Skip to main content

vtcode_core/llm/providers/
streaming_progress.rs

1//! Provider-agnostic streaming timeout progress tracking
2//!
3//! This module provides a unified interface for tracking streaming timeout progress
4//! across all LLM providers (OpenAI, Anthropic, Gemini, Ollama, etc.)
5
6use std::sync::Arc;
7use std::sync::atomic::{AtomicU8, Ordering};
8use std::time::{Duration, Instant};
9use tracing::warn;
10
11/// Callback for streaming timeout progress updates
12/// Progress value is 0.0-1.0 representing elapsed / total_timeout
13pub type StreamingProgressCallback = Box<dyn Fn(f32) + Send + Sync>;
14
15/// Unified streaming progress tracker for all LLM providers
16#[derive(Clone)]
17pub struct StreamingProgressTracker {
18    callback: Option<Arc<StreamingProgressCallback>>,
19    warning_threshold: f32,
20    total_timeout: Duration,
21    start_time: Arc<Instant>,
22    last_reported_progress: Arc<AtomicU8>,
23}
24
25impl StreamingProgressTracker {
26    /// Create a new streaming progress tracker
27    pub fn new(total_timeout: Duration) -> Self {
28        Self {
29            callback: None,
30            warning_threshold: 0.8,
31            total_timeout,
32            start_time: Arc::new(Instant::now()),
33            last_reported_progress: Arc::new(AtomicU8::new(0)),
34        }
35    }
36
37    /// Set a progress callback
38    pub fn with_callback(mut self, callback: StreamingProgressCallback) -> Self {
39        self.callback = Some(Arc::new(callback));
40        self
41    }
42
43    /// Set the warning threshold (0.0-1.0)
44    pub fn with_warning_threshold(mut self, threshold: f32) -> Self {
45        self.warning_threshold = threshold.clamp(0.0, 1.0);
46        self
47    }
48
49    /// Report that the first chunk has been received
50    pub fn report_first_chunk(&self) {
51        self.report_progress(0.1);
52    }
53
54    /// Report progress with elapsed time
55    pub fn report_chunk_received(&self) {
56        let elapsed = self.start_time.elapsed();
57        self.report_progress_with_elapsed(elapsed);
58    }
59
60    /// Report progress at a specific elapsed duration
61    pub fn report_progress_with_elapsed(&self, elapsed: Duration) {
62        if self.total_timeout.as_secs() == 0 {
63            return;
64        }
65
66        let progress = elapsed.as_secs_f32() / self.total_timeout.as_secs_f32();
67        self.report_progress(progress.min(0.99)); // Cap at 99%
68    }
69
70    /// Report error or timeout (100% progress)
71    pub fn report_error(&self) {
72        self.report_progress(1.0);
73    }
74
75    /// Get current progress as percentage (0-100)
76    pub fn progress_percent(&self) -> u8 {
77        self.last_reported_progress.load(Ordering::Relaxed)
78    }
79
80    /// Get elapsed time since start
81    pub fn elapsed(&self) -> Duration {
82        self.start_time.elapsed()
83    }
84
85    /// Check if warning threshold has been exceeded
86    pub fn is_approaching_timeout(&self) -> bool {
87        let elapsed = self.start_time.elapsed();
88        if self.total_timeout.as_secs() == 0 {
89            return false;
90        }
91
92        let elapsed_progress = elapsed.as_secs_f32() / self.total_timeout.as_secs_f32();
93        let reported_progress =
94            f32::from(self.last_reported_progress.load(Ordering::Relaxed)) / 100.0;
95        elapsed_progress.max(reported_progress) >= self.warning_threshold
96    }
97
98    // Private: Report progress with clamping and threshold checking
99    fn report_progress(&self, progress: f32) {
100        let progress_clamped = progress.clamp(0.0, 1.0);
101        let percent = (progress_clamped * 100.0) as u8;
102
103        // Only update if progress changed by at least 1%
104        let last_percent = self.last_reported_progress.load(Ordering::Relaxed);
105        if percent <= last_percent {
106            return;
107        }
108
109        self.last_reported_progress
110            .store(percent, Ordering::Relaxed);
111
112        // Call the callback if set
113        if let Some(ref callback) = self.callback {
114            callback(progress_clamped);
115        }
116
117        // Warn if approaching threshold
118        if progress_clamped >= self.warning_threshold && progress_clamped < 1.0 {
119            warn!(
120                "Streaming operation at {:.0}% of timeout limit ({:?}/{:?} elapsed). Approaching timeout.",
121                progress_clamped * 100.0,
122                self.elapsed(),
123                self.total_timeout
124            );
125        }
126    }
127}
128
129/// Builder for creating streaming progress trackers with fluent API
130pub struct StreamingProgressBuilder {
131    total_timeout: Duration,
132    callback: Option<StreamingProgressCallback>,
133    warning_threshold: f32,
134}
135
136impl StreamingProgressBuilder {
137    /// Create a new builder with total timeout in seconds
138    pub fn new(timeout_secs: u64) -> Self {
139        Self {
140            total_timeout: Duration::from_secs(timeout_secs),
141            callback: None,
142            warning_threshold: 0.8,
143        }
144    }
145
146    /// Create a new builder with a specific duration
147    pub fn with_duration(duration: Duration) -> Self {
148        Self {
149            total_timeout: duration,
150            callback: None,
151            warning_threshold: 0.8,
152        }
153    }
154
155    /// Set the progress callback
156    pub fn callback(mut self, callback: StreamingProgressCallback) -> Self {
157        self.callback = Some(callback);
158        self
159    }
160
161    /// Set the warning threshold (0.0-1.0)
162    pub fn warning_threshold(mut self, threshold: f32) -> Self {
163        self.warning_threshold = threshold.clamp(0.0, 1.0);
164        self
165    }
166
167    /// Build the tracker
168    pub fn build(self) -> StreamingProgressTracker {
169        let mut tracker = StreamingProgressTracker::new(self.total_timeout);
170        if let Some(callback) = self.callback {
171            tracker.callback = Some(Arc::new(callback));
172        }
173        tracker.warning_threshold = self.warning_threshold;
174        tracker
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use std::sync::Mutex;
182
183    #[test]
184    fn test_progress_tracker_creation() {
185        let tracker = StreamingProgressTracker::new(Duration::from_secs(600));
186        assert_eq!(tracker.progress_percent(), 0);
187        assert!(!tracker.is_approaching_timeout());
188    }
189
190    #[test]
191    fn test_progress_reporting() {
192        let tracker = StreamingProgressTracker::new(Duration::from_secs(100));
193
194        tracker.report_progress_with_elapsed(Duration::from_secs(30));
195        assert_eq!(tracker.progress_percent(), 30);
196
197        tracker.report_progress_with_elapsed(Duration::from_secs(80));
198        assert_eq!(tracker.progress_percent(), 80);
199    }
200
201    #[test]
202    fn test_warning_threshold() {
203        let tracker =
204            StreamingProgressTracker::new(Duration::from_secs(100)).with_warning_threshold(0.8);
205
206        tracker.report_progress_with_elapsed(Duration::from_secs(50));
207        assert!(!tracker.is_approaching_timeout());
208
209        tracker.report_progress_with_elapsed(Duration::from_secs(85));
210        assert!(tracker.is_approaching_timeout());
211    }
212
213    #[test]
214    fn test_callback_execution() {
215        let progress_log = Arc::new(Mutex::new(Vec::new()));
216        let progress_clone = progress_log.clone();
217
218        let tracker = StreamingProgressTracker::new(Duration::from_secs(100)).with_callback(
219            Box::new(move |progress: f32| {
220                progress_clone.lock().unwrap().push(progress);
221            }),
222        );
223
224        tracker.report_progress_with_elapsed(Duration::from_secs(30));
225        tracker.report_progress_with_elapsed(Duration::from_secs(60));
226        tracker.report_progress_with_elapsed(Duration::from_secs(90));
227
228        let log = progress_log.lock().unwrap();
229        assert!(!log.is_empty());
230        assert!(log.iter().all(|&p| (0.0..=1.0).contains(&p)));
231    }
232
233    #[test]
234    fn test_builder_pattern() {
235        let tracker = StreamingProgressBuilder::new(300)
236            .warning_threshold(0.75)
237            .build();
238
239        assert_eq!(tracker.total_timeout.as_secs(), 300);
240        assert_eq!(tracker.warning_threshold, 0.75);
241    }
242
243    #[test]
244    fn test_zero_timeout_safety() {
245        let tracker = StreamingProgressTracker::new(Duration::from_secs(0));
246        tracker.report_chunk_received(); // Should not panic
247        assert!(!tracker.is_approaching_timeout());
248    }
249
250    #[test]
251    fn test_progress_clamping() {
252        let tracker = StreamingProgressTracker::new(Duration::from_secs(100));
253
254        tracker.report_progress_with_elapsed(Duration::from_secs(150)); // Beyond timeout
255        assert_eq!(tracker.progress_percent(), 99); // Clamped at 99%
256
257        tracker.report_error();
258        assert_eq!(tracker.progress_percent(), 100);
259    }
260}