synth_ai_core/streaming/
streamer.rs1use super::{
4 config::StreamConfig,
5 endpoints::StreamEndpoints,
6 handler::StreamHandler,
7 types::{StreamMessage, StreamType},
8};
9use crate::errors::CoreError;
10use crate::http::HttpClient;
11use serde_json::Value;
12use std::collections::HashSet;
13use std::sync::Arc;
14
15const DEFAULT_TIMEOUT_SECS: u64 = 60;
17
18const TERMINAL_STATUSES: &[&str] = &[
20 "succeeded",
21 "failed",
22 "cancelled",
23 "canceled",
24 "completed",
25 "error",
26];
27
28pub struct JobStreamer {
30 base_url: String,
31 api_key: String,
32 job_id: String,
33 endpoints: StreamEndpoints,
34 config: StreamConfig,
35 handlers: Vec<Arc<dyn StreamHandler>>,
36 seen_messages: HashSet<String>,
37 last_event_seq: Option<i64>,
38}
39
40impl JobStreamer {
41 pub fn new(
43 base_url: impl Into<String>,
44 api_key: impl Into<String>,
45 job_id: impl Into<String>,
46 ) -> Self {
47 let job_id = job_id.into();
48 Self {
49 base_url: base_url.into().trim_end_matches('/').to_string(),
50 api_key: api_key.into(),
51 job_id: job_id.clone(),
52 endpoints: StreamEndpoints::learning(&job_id),
53 config: StreamConfig::default(),
54 handlers: vec![],
55 seen_messages: HashSet::new(),
56 last_event_seq: None,
57 }
58 }
59
60 pub fn with_endpoints(mut self, endpoints: StreamEndpoints) -> Self {
62 self.endpoints = endpoints;
63 self
64 }
65
66 pub fn with_config(mut self, config: StreamConfig) -> Self {
68 self.config = config;
69 self
70 }
71
72 pub fn with_handler(mut self, handler: Arc<dyn StreamHandler>) -> Self {
74 self.handlers.push(handler);
75 self
76 }
77
78 pub fn add_handler<H: StreamHandler + 'static>(&mut self, handler: H) {
80 self.handlers.push(Arc::new(handler));
81 }
82
83 pub async fn poll_status(&mut self) -> Result<Option<Value>, CoreError> {
85 let client = self.create_client()?;
86
87 for endpoint in self.endpoints.all_status_endpoints() {
88 match client.get::<Value>(endpoint, None).await {
89 Ok(status) => {
90 self.dispatch_status(&status);
91 return Ok(Some(status));
92 }
93 Err(e) => {
94 if let Some(404) = e.status() {
96 continue;
97 }
98 return Err(e.into());
99 }
100 }
101 }
102
103 Ok(None)
104 }
105
106 pub async fn poll_events(&mut self) -> Result<Vec<StreamMessage>, CoreError> {
108 if !self.config.is_stream_enabled(StreamType::Events) {
109 return Ok(vec![]);
110 }
111
112 let client = self.create_client()?;
113 let mut all_messages = vec![];
114
115 for endpoint in self.endpoints.all_event_endpoints() {
116 let url = if let Some(seq) = self.last_event_seq {
118 format!("{}?since_seq={}", endpoint, seq)
119 } else {
120 endpoint.to_string()
121 };
122
123 match client.get::<Value>(&url, None).await {
124 Ok(response) => {
125 if let Some(events) = response.get("events").and_then(|v| v.as_array()) {
126 for event in events {
127 if self.config.should_include_event(event) {
128 let seq = event.get("seq").and_then(|v| v.as_i64());
129 let msg = StreamMessage::event(&self.job_id, event.clone(), seq.unwrap_or(0));
130
131 if let Some(s) = seq {
133 self.last_event_seq = Some(self.last_event_seq.map(|l| l.max(s)).unwrap_or(s));
134 }
135
136 self.dispatch_message(&msg);
137 all_messages.push(msg);
138 }
139 }
140 }
141 break; }
143 Err(e) => {
144 if let Some(404) = e.status() {
145 continue; }
147 return Err(e.into());
148 }
149 }
150 }
151
152 Ok(all_messages)
153 }
154
155 pub async fn poll_metrics(&mut self) -> Result<Vec<StreamMessage>, CoreError> {
157 if !self.config.is_stream_enabled(StreamType::Metrics) {
158 return Ok(vec![]);
159 }
160
161 let client = self.create_client()?;
162 let mut all_messages = vec![];
163
164 if let Some(ref endpoint) = self.endpoints.metrics {
165 match client.get::<Value>(endpoint, None).await {
166 Ok(response) => {
167 if let Some(metrics) = response.get("metrics").and_then(|v| v.as_array()) {
168 for metric in metrics {
169 if self.config.should_include_metric(metric) {
170 let step = metric.get("step").and_then(|v| v.as_i64()).unwrap_or(0);
171 let msg = StreamMessage::metrics(&self.job_id, metric.clone(), step);
172 self.dispatch_message(&msg);
173 all_messages.push(msg);
174 }
175 }
176 }
177 }
178 Err(e) => {
179 if e.status() != Some(404) {
180 return Err(e.into());
181 }
182 }
183 }
184 }
185
186 Ok(all_messages)
187 }
188
189 pub async fn stream_until_terminal(&mut self) -> Result<Value, CoreError> {
191 for handler in &self.handlers {
193 handler.on_start(&self.job_id);
194 }
195
196 loop {
197 if let Some(status) = self.poll_status().await? {
199 if Self::is_terminal(&status) {
200 let final_status = status.get("status").and_then(|v| v.as_str());
201
202 for handler in &self.handlers {
204 handler.on_end(&self.job_id, final_status);
205 handler.flush();
206 }
207
208 return Ok(status);
209 }
210 }
211
212 let _ = self.poll_events().await?;
214
215 let _ = self.poll_metrics().await?;
217
218 tokio::time::sleep(tokio::time::Duration::from_secs_f64(
220 self.config.poll_interval_seconds,
221 ))
222 .await;
223 }
224 }
225
226 pub async fn stream_for_duration(
228 &mut self,
229 max_seconds: f64,
230 ) -> Result<Option<Value>, CoreError> {
231 let start = std::time::Instant::now();
232 let max_duration = std::time::Duration::from_secs_f64(max_seconds);
233
234 for handler in &self.handlers {
235 handler.on_start(&self.job_id);
236 }
237
238 loop {
239 if start.elapsed() >= max_duration {
240 for handler in &self.handlers {
241 handler.on_end(&self.job_id, Some("timeout"));
242 handler.flush();
243 }
244 return Ok(None);
245 }
246
247 if let Some(status) = self.poll_status().await? {
248 if Self::is_terminal(&status) {
249 let final_status = status.get("status").and_then(|v| v.as_str());
250 for handler in &self.handlers {
251 handler.on_end(&self.job_id, final_status);
252 handler.flush();
253 }
254 return Ok(Some(status));
255 }
256 }
257
258 let _ = self.poll_events().await?;
259 let _ = self.poll_metrics().await?;
260
261 tokio::time::sleep(tokio::time::Duration::from_secs_f64(
262 self.config.poll_interval_seconds,
263 ))
264 .await;
265 }
266 }
267
268 fn create_client(&self) -> Result<HttpClient, CoreError> {
269 HttpClient::new(&self.base_url, &self.api_key, DEFAULT_TIMEOUT_SECS)
270 .map_err(|e| CoreError::Internal(format!("Failed to create HTTP client: {}", e)))
271 }
272
273 fn dispatch_status(&mut self, status: &Value) {
274 let msg = StreamMessage::status(&self.job_id, status.clone());
275 self.dispatch_message(&msg);
276 }
277
278 fn dispatch_message(&mut self, message: &StreamMessage) {
279 if self.config.deduplicate {
281 let key = message.key();
282 if self.seen_messages.contains(&key) {
283 return;
284 }
285 self.seen_messages.insert(key);
286 }
287
288 for handler in &self.handlers {
290 if handler.should_handle(message) {
291 handler.handle(message);
292 }
293 }
294 }
295
296 fn is_terminal(status: &Value) -> bool {
297 status
298 .get("status")
299 .and_then(|v| v.as_str())
300 .map(|s| TERMINAL_STATUSES.contains(&s))
301 .unwrap_or(false)
302 }
303
304 pub fn job_id(&self) -> &str {
306 &self.job_id
307 }
308
309 pub fn last_event_seq(&self) -> Option<i64> {
311 self.last_event_seq
312 }
313
314 pub fn clear_seen(&mut self) {
316 self.seen_messages.clear();
317 self.last_event_seq = None;
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_terminal_detection() {
327 assert!(JobStreamer::is_terminal(&serde_json::json!({"status": "succeeded"})));
328 assert!(JobStreamer::is_terminal(&serde_json::json!({"status": "failed"})));
329 assert!(JobStreamer::is_terminal(&serde_json::json!({"status": "cancelled"})));
330 assert!(!JobStreamer::is_terminal(&serde_json::json!({"status": "running"})));
331 assert!(!JobStreamer::is_terminal(&serde_json::json!({"status": "pending"})));
332 }
333
334 #[test]
335 fn test_streamer_creation() {
336 let streamer = JobStreamer::new("https://api.example.com", "sk-test", "job-123")
337 .with_config(StreamConfig::minimal())
338 .with_endpoints(StreamEndpoints::prompt_learning("job-123"));
339
340 assert_eq!(streamer.job_id(), "job-123");
341 assert!(streamer.last_event_seq().is_none());
342 }
343
344 #[test]
345 fn test_clear_seen() {
346 let mut streamer = JobStreamer::new("https://api.example.com", "sk-test", "job-123");
347
348 streamer.seen_messages.insert("test".to_string());
349 streamer.last_event_seq = Some(42);
350
351 streamer.clear_seen();
352
353 assert!(streamer.seen_messages.is_empty());
354 assert!(streamer.last_event_seq.is_none());
355 }
356}