Skip to main content

synth_ai_core/orchestration/
streaming.rs

1//! Event streaming for optimization jobs.
2//!
3//! This module provides SSE-style event streaming with deduplication
4//! for optimization jobs.
5
6use std::collections::HashSet;
7use std::time::{Duration, Instant};
8
9use serde_json::Value;
10
11use crate::errors::CoreError;
12use crate::http::HttpClient;
13
14use super::events::{EventCategory, EventParser, ParsedEvent};
15
16/// Format a duration for human-readable logging.
17fn format_duration(d: Duration) -> String {
18    let secs = d.as_secs();
19    if secs < 60 {
20        format!("{}s", secs)
21    } else if secs < 3600 {
22        format!("{}m{}s", secs / 60, secs % 60)
23    } else {
24        format!("{}h{}m", secs / 3600, (secs % 3600) / 60)
25    }
26}
27
28/// Log an event summary based on its category.
29pub fn log_event_summary(event: &ParsedEvent) {
30    let path = EventParser::parse_path(&event.event_type);
31    if let (Some(entity), Some(action)) = (path.entity.as_deref(), path.action.as_deref()) {
32        eprintln!(
33            "[STREAM] Event path: {}.{} (alg={:?} detail={:?})",
34            entity, action, path.algorithm, path.detail
35        );
36    }
37    match event.category {
38        EventCategory::Baseline => {
39            let baseline = EventParser::parse_baseline(event);
40            eprintln!("[STREAM] Baseline: reward={:.3?}", baseline.reward);
41        }
42        EventCategory::Candidate => {
43            let candidate = EventParser::parse_candidate(event);
44            eprintln!(
45                "[STREAM] Candidate {}: reward={:.3?} accepted={} gen={:?}",
46                candidate.candidate_id, candidate.reward, candidate.accepted, candidate.generation
47            );
48        }
49        EventCategory::Frontier => {
50            let frontier = EventParser::parse_frontier(event);
51            eprintln!(
52                "[STREAM] Frontier updated: size={} best={:.3?}",
53                frontier.frontier_size, frontier.best_reward
54            );
55        }
56        EventCategory::Progress => {
57            let progress = EventParser::parse_progress(event);
58            eprintln!(
59                "[STREAM] Progress: rollouts={}/{:?} best={:.3?}",
60                progress.rollouts_completed, progress.rollouts_total, progress.best_reward
61            );
62        }
63        EventCategory::Generation => {
64            let gen = EventParser::parse_generation(event);
65            eprintln!(
66                "[STREAM] Generation {}: best_acc={:.3} proposed={} accepted={}",
67                gen.generation, gen.best_reward, gen.candidates_proposed, gen.candidates_accepted
68            );
69        }
70        EventCategory::Validation => {
71            eprintln!("[STREAM] Validation event: {:?}", event.event_type);
72        }
73        EventCategory::Complete => {
74            let complete = EventParser::parse_complete(event);
75            eprintln!(
76                "[STREAM] COMPLETE: best={:.3?} baseline={:.3?} reason={:?}",
77                complete.best_reward, complete.baseline_reward, complete.finish_reason
78            );
79        }
80        EventCategory::Termination => {
81            let term = EventParser::parse_termination(event);
82            eprintln!("[STREAM] TERMINATION: reason={}", term.reason);
83        }
84        EventCategory::Usage => {
85            let usage = EventParser::parse_usage(event);
86            eprintln!(
87                "[STREAM] Usage: total=${:.4} tokens=${:.4} sandbox=${:.4}",
88                usage.total_usd, usage.tokens_usd, usage.sandbox_usd
89            );
90        }
91        EventCategory::Throughput => {
92            eprintln!("[STREAM] Throughput event");
93        }
94        EventCategory::Unknown => {
95            eprintln!("[STREAM] Unknown event: {}", event.event_type);
96        }
97    }
98}
99
100/// Event stream for polling job events.
101pub struct EventStream {
102    /// HTTP client reference
103    client: HttpClient,
104    /// Job ID to stream events for
105    job_id: String,
106    /// Base URL for API
107    base_url: String,
108    /// Last seen sequence number
109    last_seq: i64,
110    /// Whether to deduplicate events
111    deduplicate: bool,
112    /// Set of seen sequence numbers
113    seen_seqs: HashSet<i64>,
114    /// Maximum events per poll
115    max_events_per_poll: i32,
116}
117
118impl EventStream {
119    /// Create a new event stream for a job.
120    pub fn new(client: HttpClient, base_url: &str, job_id: &str) -> Self {
121        Self {
122            client,
123            job_id: job_id.to_string(),
124            base_url: base_url.trim_end_matches('/').to_string(),
125            last_seq: 0,
126            deduplicate: true,
127            seen_seqs: HashSet::new(),
128            max_events_per_poll: 500,
129        }
130    }
131
132    /// Set the starting sequence number.
133    pub fn with_start_seq(mut self, seq: i64) -> Self {
134        self.last_seq = seq;
135        self
136    }
137
138    /// Enable or disable deduplication.
139    pub fn with_deduplicate(mut self, dedupe: bool) -> Self {
140        self.deduplicate = dedupe;
141        self
142    }
143
144    /// Set max events per poll.
145    pub fn with_max_events(mut self, max: i32) -> Self {
146        self.max_events_per_poll = max;
147        self
148    }
149
150    /// Get the last seen sequence number.
151    pub fn last_seq(&self) -> i64 {
152        self.last_seq
153    }
154
155    /// Poll for new events.
156    ///
157    /// Returns events since the last sequence number.
158    pub async fn poll_events(&mut self) -> Result<Vec<ParsedEvent>, CoreError> {
159        let url = format!(
160            "{}/api/prompt-learning/online/jobs/{}/events",
161            self.base_url, self.job_id
162        );
163
164        let params = [
165            ("since_seq", self.last_seq.to_string()),
166            ("limit", self.max_events_per_poll.to_string()),
167        ];
168
169        let params_slice: &[(&str, &str)] = &[("since_seq", &params[0].1), ("limit", &params[1].1)];
170
171        eprintln!(
172            "[STREAM] poll_events: job={} since_seq={} limit={}",
173            self.job_id, self.last_seq, self.max_events_per_poll
174        );
175
176        let response: Value = self
177            .client
178            .get(&url, Some(params_slice))
179            .await
180            .map_err(|e| {
181                eprintln!("[STREAM] ERROR: poll_events failed: {}", e);
182                CoreError::Internal(format!("failed to fetch events: {}", e))
183            })?;
184
185        // Parse events array
186        let events_array = response
187            .get("events")
188            .and_then(|v| v.as_array())
189            .cloned()
190            .unwrap_or_default();
191
192        eprintln!(
193            "[STREAM] poll_events: received {} raw events",
194            events_array.len()
195        );
196
197        let mut parsed_events = Vec::new();
198
199        for event_value in events_array {
200            let parsed = EventParser::parse(&event_value);
201
202            // Update last_seq
203            if let Some(seq) = parsed.seq {
204                if seq > self.last_seq {
205                    self.last_seq = seq;
206                }
207
208                // Deduplication
209                if self.deduplicate {
210                    if self.seen_seqs.contains(&seq) {
211                        continue;
212                    }
213                    self.seen_seqs.insert(seq);
214
215                    // Limit seen_seqs size to prevent memory growth
216                    if self.seen_seqs.len() > 10000 {
217                        // Keep only recent sequences
218                        let threshold = self.last_seq - 5000;
219                        self.seen_seqs.retain(|&s| s > threshold);
220                    }
221                }
222            }
223
224            parsed_events.push(parsed);
225        }
226
227        if !parsed_events.is_empty() {
228            eprintln!(
229                "[STREAM] poll_events: returning {} new events (last_seq={})",
230                parsed_events.len(),
231                self.last_seq
232            );
233        }
234
235        Ok(parsed_events)
236    }
237
238    /// Stream events until a terminal condition with callback.
239    ///
240    /// # Arguments
241    ///
242    /// * `on_event` - Callback for each event
243    /// * `timeout` - Maximum time to stream
244    /// * `poll_interval` - Time between polls
245    /// * `is_terminal` - Optional check for terminal status
246    pub async fn stream_until<F, T>(
247        &mut self,
248        mut on_event: F,
249        timeout: Duration,
250        poll_interval: Duration,
251        mut is_terminal: T,
252    ) -> Result<(), CoreError>
253    where
254        F: FnMut(&ParsedEvent),
255        T: FnMut() -> bool,
256    {
257        let start = Instant::now();
258        let mut last_event_time = Instant::now();
259        let mut poll_count = 0u64;
260        let mut total_events = 0u64;
261
262        eprintln!(
263            "[STREAM] stream_until: starting job={} timeout={} poll_interval={}",
264            self.job_id,
265            format_duration(timeout),
266            format_duration(poll_interval)
267        );
268
269        loop {
270            let elapsed = start.elapsed();
271
272            // Check timeout
273            if elapsed > timeout {
274                eprintln!(
275                    "[STREAM] TIMEOUT: elapsed={} total_events={}",
276                    format_duration(elapsed),
277                    total_events
278                );
279                return Err(CoreError::Timeout(format!(
280                    "event stream timed out after {:.0} seconds",
281                    timeout.as_secs_f64()
282                )));
283            }
284
285            // Check terminal condition
286            if is_terminal() {
287                eprintln!(
288                    "[STREAM] Terminal condition reached: elapsed={} total_events={}",
289                    format_duration(elapsed),
290                    total_events
291                );
292                return Ok(());
293            }
294
295            poll_count += 1;
296
297            // Log every 10 polls or when significant time has passed
298            if poll_count % 10 == 0 {
299                eprintln!(
300                    "[STREAM] Streaming: elapsed={} polls={} events={}",
301                    format_duration(elapsed),
302                    poll_count,
303                    total_events
304                );
305            }
306
307            // Poll events
308            match self.poll_events().await {
309                Ok(events) => {
310                    if !events.is_empty() {
311                        last_event_time = Instant::now();
312                        total_events += events.len() as u64;
313                        eprintln!(
314                            "[STREAM] Received {} events (total={})",
315                            events.len(),
316                            total_events
317                        );
318                    }
319
320                    for event in &events {
321                        // Log each event summary
322                        log_event_summary(event);
323
324                        on_event(event);
325
326                        // Check for terminal events
327                        if event.category.is_terminal() {
328                            eprintln!(
329                                "[STREAM] Terminal event received: {} (elapsed={})",
330                                event.event_type,
331                                format_duration(elapsed)
332                            );
333                            return Ok(());
334                        }
335                    }
336                }
337                Err(e) => {
338                    let since_last = last_event_time.elapsed();
339                    eprintln!(
340                        "[STREAM] Poll error ({}s since last event): {}",
341                        since_last.as_secs(),
342                        e
343                    );
344                    // Allow some grace period for transient errors
345                    if since_last > Duration::from_secs(120) {
346                        eprintln!("[STREAM] ERROR: Too long since last event, giving up");
347                        return Err(e);
348                    }
349                }
350            }
351
352            // Wait before next poll
353            tokio::time::sleep(poll_interval).await;
354        }
355    }
356}
357
358/// Stream configuration.
359#[derive(Debug, Clone)]
360pub struct StreamConfig {
361    /// Poll interval in seconds
362    pub poll_interval_secs: f64,
363    /// Maximum events per poll
364    pub max_events_per_poll: i32,
365    /// Whether to deduplicate events
366    pub deduplicate: bool,
367    /// Timeout in seconds
368    pub timeout_secs: f64,
369}
370
371impl Default for StreamConfig {
372    fn default() -> Self {
373        Self {
374            poll_interval_secs: 5.0,
375            max_events_per_poll: 500,
376            deduplicate: true,
377            timeout_secs: 3600.0,
378        }
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn test_stream_config_default() {
388        let config = StreamConfig::default();
389        assert_eq!(config.poll_interval_secs, 5.0);
390        assert_eq!(config.max_events_per_poll, 500);
391        assert!(config.deduplicate);
392    }
393}