vtcode_core/llm/providers/
streaming_progress.rs1use std::sync::Arc;
7use std::sync::atomic::{AtomicU8, Ordering};
8use std::time::{Duration, Instant};
9use tracing::warn;
10
11pub type StreamingProgressCallback = Box<dyn Fn(f32) + Send + Sync>;
14
15#[derive(Clone)]
17pub struct StreamingProgressTracker {
18 callback: Option<Arc<StreamingProgressCallback>>,
19 warning_threshold: f32,
20 total_timeout: Duration,
21 start_time: Arc<Instant>,
22 last_reported_progress: Arc<AtomicU8>,
23}
24
25impl StreamingProgressTracker {
26 pub fn new(total_timeout: Duration) -> Self {
28 Self {
29 callback: None,
30 warning_threshold: 0.8,
31 total_timeout,
32 start_time: Arc::new(Instant::now()),
33 last_reported_progress: Arc::new(AtomicU8::new(0)),
34 }
35 }
36
37 pub fn with_callback(mut self, callback: StreamingProgressCallback) -> Self {
39 self.callback = Some(Arc::new(callback));
40 self
41 }
42
43 pub fn with_warning_threshold(mut self, threshold: f32) -> Self {
45 self.warning_threshold = threshold.clamp(0.0, 1.0);
46 self
47 }
48
49 pub fn report_first_chunk(&self) {
51 self.report_progress(0.1);
52 }
53
54 pub fn report_chunk_received(&self) {
56 let elapsed = self.start_time.elapsed();
57 self.report_progress_with_elapsed(elapsed);
58 }
59
60 pub fn report_progress_with_elapsed(&self, elapsed: Duration) {
62 if self.total_timeout.as_secs() == 0 {
63 return;
64 }
65
66 let progress = elapsed.as_secs_f32() / self.total_timeout.as_secs_f32();
67 self.report_progress(progress.min(0.99)); }
69
70 pub fn report_error(&self) {
72 self.report_progress(1.0);
73 }
74
75 pub fn progress_percent(&self) -> u8 {
77 self.last_reported_progress.load(Ordering::Relaxed)
78 }
79
80 pub fn elapsed(&self) -> Duration {
82 self.start_time.elapsed()
83 }
84
85 pub fn is_approaching_timeout(&self) -> bool {
87 let elapsed = self.start_time.elapsed();
88 if self.total_timeout.as_secs() == 0 {
89 return false;
90 }
91
92 let elapsed_progress = elapsed.as_secs_f32() / self.total_timeout.as_secs_f32();
93 let reported_progress =
94 f32::from(self.last_reported_progress.load(Ordering::Relaxed)) / 100.0;
95 elapsed_progress.max(reported_progress) >= self.warning_threshold
96 }
97
98 fn report_progress(&self, progress: f32) {
100 let progress_clamped = progress.clamp(0.0, 1.0);
101 let percent = (progress_clamped * 100.0) as u8;
102
103 let last_percent = self.last_reported_progress.load(Ordering::Relaxed);
105 if percent <= last_percent {
106 return;
107 }
108
109 self.last_reported_progress
110 .store(percent, Ordering::Relaxed);
111
112 if let Some(ref callback) = self.callback {
114 callback(progress_clamped);
115 }
116
117 if progress_clamped >= self.warning_threshold && progress_clamped < 1.0 {
119 warn!(
120 "Streaming operation at {:.0}% of timeout limit ({:?}/{:?} elapsed). Approaching timeout.",
121 progress_clamped * 100.0,
122 self.elapsed(),
123 self.total_timeout
124 );
125 }
126 }
127}
128
129pub struct StreamingProgressBuilder {
131 total_timeout: Duration,
132 callback: Option<StreamingProgressCallback>,
133 warning_threshold: f32,
134}
135
136impl StreamingProgressBuilder {
137 pub fn new(timeout_secs: u64) -> Self {
139 Self {
140 total_timeout: Duration::from_secs(timeout_secs),
141 callback: None,
142 warning_threshold: 0.8,
143 }
144 }
145
146 pub fn with_duration(duration: Duration) -> Self {
148 Self {
149 total_timeout: duration,
150 callback: None,
151 warning_threshold: 0.8,
152 }
153 }
154
155 pub fn callback(mut self, callback: StreamingProgressCallback) -> Self {
157 self.callback = Some(callback);
158 self
159 }
160
161 pub fn warning_threshold(mut self, threshold: f32) -> Self {
163 self.warning_threshold = threshold.clamp(0.0, 1.0);
164 self
165 }
166
167 pub fn build(self) -> StreamingProgressTracker {
169 let mut tracker = StreamingProgressTracker::new(self.total_timeout);
170 if let Some(callback) = self.callback {
171 tracker.callback = Some(Arc::new(callback));
172 }
173 tracker.warning_threshold = self.warning_threshold;
174 tracker
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use std::sync::Mutex;
182
183 #[test]
184 fn test_progress_tracker_creation() {
185 let tracker = StreamingProgressTracker::new(Duration::from_secs(600));
186 assert_eq!(tracker.progress_percent(), 0);
187 assert!(!tracker.is_approaching_timeout());
188 }
189
190 #[test]
191 fn test_progress_reporting() {
192 let tracker = StreamingProgressTracker::new(Duration::from_secs(100));
193
194 tracker.report_progress_with_elapsed(Duration::from_secs(30));
195 assert_eq!(tracker.progress_percent(), 30);
196
197 tracker.report_progress_with_elapsed(Duration::from_secs(80));
198 assert_eq!(tracker.progress_percent(), 80);
199 }
200
201 #[test]
202 fn test_warning_threshold() {
203 let tracker =
204 StreamingProgressTracker::new(Duration::from_secs(100)).with_warning_threshold(0.8);
205
206 tracker.report_progress_with_elapsed(Duration::from_secs(50));
207 assert!(!tracker.is_approaching_timeout());
208
209 tracker.report_progress_with_elapsed(Duration::from_secs(85));
210 assert!(tracker.is_approaching_timeout());
211 }
212
213 #[test]
214 fn test_callback_execution() {
215 let progress_log = Arc::new(Mutex::new(Vec::new()));
216 let progress_clone = progress_log.clone();
217
218 let tracker = StreamingProgressTracker::new(Duration::from_secs(100)).with_callback(
219 Box::new(move |progress: f32| {
220 progress_clone.lock().unwrap().push(progress);
221 }),
222 );
223
224 tracker.report_progress_with_elapsed(Duration::from_secs(30));
225 tracker.report_progress_with_elapsed(Duration::from_secs(60));
226 tracker.report_progress_with_elapsed(Duration::from_secs(90));
227
228 let log = progress_log.lock().unwrap();
229 assert!(!log.is_empty());
230 assert!(log.iter().all(|&p| (0.0..=1.0).contains(&p)));
231 }
232
233 #[test]
234 fn test_builder_pattern() {
235 let tracker = StreamingProgressBuilder::new(300)
236 .warning_threshold(0.75)
237 .build();
238
239 assert_eq!(tracker.total_timeout.as_secs(), 300);
240 assert_eq!(tracker.warning_threshold, 0.75);
241 }
242
243 #[test]
244 fn test_zero_timeout_safety() {
245 let tracker = StreamingProgressTracker::new(Duration::from_secs(0));
246 tracker.report_chunk_received(); assert!(!tracker.is_approaching_timeout());
248 }
249
250 #[test]
251 fn test_progress_clamping() {
252 let tracker = StreamingProgressTracker::new(Duration::from_secs(100));
253
254 tracker.report_progress_with_elapsed(Duration::from_secs(150)); assert_eq!(tracker.progress_percent(), 99); tracker.report_error();
258 assert_eq!(tracker.progress_percent(), 100);
259 }
260}