unistore_progress/
tracker.rs

1//! 【进度追踪器】- 核心实现
2//!
3//! 职责:
4//! - 追踪任务进度
5//! - 计算 ETA
6//! - 发布进度事件
7
8use crate::config::ProgressConfig;
9use crate::deps::*;
10#[allow(unused_imports)]
11use crate::error::{ProgressError, ProgressResult};
12use crate::event::ProgressEvent;
13
14/// 内部状态
15struct TrackerState {
16    /// 当前进度消息
17    message: String,
18    /// 最后通知时间(用于去抖动)
19    last_notify: Option<Instant>,
20    /// 速率样本(用于 ETA 平滑)
21    rate_samples: Vec<f64>,
22}
23
24/// 进度追踪器
25///
26/// 线程安全的进度追踪器,支持多个订阅者。
27///
28/// # Example
29///
30/// ```rust
31/// use unistore_progress::ProgressTracker;
32///
33/// let tracker = ProgressTracker::new(100);
34/// tracker.advance(10);
35/// tracker.set_message("处理中...");
36/// assert_eq!(tracker.current(), 10);
37/// ```
38pub struct ProgressTracker {
39    /// 总数
40    total: u64,
41    /// 当前完成数(原子操作)
42    current: AtomicU64,
43    /// 是否已完成
44    finished: AtomicBool,
45    /// 开始时间
46    started_at: Instant,
47    /// 配置
48    config: ProgressConfig,
49    /// 可变状态
50    state: RwLock<TrackerState>,
51    /// 事件广播通道
52    sender: broadcast::Sender<ProgressEvent>,
53}
54
55impl ProgressTracker {
56    /// 创建新的进度追踪器
57    ///
58    /// # Arguments
59    /// * `total` - 总任务数
60    ///
61    /// # Panics
62    /// 如果 total 为 0
63    pub fn new(total: u64) -> Self {
64        Self::with_config(total, ProgressConfig::default())
65    }
66
67    /// 使用自定义配置创建进度追踪器
68    pub fn with_config(total: u64, config: ProgressConfig) -> Self {
69        assert!(total > 0, "total must be greater than 0");
70
71        let (sender, _) = broadcast::channel(config.channel_capacity);
72
73        Self {
74            total,
75            current: AtomicU64::new(0),
76            finished: AtomicBool::new(false),
77            started_at: Instant::now(),
78            config,
79            state: RwLock::new(TrackerState {
80                message: String::new(),
81                last_notify: None,
82                rate_samples: Vec::with_capacity(10),
83            }),
84            sender,
85        }
86    }
87
88    /// 获取总数
89    pub fn total(&self) -> u64 {
90        self.total
91    }
92
93    /// 获取当前完成数
94    pub fn current(&self) -> u64 {
95        self.current.load(Ordering::Relaxed)
96    }
97
98    /// 获取完成百分比(0.0 - 100.0)
99    pub fn percentage(&self) -> f64 {
100        (self.current() as f64 / self.total as f64) * 100.0
101    }
102
103    /// 是否已完成
104    pub fn is_finished(&self) -> bool {
105        self.finished.load(Ordering::Relaxed)
106    }
107
108    /// 获取已用时间
109    pub fn elapsed(&self) -> Duration {
110        self.started_at.elapsed()
111    }
112
113    /// 增加进度
114    ///
115    /// # Arguments
116    /// * `delta` - 增加的数量
117    pub fn advance(&self, delta: u64) {
118        if self.is_finished() {
119            return;
120        }
121
122        let new_value = self.current.fetch_add(delta, Ordering::Relaxed) + delta;
123
124        // 检查是否达到自动完成阈值
125        let ratio = new_value as f64 / self.total as f64;
126        if ratio >= self.config.auto_finish_threshold {
127            self.finish();
128        } else {
129            self.maybe_notify();
130        }
131    }
132
133    /// 设置进度到指定值
134    pub fn set(&self, value: u64) {
135        if self.is_finished() {
136            return;
137        }
138
139        let value = value.min(self.total);
140        self.current.store(value, Ordering::Relaxed);
141
142        let ratio = value as f64 / self.total as f64;
143        if ratio >= self.config.auto_finish_threshold {
144            self.finish();
145        } else {
146            self.maybe_notify();
147        }
148    }
149
150    /// 设置进度消息
151    pub fn set_message(&self, message: impl Into<String>) {
152        let mut state = self.state.write();
153        state.message = message.into();
154        drop(state);
155        self.maybe_notify();
156    }
157
158    /// 标记为完成
159    pub fn finish(&self) {
160        if self.finished.swap(true, Ordering::Relaxed) {
161            return; // 已经完成过了
162        }
163        self.current.store(self.total, Ordering::Relaxed);
164        self.notify_now();
165    }
166
167    /// 订阅进度更新
168    pub fn subscribe(&self) -> broadcast::Receiver<ProgressEvent> {
169        self.sender.subscribe()
170    }
171
172    /// 获取当前快照
173    pub fn snapshot(&self) -> ProgressEvent {
174        let state = self.state.read();
175        let current = self.current();
176        let elapsed = self.elapsed();
177
178        ProgressEvent {
179            current,
180            total: self.total,
181            message: state.message.clone(),
182            elapsed,
183            eta: self.calculate_eta(current, elapsed),
184            finished: self.is_finished(),
185        }
186    }
187
188    /// 检查是否应该通知(去抖动)
189    fn maybe_notify(&self) {
190        let now = Instant::now();
191        let should_notify = {
192            let state = self.state.read();
193            match state.last_notify {
194                Some(last) => now.duration_since(last) >= self.config.debounce_interval,
195                None => true,
196            }
197        };
198
199        if should_notify {
200            self.notify_now();
201        }
202    }
203
204    /// 立即发送通知
205    fn notify_now(&self) {
206        let event = self.snapshot();
207
208        // 更新最后通知时间和速率样本
209        {
210            let mut state = self.state.write();
211            state.last_notify = Some(Instant::now());
212
213            // 记录速率样本
214            let elapsed_secs = event.elapsed.as_secs_f64();
215            if elapsed_secs > 0.0 {
216                let rate = event.current as f64 / elapsed_secs;
217                state.rate_samples.push(rate);
218                if state.rate_samples.len() > 10 {
219                    state.rate_samples.remove(0);
220                }
221            }
222        }
223
224        // 发送事件(忽略没有订阅者的情况)
225        let _ = self.sender.send(event);
226    }
227
228    /// 计算 ETA
229    fn calculate_eta(&self, current: u64, elapsed: Duration) -> Option<Duration> {
230        if current == 0 {
231            return None;
232        }
233
234        let state = self.state.read();
235        if state.rate_samples.len() < self.config.eta_min_samples {
236            // 样本不足,使用简单计算
237            let rate = current as f64 / elapsed.as_secs_f64();
238            if rate > 0.0 {
239                let remaining = (self.total - current) as f64;
240                return Some(Duration::from_secs_f64(remaining / rate));
241            }
242            return None;
243        }
244
245        // 使用指数移动平均计算平滑速率
246        let smoothed_rate = state.rate_samples.iter().rev().fold(0.0, |acc, &rate| {
247            acc * (1.0 - self.config.eta_smoothing_factor) + rate * self.config.eta_smoothing_factor
248        });
249
250        if smoothed_rate > 0.0 {
251            let remaining = (self.total - current) as f64;
252            Some(Duration::from_secs_f64(remaining / smoothed_rate))
253        } else {
254            None
255        }
256    }
257}
258
259impl std::fmt::Debug for ProgressTracker {
260    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261        f.debug_struct("ProgressTracker")
262            .field("total", &self.total)
263            .field("current", &self.current())
264            .field("finished", &self.is_finished())
265            .field("elapsed", &self.elapsed())
266            .finish()
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_new() {
276        let tracker = ProgressTracker::new(100);
277        assert_eq!(tracker.total(), 100);
278        assert_eq!(tracker.current(), 0);
279        assert!(!tracker.is_finished());
280    }
281
282    #[test]
283    fn test_advance() {
284        let tracker = ProgressTracker::new(100);
285        tracker.advance(10);
286        assert_eq!(tracker.current(), 10);
287        tracker.advance(20);
288        assert_eq!(tracker.current(), 30);
289    }
290
291    #[test]
292    fn test_set() {
293        let tracker = ProgressTracker::new(100);
294        tracker.set(50);
295        assert_eq!(tracker.current(), 50);
296    }
297
298    #[test]
299    fn test_finish() {
300        let tracker = ProgressTracker::new(100);
301        tracker.advance(50);
302        tracker.finish();
303        assert!(tracker.is_finished());
304        assert_eq!(tracker.current(), 100);
305    }
306
307    #[test]
308    fn test_percentage() {
309        let tracker = ProgressTracker::new(100);
310        tracker.set(25);
311        assert!((tracker.percentage() - 25.0).abs() < 0.001);
312    }
313
314    #[test]
315    fn test_set_message() {
316        let tracker = ProgressTracker::new(100);
317        tracker.set_message("Processing...");
318        let snapshot = tracker.snapshot();
319        assert_eq!(snapshot.message, "Processing...");
320    }
321
322    #[test]
323    fn test_auto_finish() {
324        let tracker = ProgressTracker::new(100);
325        tracker.set(100);
326        assert!(tracker.is_finished());
327    }
328
329    #[tokio::test]
330    async fn test_subscribe() {
331        let tracker = ProgressTracker::with_config(
332            100,
333            ProgressConfig::default().no_debounce(),
334        );
335        let mut rx = tracker.subscribe();
336
337        tracker.advance(10);
338
339        // 等待事件
340        let event = tokio::time::timeout(Duration::from_millis(100), rx.recv())
341            .await
342            .expect("timeout")
343            .expect("recv error");
344
345        assert_eq!(event.current, 10);
346    }
347}