unistore_progress/
tracker.rs1use crate::config::ProgressConfig;
9use crate::deps::*;
10#[allow(unused_imports)]
11use crate::error::{ProgressError, ProgressResult};
12use crate::event::ProgressEvent;
13
14struct TrackerState {
16 message: String,
18 last_notify: Option<Instant>,
20 rate_samples: Vec<f64>,
22}
23
24pub struct ProgressTracker {
39 total: u64,
41 current: AtomicU64,
43 finished: AtomicBool,
45 started_at: Instant,
47 config: ProgressConfig,
49 state: RwLock<TrackerState>,
51 sender: broadcast::Sender<ProgressEvent>,
53}
54
55impl ProgressTracker {
56 pub fn new(total: u64) -> Self {
64 Self::with_config(total, ProgressConfig::default())
65 }
66
67 pub fn with_config(total: u64, config: ProgressConfig) -> Self {
69 assert!(total > 0, "total must be greater than 0");
70
71 let (sender, _) = broadcast::channel(config.channel_capacity);
72
73 Self {
74 total,
75 current: AtomicU64::new(0),
76 finished: AtomicBool::new(false),
77 started_at: Instant::now(),
78 config,
79 state: RwLock::new(TrackerState {
80 message: String::new(),
81 last_notify: None,
82 rate_samples: Vec::with_capacity(10),
83 }),
84 sender,
85 }
86 }
87
88 pub fn total(&self) -> u64 {
90 self.total
91 }
92
93 pub fn current(&self) -> u64 {
95 self.current.load(Ordering::Relaxed)
96 }
97
98 pub fn percentage(&self) -> f64 {
100 (self.current() as f64 / self.total as f64) * 100.0
101 }
102
103 pub fn is_finished(&self) -> bool {
105 self.finished.load(Ordering::Relaxed)
106 }
107
108 pub fn elapsed(&self) -> Duration {
110 self.started_at.elapsed()
111 }
112
113 pub fn advance(&self, delta: u64) {
118 if self.is_finished() {
119 return;
120 }
121
122 let new_value = self.current.fetch_add(delta, Ordering::Relaxed) + delta;
123
124 let ratio = new_value as f64 / self.total as f64;
126 if ratio >= self.config.auto_finish_threshold {
127 self.finish();
128 } else {
129 self.maybe_notify();
130 }
131 }
132
133 pub fn set(&self, value: u64) {
135 if self.is_finished() {
136 return;
137 }
138
139 let value = value.min(self.total);
140 self.current.store(value, Ordering::Relaxed);
141
142 let ratio = value as f64 / self.total as f64;
143 if ratio >= self.config.auto_finish_threshold {
144 self.finish();
145 } else {
146 self.maybe_notify();
147 }
148 }
149
150 pub fn set_message(&self, message: impl Into<String>) {
152 let mut state = self.state.write();
153 state.message = message.into();
154 drop(state);
155 self.maybe_notify();
156 }
157
158 pub fn finish(&self) {
160 if self.finished.swap(true, Ordering::Relaxed) {
161 return; }
163 self.current.store(self.total, Ordering::Relaxed);
164 self.notify_now();
165 }
166
167 pub fn subscribe(&self) -> broadcast::Receiver<ProgressEvent> {
169 self.sender.subscribe()
170 }
171
172 pub fn snapshot(&self) -> ProgressEvent {
174 let state = self.state.read();
175 let current = self.current();
176 let elapsed = self.elapsed();
177
178 ProgressEvent {
179 current,
180 total: self.total,
181 message: state.message.clone(),
182 elapsed,
183 eta: self.calculate_eta(current, elapsed),
184 finished: self.is_finished(),
185 }
186 }
187
188 fn maybe_notify(&self) {
190 let now = Instant::now();
191 let should_notify = {
192 let state = self.state.read();
193 match state.last_notify {
194 Some(last) => now.duration_since(last) >= self.config.debounce_interval,
195 None => true,
196 }
197 };
198
199 if should_notify {
200 self.notify_now();
201 }
202 }
203
204 fn notify_now(&self) {
206 let event = self.snapshot();
207
208 {
210 let mut state = self.state.write();
211 state.last_notify = Some(Instant::now());
212
213 let elapsed_secs = event.elapsed.as_secs_f64();
215 if elapsed_secs > 0.0 {
216 let rate = event.current as f64 / elapsed_secs;
217 state.rate_samples.push(rate);
218 if state.rate_samples.len() > 10 {
219 state.rate_samples.remove(0);
220 }
221 }
222 }
223
224 let _ = self.sender.send(event);
226 }
227
228 fn calculate_eta(&self, current: u64, elapsed: Duration) -> Option<Duration> {
230 if current == 0 {
231 return None;
232 }
233
234 let state = self.state.read();
235 if state.rate_samples.len() < self.config.eta_min_samples {
236 let rate = current as f64 / elapsed.as_secs_f64();
238 if rate > 0.0 {
239 let remaining = (self.total - current) as f64;
240 return Some(Duration::from_secs_f64(remaining / rate));
241 }
242 return None;
243 }
244
245 let smoothed_rate = state.rate_samples.iter().rev().fold(0.0, |acc, &rate| {
247 acc * (1.0 - self.config.eta_smoothing_factor) + rate * self.config.eta_smoothing_factor
248 });
249
250 if smoothed_rate > 0.0 {
251 let remaining = (self.total - current) as f64;
252 Some(Duration::from_secs_f64(remaining / smoothed_rate))
253 } else {
254 None
255 }
256 }
257}
258
259impl std::fmt::Debug for ProgressTracker {
260 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261 f.debug_struct("ProgressTracker")
262 .field("total", &self.total)
263 .field("current", &self.current())
264 .field("finished", &self.is_finished())
265 .field("elapsed", &self.elapsed())
266 .finish()
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_new() {
276 let tracker = ProgressTracker::new(100);
277 assert_eq!(tracker.total(), 100);
278 assert_eq!(tracker.current(), 0);
279 assert!(!tracker.is_finished());
280 }
281
282 #[test]
283 fn test_advance() {
284 let tracker = ProgressTracker::new(100);
285 tracker.advance(10);
286 assert_eq!(tracker.current(), 10);
287 tracker.advance(20);
288 assert_eq!(tracker.current(), 30);
289 }
290
291 #[test]
292 fn test_set() {
293 let tracker = ProgressTracker::new(100);
294 tracker.set(50);
295 assert_eq!(tracker.current(), 50);
296 }
297
298 #[test]
299 fn test_finish() {
300 let tracker = ProgressTracker::new(100);
301 tracker.advance(50);
302 tracker.finish();
303 assert!(tracker.is_finished());
304 assert_eq!(tracker.current(), 100);
305 }
306
307 #[test]
308 fn test_percentage() {
309 let tracker = ProgressTracker::new(100);
310 tracker.set(25);
311 assert!((tracker.percentage() - 25.0).abs() < 0.001);
312 }
313
314 #[test]
315 fn test_set_message() {
316 let tracker = ProgressTracker::new(100);
317 tracker.set_message("Processing...");
318 let snapshot = tracker.snapshot();
319 assert_eq!(snapshot.message, "Processing...");
320 }
321
322 #[test]
323 fn test_auto_finish() {
324 let tracker = ProgressTracker::new(100);
325 tracker.set(100);
326 assert!(tracker.is_finished());
327 }
328
329 #[tokio::test]
330 async fn test_subscribe() {
331 let tracker = ProgressTracker::with_config(
332 100,
333 ProgressConfig::default().no_debounce(),
334 );
335 let mut rx = tracker.subscribe();
336
337 tracker.advance(10);
338
339 let event = tokio::time::timeout(Duration::from_millis(100), rx.recv())
341 .await
342 .expect("timeout")
343 .expect("recv error");
344
345 assert_eq!(event.current, 10);
346 }
347}