synth_ai_core/streaming/
streamer.rs1use super::{
4 config::StreamConfig,
5 endpoints::StreamEndpoints,
6 handler::StreamHandler,
7 types::{StreamMessage, StreamType},
8};
9use crate::errors::CoreError;
10use crate::http::HttpClient;
11use serde_json::Value;
12use std::collections::HashSet;
13use std::sync::Arc;
14
15const DEFAULT_TIMEOUT_SECS: u64 = 60;
17
18const TERMINAL_STATUSES: &[&str] = &[
20 "succeeded",
21 "failed",
22 "cancelled",
23 "canceled",
24 "completed",
25 "error",
26 "paused",
27];
28
29pub struct JobStreamer {
31 base_url: String,
32 api_key: String,
33 job_id: String,
34 endpoints: StreamEndpoints,
35 config: StreamConfig,
36 handlers: Vec<Arc<dyn StreamHandler>>,
37 seen_messages: HashSet<String>,
38 last_event_seq: Option<i64>,
39}
40
41impl JobStreamer {
42 pub fn new(
44 base_url: impl Into<String>,
45 api_key: impl Into<String>,
46 job_id: impl Into<String>,
47 ) -> Self {
48 let job_id = job_id.into();
49 Self {
50 base_url: base_url.into().trim_end_matches('/').to_string(),
51 api_key: api_key.into(),
52 job_id: job_id.clone(),
53 endpoints: StreamEndpoints::learning(&job_id),
54 config: StreamConfig::default(),
55 handlers: vec![],
56 seen_messages: HashSet::new(),
57 last_event_seq: None,
58 }
59 }
60
61 pub fn with_endpoints(mut self, endpoints: StreamEndpoints) -> Self {
63 self.endpoints = endpoints;
64 self
65 }
66
67 pub fn with_config(mut self, config: StreamConfig) -> Self {
69 self.config = config;
70 self
71 }
72
73 pub fn with_handler(mut self, handler: Arc<dyn StreamHandler>) -> Self {
75 self.handlers.push(handler);
76 self
77 }
78
79 pub fn add_handler<H: StreamHandler + 'static>(&mut self, handler: H) {
81 self.handlers.push(Arc::new(handler));
82 }
83
84 pub async fn poll_status(&mut self) -> Result<Option<Value>, CoreError> {
86 let client = self.create_client()?;
87
88 for endpoint in self.endpoints.all_status_endpoints() {
89 match client.get::<Value>(endpoint, None).await {
90 Ok(status) => {
91 self.dispatch_status(&status);
92 return Ok(Some(status));
93 }
94 Err(e) => {
95 if let Some(404) = e.status() {
97 continue;
98 }
99 return Err(e.into());
100 }
101 }
102 }
103
104 Ok(None)
105 }
106
107 pub async fn poll_events(&mut self) -> Result<Vec<StreamMessage>, CoreError> {
109 if !self.config.is_stream_enabled(StreamType::Events) {
110 return Ok(vec![]);
111 }
112
113 let client = self.create_client()?;
114 let mut all_messages = vec![];
115 let mut total_events: usize = 0;
116
117 for endpoint in self.endpoints.all_event_endpoints() {
118 let url = if let Some(seq) = self.last_event_seq {
120 format!("{}?since_seq={}", endpoint, seq)
121 } else {
122 endpoint.to_string()
123 };
124
125 match client.get::<Value>(&url, None).await {
126 Ok(response) => {
127 let events_list = response
128 .get("events")
129 .and_then(|v| v.as_array())
130 .or_else(|| response.as_array());
131 if let Some(events) = events_list {
132 for event in events {
133 if self.config.should_include_event(event) {
134 let seq = event.get("seq").and_then(|v| v.as_i64());
135 let msg = StreamMessage::event(
136 &self.job_id,
137 event.clone(),
138 seq.unwrap_or(0),
139 );
140
141 if let Some(s) = seq {
143 self.last_event_seq =
144 Some(self.last_event_seq.map(|l| l.max(s)).unwrap_or(s));
145 }
146
147 self.dispatch_message(&msg);
148 all_messages.push(msg);
149 total_events += 1;
150 if let Some(max_events) = self.config.max_events_per_poll {
151 if total_events >= max_events {
152 return Ok(all_messages);
153 }
154 }
155 }
156 }
157 }
158 break; }
160 Err(e) => {
161 if let Some(404) = e.status() {
162 continue; }
164 return Err(e.into());
165 }
166 }
167 }
168
169 Ok(all_messages)
170 }
171
172 pub async fn poll_metrics(&mut self) -> Result<Vec<StreamMessage>, CoreError> {
174 if !self.config.is_stream_enabled(StreamType::Metrics) {
175 return Ok(vec![]);
176 }
177
178 let client = self.create_client()?;
179 let mut all_messages = vec![];
180
181 let metric_endpoints = self.endpoints.all_metric_endpoints();
182 if metric_endpoints.is_empty() {
183 return Ok(all_messages);
184 }
185
186 for endpoint in metric_endpoints {
187 match client.get::<Value>(endpoint, None).await {
188 Ok(response) => {
189 let mut metrics: Option<&Vec<Value>> = None;
190 if let Some(items) = response.get("points").and_then(|v| v.as_array()) {
191 metrics = Some(items);
192 } else if let Some(items) = response.get("metrics").and_then(|v| v.as_array()) {
193 metrics = Some(items);
194 } else if let Some(items) = response.as_array() {
195 metrics = Some(items);
196 }
197
198 if let Some(metrics) = metrics {
199 for metric in metrics {
200 if self.config.should_include_metric(metric) {
201 let step = metric.get("step").and_then(|v| v.as_i64()).unwrap_or(0);
202 let msg =
203 StreamMessage::metrics(&self.job_id, metric.clone(), step);
204 self.dispatch_message(&msg);
205 all_messages.push(msg);
206 }
207 }
208 }
209 break; }
211 Err(e) => {
212 if let Some(404) = e.status() {
213 continue;
214 }
215 return Err(e.into());
216 }
217 }
218 }
219
220 Ok(all_messages)
221 }
222
223 pub async fn poll_timeline(&mut self) -> Result<Vec<StreamMessage>, CoreError> {
225 if !self.config.is_stream_enabled(StreamType::Timeline) {
226 return Ok(vec![]);
227 }
228
229 let client = self.create_client()?;
230 let mut all_messages = vec![];
231 let timeline_endpoints = self.endpoints.all_timeline_endpoints();
232 if timeline_endpoints.is_empty() {
233 return Ok(all_messages);
234 }
235
236 for endpoint in timeline_endpoints {
237 match client.get::<Value>(endpoint, None).await {
238 Ok(response) => {
239 let mut entries: Option<&Vec<Value>> = None;
240 if let Some(items) = response.get("events").and_then(|v| v.as_array()) {
241 entries = Some(items);
242 } else if let Some(items) = response.get("timeline").and_then(|v| v.as_array())
243 {
244 entries = Some(items);
245 } else if let Some(items) = response.as_array() {
246 entries = Some(items);
247 }
248
249 if let Some(entries) = entries {
250 for entry in entries {
251 if !self.config.should_include_timeline(entry) {
252 continue;
253 }
254 let phase = entry.get("phase").and_then(|v| v.as_str()).unwrap_or("");
255 let job_id = entry
256 .get("job_id")
257 .and_then(|v| v.as_str())
258 .unwrap_or(&self.job_id);
259 let msg = StreamMessage::timeline(job_id, phase, entry.clone());
260 self.dispatch_message(&msg);
261 all_messages.push(msg);
262 }
263 }
264 break; }
266 Err(e) => {
267 if let Some(404) = e.status() {
268 continue;
269 }
270 return Err(e.into());
271 }
272 }
273 }
274
275 Ok(all_messages)
276 }
277
278 pub async fn stream_until_terminal(&mut self) -> Result<Value, CoreError> {
280 for handler in &self.handlers {
282 handler.on_start(&self.job_id);
283 }
284
285 loop {
286 if let Some(status) = self.poll_status().await? {
288 if Self::is_terminal(&status) {
289 let final_status = status.get("status").and_then(|v| v.as_str());
290
291 for handler in &self.handlers {
293 handler.on_end(&self.job_id, final_status);
294 handler.flush();
295 }
296
297 return Ok(status);
298 }
299 }
300
301 let _ = self.poll_events().await?;
303
304 let _ = self.poll_metrics().await?;
306 let _ = self.poll_timeline().await?;
307
308 tokio::time::sleep(tokio::time::Duration::from_secs_f64(
310 self.config.poll_interval_seconds,
311 ))
312 .await;
313 }
314 }
315
316 pub async fn stream_for_duration(
318 &mut self,
319 max_seconds: f64,
320 ) -> Result<Option<Value>, CoreError> {
321 let start = std::time::Instant::now();
322 let max_duration = std::time::Duration::from_secs_f64(max_seconds);
323
324 for handler in &self.handlers {
325 handler.on_start(&self.job_id);
326 }
327
328 loop {
329 if start.elapsed() >= max_duration {
330 for handler in &self.handlers {
331 handler.on_end(&self.job_id, Some("timeout"));
332 handler.flush();
333 }
334 return Ok(None);
335 }
336
337 if let Some(status) = self.poll_status().await? {
338 if Self::is_terminal(&status) {
339 let final_status = status.get("status").and_then(|v| v.as_str());
340 for handler in &self.handlers {
341 handler.on_end(&self.job_id, final_status);
342 handler.flush();
343 }
344 return Ok(Some(status));
345 }
346 }
347
348 let _ = self.poll_events().await?;
349 let _ = self.poll_metrics().await?;
350 let _ = self.poll_timeline().await?;
351
352 tokio::time::sleep(tokio::time::Duration::from_secs_f64(
353 self.config.poll_interval_seconds,
354 ))
355 .await;
356 }
357 }
358
359 fn create_client(&self) -> Result<HttpClient, CoreError> {
360 HttpClient::new(&self.base_url, &self.api_key, DEFAULT_TIMEOUT_SECS)
361 .map_err(|e| CoreError::Internal(format!("Failed to create HTTP client: {}", e)))
362 }
363
364 fn dispatch_status(&mut self, status: &Value) {
365 let msg = StreamMessage::status(&self.job_id, status.clone());
366 self.dispatch_message(&msg);
367 }
368
369 fn dispatch_message(&mut self, message: &StreamMessage) {
370 if self.config.deduplicate {
372 let key = message.key();
373 if self.seen_messages.contains(&key) {
374 return;
375 }
376 self.seen_messages.insert(key);
377 }
378
379 for handler in &self.handlers {
381 if handler.should_handle(message) {
382 handler.handle(message);
383 }
384 }
385 }
386
387 fn is_terminal(status: &Value) -> bool {
388 status
389 .get("status")
390 .and_then(|v| v.as_str())
391 .map(|s| TERMINAL_STATUSES.contains(&s))
392 .unwrap_or(false)
393 }
394
395 pub fn job_id(&self) -> &str {
397 &self.job_id
398 }
399
400 pub fn last_event_seq(&self) -> Option<i64> {
402 self.last_event_seq
403 }
404
405 pub fn clear_seen(&mut self) {
407 self.seen_messages.clear();
408 self.last_event_seq = None;
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_terminal_detection() {
418 assert!(JobStreamer::is_terminal(
419 &serde_json::json!({"status": "succeeded"})
420 ));
421 assert!(JobStreamer::is_terminal(
422 &serde_json::json!({"status": "failed"})
423 ));
424 assert!(JobStreamer::is_terminal(
425 &serde_json::json!({"status": "cancelled"})
426 ));
427 assert!(JobStreamer::is_terminal(
428 &serde_json::json!({"status": "paused"})
429 ));
430 assert!(!JobStreamer::is_terminal(
431 &serde_json::json!({"status": "running"})
432 ));
433 assert!(!JobStreamer::is_terminal(
434 &serde_json::json!({"status": "pending"})
435 ));
436 }
437
438 #[test]
439 fn test_streamer_creation() {
440 let streamer = JobStreamer::new("https://api.example.com", "sk-test", "job-123")
441 .with_config(StreamConfig::minimal())
442 .with_endpoints(StreamEndpoints::prompt_learning("job-123"));
443
444 assert_eq!(streamer.job_id(), "job-123");
445 assert!(streamer.last_event_seq().is_none());
446 }
447
448 #[test]
449 fn test_clear_seen() {
450 let mut streamer = JobStreamer::new("https://api.example.com", "sk-test", "job-123");
451
452 streamer.seen_messages.insert("test".to_string());
453 streamer.last_event_seq = Some(42);
454
455 streamer.clear_seen();
456
457 assert!(streamer.seen_messages.is_empty());
458 assert!(streamer.last_event_seq.is_none());
459 }
460}