Skip to main content

synth_ai_core/streaming/
streamer.rs

1//! Job streamer for polling and streaming job events.
2
3use super::{
4    config::StreamConfig,
5    endpoints::StreamEndpoints,
6    handler::StreamHandler,
7    types::{StreamMessage, StreamType},
8};
9use crate::errors::CoreError;
10use crate::http::HttpClient;
11use serde_json::Value;
12use std::collections::HashSet;
13use std::sync::Arc;
14
15/// Default timeout in seconds for streaming requests.
16const DEFAULT_TIMEOUT_SECS: u64 = 60;
17
18/// Terminal job statuses.
19const TERMINAL_STATUSES: &[&str] = &[
20    "succeeded",
21    "failed",
22    "cancelled",
23    "canceled",
24    "completed",
25    "error",
26];
27
28/// Job streamer that polls endpoints and dispatches to handlers.
29pub struct JobStreamer {
30    base_url: String,
31    api_key: String,
32    job_id: String,
33    endpoints: StreamEndpoints,
34    config: StreamConfig,
35    handlers: Vec<Arc<dyn StreamHandler>>,
36    seen_messages: HashSet<String>,
37    last_event_seq: Option<i64>,
38}
39
40impl JobStreamer {
41    /// Create a new job streamer.
42    pub fn new(
43        base_url: impl Into<String>,
44        api_key: impl Into<String>,
45        job_id: impl Into<String>,
46    ) -> Self {
47        let job_id = job_id.into();
48        Self {
49            base_url: base_url.into().trim_end_matches('/').to_string(),
50            api_key: api_key.into(),
51            job_id: job_id.clone(),
52            endpoints: StreamEndpoints::learning(&job_id),
53            config: StreamConfig::default(),
54            handlers: vec![],
55            seen_messages: HashSet::new(),
56            last_event_seq: None,
57        }
58    }
59
60    /// Set the stream endpoints.
61    pub fn with_endpoints(mut self, endpoints: StreamEndpoints) -> Self {
62        self.endpoints = endpoints;
63        self
64    }
65
66    /// Set the stream config.
67    pub fn with_config(mut self, config: StreamConfig) -> Self {
68        self.config = config;
69        self
70    }
71
72    /// Add a handler.
73    pub fn with_handler(mut self, handler: Arc<dyn StreamHandler>) -> Self {
74        self.handlers.push(handler);
75        self
76    }
77
78    /// Add a handler (convenience method).
79    pub fn add_handler<H: StreamHandler + 'static>(&mut self, handler: H) {
80        self.handlers.push(Arc::new(handler));
81    }
82
83    /// Poll status once.
84    pub async fn poll_status(&mut self) -> Result<Option<Value>, CoreError> {
85        let client = self.create_client()?;
86
87        for endpoint in self.endpoints.all_status_endpoints() {
88            match client.get::<Value>(endpoint, None).await {
89                Ok(status) => {
90                    self.dispatch_status(&status);
91                    return Ok(Some(status));
92                }
93                Err(e) => {
94                    // Check if it's a 404 - try next fallback
95                    if let Some(404) = e.status() {
96                        continue;
97                    }
98                    return Err(e.into());
99                }
100            }
101        }
102
103        Ok(None)
104    }
105
106    /// Poll events once.
107    pub async fn poll_events(&mut self) -> Result<Vec<StreamMessage>, CoreError> {
108        if !self.config.is_stream_enabled(StreamType::Events) {
109            return Ok(vec![]);
110        }
111
112        let client = self.create_client()?;
113        let mut all_messages = vec![];
114
115        for endpoint in self.endpoints.all_event_endpoints() {
116            // Add since_seq parameter if we have one
117            let url = if let Some(seq) = self.last_event_seq {
118                format!("{}?since_seq={}", endpoint, seq)
119            } else {
120                endpoint.to_string()
121            };
122
123            match client.get::<Value>(&url, None).await {
124                Ok(response) => {
125                    if let Some(events) = response.get("events").and_then(|v| v.as_array()) {
126                        for event in events {
127                            if self.config.should_include_event(event) {
128                                let seq = event.get("seq").and_then(|v| v.as_i64());
129                                let msg = StreamMessage::event(&self.job_id, event.clone(), seq.unwrap_or(0));
130
131                                // Update last seen seq
132                                if let Some(s) = seq {
133                                    self.last_event_seq = Some(self.last_event_seq.map(|l| l.max(s)).unwrap_or(s));
134                                }
135
136                                self.dispatch_message(&msg);
137                                all_messages.push(msg);
138                            }
139                        }
140                    }
141                    break; // Success, don't try fallbacks
142                }
143                Err(e) => {
144                    if let Some(404) = e.status() {
145                        continue; // Try next fallback
146                    }
147                    return Err(e.into());
148                }
149            }
150        }
151
152        Ok(all_messages)
153    }
154
155    /// Poll metrics once.
156    pub async fn poll_metrics(&mut self) -> Result<Vec<StreamMessage>, CoreError> {
157        if !self.config.is_stream_enabled(StreamType::Metrics) {
158            return Ok(vec![]);
159        }
160
161        let client = self.create_client()?;
162        let mut all_messages = vec![];
163
164        if let Some(ref endpoint) = self.endpoints.metrics {
165            match client.get::<Value>(endpoint, None).await {
166                Ok(response) => {
167                    if let Some(metrics) = response.get("metrics").and_then(|v| v.as_array()) {
168                        for metric in metrics {
169                            if self.config.should_include_metric(metric) {
170                                let step = metric.get("step").and_then(|v| v.as_i64()).unwrap_or(0);
171                                let msg = StreamMessage::metrics(&self.job_id, metric.clone(), step);
172                                self.dispatch_message(&msg);
173                                all_messages.push(msg);
174                            }
175                        }
176                    }
177                }
178                Err(e) => {
179                    if e.status() != Some(404) {
180                        return Err(e.into());
181                    }
182                }
183            }
184        }
185
186        Ok(all_messages)
187    }
188
189    /// Stream until the job reaches a terminal state.
190    pub async fn stream_until_terminal(&mut self) -> Result<Value, CoreError> {
191        // Notify handlers of start
192        for handler in &self.handlers {
193            handler.on_start(&self.job_id);
194        }
195
196        loop {
197            // Poll status
198            if let Some(status) = self.poll_status().await? {
199                if Self::is_terminal(&status) {
200                    let final_status = status.get("status").and_then(|v| v.as_str());
201
202                    // Notify handlers of end
203                    for handler in &self.handlers {
204                        handler.on_end(&self.job_id, final_status);
205                        handler.flush();
206                    }
207
208                    return Ok(status);
209                }
210            }
211
212            // Poll events
213            let _ = self.poll_events().await?;
214
215            // Poll metrics (less frequently)
216            let _ = self.poll_metrics().await?;
217
218            // Wait before next poll
219            tokio::time::sleep(tokio::time::Duration::from_secs_f64(
220                self.config.poll_interval_seconds,
221            ))
222            .await;
223        }
224    }
225
226    /// Stream for a maximum duration, returning early if terminal.
227    pub async fn stream_for_duration(
228        &mut self,
229        max_seconds: f64,
230    ) -> Result<Option<Value>, CoreError> {
231        let start = std::time::Instant::now();
232        let max_duration = std::time::Duration::from_secs_f64(max_seconds);
233
234        for handler in &self.handlers {
235            handler.on_start(&self.job_id);
236        }
237
238        loop {
239            if start.elapsed() >= max_duration {
240                for handler in &self.handlers {
241                    handler.on_end(&self.job_id, Some("timeout"));
242                    handler.flush();
243                }
244                return Ok(None);
245            }
246
247            if let Some(status) = self.poll_status().await? {
248                if Self::is_terminal(&status) {
249                    let final_status = status.get("status").and_then(|v| v.as_str());
250                    for handler in &self.handlers {
251                        handler.on_end(&self.job_id, final_status);
252                        handler.flush();
253                    }
254                    return Ok(Some(status));
255                }
256            }
257
258            let _ = self.poll_events().await?;
259            let _ = self.poll_metrics().await?;
260
261            tokio::time::sleep(tokio::time::Duration::from_secs_f64(
262                self.config.poll_interval_seconds,
263            ))
264            .await;
265        }
266    }
267
268    fn create_client(&self) -> Result<HttpClient, CoreError> {
269        HttpClient::new(&self.base_url, &self.api_key, DEFAULT_TIMEOUT_SECS)
270            .map_err(|e| CoreError::Internal(format!("Failed to create HTTP client: {}", e)))
271    }
272
273    fn dispatch_status(&mut self, status: &Value) {
274        let msg = StreamMessage::status(&self.job_id, status.clone());
275        self.dispatch_message(&msg);
276    }
277
278    fn dispatch_message(&mut self, message: &StreamMessage) {
279        // Deduplication
280        if self.config.deduplicate {
281            let key = message.key();
282            if self.seen_messages.contains(&key) {
283                return;
284            }
285            self.seen_messages.insert(key);
286        }
287
288        // Dispatch to handlers
289        for handler in &self.handlers {
290            if handler.should_handle(message) {
291                handler.handle(message);
292            }
293        }
294    }
295
296    fn is_terminal(status: &Value) -> bool {
297        status
298            .get("status")
299            .and_then(|v| v.as_str())
300            .map(|s| TERMINAL_STATUSES.contains(&s))
301            .unwrap_or(false)
302    }
303
304    /// Get the job ID.
305    pub fn job_id(&self) -> &str {
306        &self.job_id
307    }
308
309    /// Get the last event sequence number.
310    pub fn last_event_seq(&self) -> Option<i64> {
311        self.last_event_seq
312    }
313
314    /// Clear seen messages (for re-streaming).
315    pub fn clear_seen(&mut self) {
316        self.seen_messages.clear();
317        self.last_event_seq = None;
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_terminal_detection() {
327        assert!(JobStreamer::is_terminal(&serde_json::json!({"status": "succeeded"})));
328        assert!(JobStreamer::is_terminal(&serde_json::json!({"status": "failed"})));
329        assert!(JobStreamer::is_terminal(&serde_json::json!({"status": "cancelled"})));
330        assert!(!JobStreamer::is_terminal(&serde_json::json!({"status": "running"})));
331        assert!(!JobStreamer::is_terminal(&serde_json::json!({"status": "pending"})));
332    }
333
334    #[test]
335    fn test_streamer_creation() {
336        let streamer = JobStreamer::new("https://api.example.com", "sk-test", "job-123")
337            .with_config(StreamConfig::minimal())
338            .with_endpoints(StreamEndpoints::prompt_learning("job-123"));
339
340        assert_eq!(streamer.job_id(), "job-123");
341        assert!(streamer.last_event_seq().is_none());
342    }
343
344    #[test]
345    fn test_clear_seen() {
346        let mut streamer = JobStreamer::new("https://api.example.com", "sk-test", "job-123");
347
348        streamer.seen_messages.insert("test".to_string());
349        streamer.last_event_seq = Some(42);
350
351        streamer.clear_seen();
352
353        assert!(streamer.seen_messages.is_empty());
354        assert!(streamer.last_event_seq.is_none());
355    }
356}