sfo_io/
stat_stream.rs

1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
2
3use std::io;
4use std::io::Error;
5use std::marker::PhantomData;
6use std::num::NonZeroU64;
7use std::pin::Pin;
8use std::sync::{Arc, Mutex};
9use std::task::{Context, Poll};
10use std::time::SystemTime;
11use nonzero_ext::nonzero;
12use pin_project::pin_project;
13use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14
15pub trait SpeedStat: 'static + Send + Sync {
16    fn get_write_speed(&self) -> u64;
17    fn get_write_sum_size(&self) -> u64;
18    fn get_read_speed(&self) -> u64;
19    fn get_read_sum_size(&self) -> u64;
20}
21
22pub trait SpeedTracker: SpeedStat {
23    fn add_write_data_size(&self, size: u64);
24    fn add_read_data_size(&self, size: u64);
25}
26
27pub trait TimePicker: 'static + Sync + Send {
28    fn now() -> u128;
29}
30
31pub struct SystemTimePicker;
32
33impl TimePicker for SystemTimePicker {
34    fn now() -> u128 {
35        SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_millis()
36    }
37}
38
39struct DataItem {
40    size: u64,
41    time: u64,
42}
43
44pub(crate) struct SpeedState<T: TimePicker> {
45    sum_size: u64,
46    last_time: u128,
47    speed_duration: NonZeroU64,
48    data_items: Vec<DataItem>,
49    _time_picker: PhantomData<T>,
50}
51
52impl<T: TimePicker> SpeedState<T> {
53    fn new(speed_duration: NonZeroU64) -> SpeedState<T> {
54        SpeedState {
55            sum_size: 0,
56            last_time: T::now(),
57            speed_duration,
58            data_items: vec![],
59            _time_picker: Default::default(),
60        }
61    }
62
63    pub fn add_data(&mut self, size: u64) {
64        self.sum_size += size;
65        let now = T::now();
66        self.clear_invalid_item(now);
67
68        if now / 1000 == self.last_time / 1000 {
69            if self.data_items.len() == 0 {
70                self.data_items.push(DataItem {
71                    size,
72                    time: (now / 1000) as u64,
73                });
74            } else {
75                let last_item = self.data_items.last_mut().unwrap();
76                if last_item.time == (now / 1000) as u64 {
77                    last_item.size += size;
78                } else {
79                    self.data_items.push(DataItem {
80                        size,
81                        time: (now / 1000) as u64,
82                    });
83                }
84            }
85        } else {
86            let duration = now - self.last_time;
87            let mut pos = 0;
88            let mut offset = 1000 - self.last_time % 1000;
89            let mut sec = (self.last_time / 1000) as u64;
90            while pos < duration {
91                let mut weight = offset;
92                if pos + offset > duration {
93                    weight = duration - pos;
94                }
95                let data_size = (size as u128 * weight / duration) as u64;
96                if self.data_items.len() == 0 {
97                    self.data_items.push(DataItem {
98                        size: data_size,
99                        time: sec,
100                    });
101                } else {
102                    let last_item = self.data_items.last_mut().unwrap();
103                    if last_item.time == sec {
104                        last_item.size += data_size;
105                    } else {
106                        self.data_items.push(DataItem {
107                            size: data_size,
108                            time: sec,
109                        })
110                    }
111                }
112                pos += offset;
113                offset = 1000;
114                sec += 1;
115            }
116
117        }
118        self.last_time = now;
119    }
120
121    pub fn clear_invalid_item(&mut self, now: u128) {
122        let now = (now / 1000) as u64;
123        self.data_items.retain(|item| {
124            (now - item.time) <= self.speed_duration.get()
125        });
126    }
127
128    pub fn get_speed(&self) -> u64 {
129        let now = (T::now() / 1000) as u64;
130        let mut sum_size = 0;
131        for item in self.data_items.iter() {
132            if (now - item.time) <= self.speed_duration.get() && now != item.time {
133                sum_size += item.size;
134            }
135        }
136
137        sum_size / self.speed_duration
138    }
139
140    pub fn get_sum_size(&self) -> u64 {
141        self.sum_size
142    }
143}
144
145pub struct SfoSpeedStat<T: TimePicker = SystemTimePicker> {
146    upload_state: Mutex<SpeedState<T>>,
147    download_state: Mutex<SpeedState<T>>,
148}
149
150impl SfoSpeedStat {
151    pub fn new() -> SfoSpeedStat {
152        Self {
153            upload_state: Mutex::new(SpeedState::new(nonzero!(5u64))),
154            download_state: Mutex::new(SpeedState::new(nonzero!(5u64))),
155        }
156    }
157
158
159    /// Creates a new SfoSpeedStat instance with the specified duration
160    ///
161    /// # Parameters
162    /// * `duration` - The duration for statistics, in seconds
163    ///
164    /// # Returns
165    /// Returns a new SfoSpeedStat instance containing initialized upload and download states
166    pub fn new_with_duration(duration: u64) -> SfoSpeedStat {
167        SfoSpeedStat {
168            upload_state: Mutex::new(SpeedState::new(NonZeroU64::new(duration).unwrap())),
169            download_state: Mutex::new(SpeedState::new(NonZeroU64::new(duration).unwrap())),
170        }
171    }
172}
173
174impl<T: TimePicker> SfoSpeedStat<T> {
175    pub(crate) fn new_with_time_picker() -> SfoSpeedStat<T> {
176        SfoSpeedStat {
177            upload_state: Mutex::new(SpeedState::new(nonzero!(5u64))),
178            download_state: Mutex::new(SpeedState::new(nonzero!(5u64))),
179        }
180    }
181}
182
183impl<T: TimePicker> SpeedTracker for SfoSpeedStat<T> {
184    fn add_write_data_size(&self, size: u64) {
185        self.upload_state.lock().unwrap().add_data(size);
186    }
187
188    fn add_read_data_size(&self, size: u64) {
189        self.download_state.lock().unwrap().add_data(size);
190    }
191}
192
193impl<T: TimePicker> SpeedStat for SfoSpeedStat<T> {
194    fn get_write_speed(&self) -> u64 {
195        self.upload_state.lock().unwrap().get_speed()
196    }
197
198    fn get_write_sum_size(&self) -> u64 {
199        self.upload_state.lock().unwrap().get_sum_size()
200    }
201
202    fn get_read_speed(&self) -> u64 {
203        self.download_state.lock().unwrap().get_speed()
204    }
205
206    fn get_read_sum_size(&self) -> u64 {
207        self.download_state.lock().unwrap().get_sum_size()
208    }
209}
210
211#[pin_project]
212pub struct StatStream<T: AsyncRead + AsyncWrite + Send + 'static> {
213    #[pin]
214    stream: T,
215    stat: Arc<dyn SpeedTracker>,
216}
217
218impl<T: AsyncRead + AsyncWrite + Send + 'static> StatStream<T> {
219    pub fn new(stream: T) -> StatStream<T> {
220        StatStream {
221            stream,
222            stat: Arc::new(SfoSpeedStat::new()),
223        }
224    }
225
226    pub fn new_with_tracker(stream: T, tracker: Arc<dyn SpeedTracker>) -> StatStream<T> {
227        StatStream {
228            stream,
229            stat: tracker,
230        }
231    }
232}
233
234impl<T: AsyncRead + AsyncWrite + Send + 'static> StatStream<T> {
235    pub(crate) fn new_test<S: TimePicker>(stream: T) -> StatStream<T> {
236        StatStream {
237            stream,
238            stat: Arc::new(SfoSpeedStat::<S>::new_with_time_picker()),
239        }
240    }
241
242    pub fn get_speed_stat(&self) -> Arc<dyn SpeedStat> {
243        self.stat.clone()
244    }
245
246    pub fn raw_stream(&mut self) -> &mut T {
247        &mut self.stream
248    }
249}
250
251impl<T: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncRead for StatStream<T> {
252    fn poll_read(
253        self: Pin<&mut Self>,
254        cx: &mut Context<'_>,
255        buf: &mut ReadBuf<'_>,
256    ) -> Poll<io::Result<()>> {
257        let this = self.project();
258        match this.stream.poll_read(cx, buf) {
259            Poll::Ready(res) => {
260                if res.is_ok() {
261                    this.stat.add_read_data_size(buf.filled().len() as u64);
262                }
263                Poll::Ready(res)
264            },
265            Poll::Pending => Poll::Pending,
266        }
267    }
268}
269
270impl<T: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncWrite for StatStream<T> {
271    fn poll_write(
272        self: Pin<&mut Self>,
273        cx: &mut Context<'_>,
274        buf: &[u8],
275    ) -> Poll<Result<usize, io::Error>> {
276        let this = self.project();
277        match this.stream.poll_write(cx, buf) {
278            Poll::Ready(res) => {
279                if res.is_ok() {
280                    this.stat.add_write_data_size(buf.len() as u64);
281                }
282                Poll::Ready(res)
283            },
284            Poll::Pending => Poll::Pending,
285        }
286    }
287
288    fn poll_flush(
289        self: Pin<&mut Self>,
290        cx: &mut Context<'_>,
291    ) -> Poll<Result<(), io::Error>> {
292        self.project().stream.poll_flush(cx)
293    }
294
295    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
296        self.project().stream.poll_shutdown(cx)
297    }
298}
299
300#[pin_project]
301pub struct StatRead<T: AsyncRead + Send + 'static> {
302    #[pin]
303    reader: T,
304    stat: Arc<dyn SpeedTracker>,
305}
306
307impl<T: AsyncRead + Send + 'static> StatRead<T> {
308    pub fn new(reader: T) -> StatRead<T> {
309        StatRead {
310            reader,
311            stat: Arc::new(SfoSpeedStat::new()),
312        }
313    }
314
315    pub fn new_with_tracker(reader: T, tracker: Arc<dyn SpeedTracker>) -> StatRead<T> {
316        StatRead {
317            reader,
318            stat: tracker,
319        }
320    }
321}
322
323impl<T: AsyncRead + Send + 'static> StatRead<T> {
324    pub(crate) fn new_test<S: TimePicker>(reader: T) -> StatRead<T> {
325        StatRead {
326            reader,
327            stat: Arc::new(SfoSpeedStat::<S>::new_with_time_picker()),
328        }
329    }
330
331    pub fn get_speed_stat(&self) -> Arc<dyn SpeedStat> {
332        self.stat.clone()
333    }
334
335    pub fn raw_reader(&mut self) -> &mut T {
336        &mut self.reader
337    }
338}
339
340impl<T: AsyncRead + Unpin + Send + 'static> AsyncRead for StatRead<T> {
341    fn poll_read(
342        self: Pin<&mut Self>,
343        cx: &mut Context<'_>,
344        buf: &mut ReadBuf<'_>,
345    ) -> Poll<io::Result<()>> {
346        let this = self.project();
347        match this.reader.poll_read(cx, buf) {
348            Poll::Ready(res) => {
349                if res.is_ok() {
350                    this.stat.add_read_data_size(buf.filled().len() as u64);
351                }
352                Poll::Ready(res)
353            },
354            Poll::Pending => Poll::Pending,
355        }
356    }
357}
358
359#[pin_project]
360pub struct StatWrite<T: AsyncWrite + Send + 'static> {
361    #[pin]
362    writer: T,
363    stat: Arc<dyn SpeedTracker>,
364}
365
366impl<T: AsyncWrite + Send + 'static> StatWrite<T> {
367    pub fn new(writer: T) -> StatWrite<T> {
368        StatWrite {
369            writer,
370            stat: Arc::new(SfoSpeedStat::new()),
371        }
372    }
373
374    pub fn new_with_tracker(writer: T, tracker: Arc<dyn SpeedTracker>) -> StatWrite<T> {
375        StatWrite {
376            writer,
377            stat: tracker,
378        }
379    }
380}
381
382impl<T: AsyncWrite + Send + 'static> StatWrite<T> {
383    pub(crate) fn new_test<S: TimePicker>(writer: T) -> StatWrite<T> {
384        StatWrite {
385            writer,
386            stat: Arc::new(SfoSpeedStat::<S>::new_with_time_picker()),
387        }
388    }
389
390    pub fn get_speed_stat(&self) -> Arc<dyn SpeedStat> {
391        self.stat.clone()
392    }
393
394    pub fn raw_writer(&mut self) -> &mut T {
395        &mut self.writer
396    }
397}
398
399impl<T: AsyncWrite + Unpin + Send + 'static> AsyncWrite for StatWrite<T> {
400    fn poll_write(
401        self: Pin<&mut Self>,
402        cx: &mut Context<'_>,
403        buf: &[u8],
404    ) -> Poll<Result<usize, io::Error>> {
405        let this = self.project();
406        match this.writer.poll_write(cx, buf) {
407            Poll::Ready(res) => {
408                if res.is_ok() {
409                    this.stat.add_write_data_size(buf.len() as u64);
410                }
411                Poll::Ready(res)
412            },
413            Poll::Pending => Poll::Pending,
414        }
415    }
416
417    fn poll_flush(
418        self: Pin<&mut Self>,
419        cx: &mut Context<'_>,
420    ) -> Poll<Result<(), io::Error>> {
421        self.project().writer.poll_flush(cx)
422    }
423
424    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
425        self.project().writer.poll_shutdown(cx)
426    }
427}
428#[cfg(test)]
429#[cfg_attr(coverage_nightly, coverage(off))]
430mod tests {
431    use super::*;
432    use std::sync::atomic::{AtomicU64, Ordering};
433    use std::time::Duration;
434    use tokio::io::{AsyncReadExt, AsyncWriteExt};
435
436    #[test]
437    fn test_speed_state_new() {
438        // Mock TimePicker for testing
439        static MOCK_TIME: AtomicU64 = AtomicU64::new(0);
440
441        struct MockTimePicker;
442
443        impl TimePicker for MockTimePicker {
444            fn now() -> u128 {
445                MOCK_TIME.load(Ordering::Relaxed) as u128
446            }
447        }
448
449        // Helper function to set mock time
450        fn set_mock_time(time_ms: u64) {
451            MOCK_TIME.store(time_ms, Ordering::Relaxed);
452        }
453
454        set_mock_time(1000);
455        let state: SpeedState<MockTimePicker> = SpeedState::new(nonzero!(5u64));
456
457        assert_eq!(state.sum_size, 0);
458        assert_eq!(state.last_time, 1000);
459        assert_eq!(state.data_items.len(), 0);
460    }
461
462    #[test]
463    fn test_add_data_same_second() {
464        // Mock TimePicker for testing
465        static MOCK_TIME: AtomicU64 = AtomicU64::new(0);
466
467        struct MockTimePicker;
468
469        impl TimePicker for MockTimePicker {
470            fn now() -> u128 {
471                MOCK_TIME.load(Ordering::Relaxed) as u128
472            }
473        }
474
475        // Helper function to set mock time
476        fn set_mock_time(time_ms: u64) {
477            MOCK_TIME.store(time_ms, Ordering::Relaxed);
478        }
479
480        set_mock_time(1500);
481        let mut state: SpeedState<MockTimePicker> = SpeedState::new(nonzero!(5u64));
482
483        state.add_data(100);
484        assert_eq!(state.sum_size, 100);
485        assert_eq!(state.data_items.len(), 1);
486        assert_eq!(state.data_items[0].size, 100);
487        assert_eq!(state.data_items[0].time, 1); // 1000 / 1000 = 1
488
489        state.add_data(200);
490        assert_eq!(state.sum_size, 300);
491        assert_eq!(state.data_items.len(), 1);
492        assert_eq!(state.data_items[0].size, 300); // 合并到同一秒
493    }
494
495    #[test]
496    fn test_add_data_different_seconds() {
497        // Mock TimePicker for testing
498        static MOCK_TIME: AtomicU64 = AtomicU64::new(0);
499
500        struct MockTimePicker;
501
502        impl TimePicker for MockTimePicker {
503            fn now() -> u128 {
504                MOCK_TIME.load(Ordering::Relaxed) as u128
505            }
506        }
507
508        // Helper function to set mock time
509        fn set_mock_time(time_ms: u64) {
510            MOCK_TIME.store(time_ms, Ordering::Relaxed);
511        }
512
513        // Helper function to advance mock time
514        fn advance_mock_time(delta_ms: u64) {
515            MOCK_TIME.fetch_add(delta_ms, Ordering::Relaxed);
516        }
517
518        set_mock_time(1000);
519        let mut state: SpeedState<MockTimePicker> = SpeedState::new(nonzero!(5u64));
520
521        state.add_data(100);
522        advance_mock_time(2000); // 时间前进到3000ms
523        state.add_data(200);
524
525        assert_eq!(state.sum_size, 300);
526        assert_eq!(state.data_items.len(), 2);
527        assert_eq!(state.data_items[0].size, 200);
528        assert_eq!(state.data_items[0].time, 1);
529        assert_eq!(state.data_items[1].size, 100);
530        assert_eq!(state.data_items[1].time, 2);
531
532        assert_eq!(state.get_speed(), 60);
533        advance_mock_time(500);
534        assert_eq!(state.get_speed(), 60);
535        //
536        state.add_data(300);
537        assert_eq!(state.sum_size, 600);
538        assert_eq!(state.get_speed(), 60);
539        advance_mock_time(500);
540        assert_eq!(state.get_speed(), 120);
541    }
542
543    #[test]
544    fn test_add_data_cross_seconds_distribution() {
545        // Mock TimePicker for testing
546        static MOCK_TIME: AtomicU64 = AtomicU64::new(0);
547
548        struct MockTimePicker;
549
550        impl TimePicker for MockTimePicker {
551            fn now() -> u128 {
552                MOCK_TIME.load(Ordering::Relaxed) as u128
553            }
554        }
555
556        // Helper function to set mock time
557        fn set_mock_time(time_ms: u64) {
558            MOCK_TIME.store(time_ms, Ordering::Relaxed);
559        }
560
561        // Helper function to advance mock time
562        fn advance_mock_time(delta_ms: u64) {
563            MOCK_TIME.fetch_add(delta_ms, Ordering::Relaxed);
564        }
565
566        // 测试跨秒时数据如何分配到不同的秒中
567        set_mock_time(1500); // 1.5秒
568        let mut state: SpeedState<MockTimePicker> = SpeedState::new(nonzero!(5u64));
569        advance_mock_time(1500);
570
571        // 从1500ms到3000ms,增加1500ms,跨越2个完整的秒(2s和3s)
572        // 1.5s到2s有500ms,2s到3s有1000ms
573        // 总共1500ms,添加300字节数据
574        state.add_data(300);
575        advance_mock_time(1500);
576
577        // 应该创建两个数据项: 一个在第2秒,一个在第3秒
578        // 第2秒应该有 300 * 500/1500 = 100 字节
579        // 第3秒应该有 300 * 1000/1500 = 200 字节
580        assert_eq!(state.data_items.len(), 2);
581        assert_eq!(state.data_items[0].time, 1);
582        assert_eq!(state.data_items[0].size, 100);
583        assert_eq!(state.data_items[1].time, 2);
584        assert_eq!(state.data_items[1].size, 200);
585
586        // 测试跨秒时数据如何分配到不同的秒中
587        set_mock_time(1500); // 1.5秒
588        let mut state: SpeedState<MockTimePicker> = SpeedState::new(nonzero!(5u64));
589        advance_mock_time(2000);
590
591        // 从1500ms到3000ms,增加1500ms,跨越2个完整的秒(2s和3s)
592        // 1.5s到2s有500ms,2s到3s有1000ms
593        // 总共1500ms,添加300字节数据
594        state.add_data(400);
595        advance_mock_time(1500);
596
597        // 应该创建两个数据项: 一个在第2秒,一个在第3秒
598        // 第2秒应该有 300 * 500/1500 = 100 字节
599        // 第3秒应该有 300 * 1000/1500 = 200 字节
600        assert_eq!(state.data_items.len(), 3);
601        assert_eq!(state.data_items[0].time, 1);
602        assert_eq!(state.data_items[0].size, 100);
603        assert_eq!(state.data_items[1].time, 2);
604        assert_eq!(state.data_items[1].size, 200);
605        assert_eq!(state.data_items[2].time, 3);
606        assert_eq!(state.data_items[2].size, 100);
607    }
608
609    #[test]
610    fn test_clear_invalid_item() {
611        // Mock TimePicker for testing
612        static MOCK_TIME: AtomicU64 = AtomicU64::new(0);
613
614        struct MockTimePicker;
615
616        impl TimePicker for MockTimePicker {
617            fn now() -> u128 {
618                MOCK_TIME.load(Ordering::Relaxed) as u128
619            }
620        }
621
622        // Helper function to set mock time
623        fn set_mock_time(time_ms: u64) {
624            MOCK_TIME.store(time_ms, Ordering::Relaxed);
625        }
626
627        let mut state: SpeedState<MockTimePicker> = SpeedState::new(nonzero!(5u64));
628
629        // 添加几个不同时间的数据项
630        state.data_items.push(DataItem { size: 100, time: 5 }); // 5秒时的数据,应该被清除
631        state.data_items.push(DataItem { size: 200, time: 7 }); // 7秒时的数据,应该保留
632        state.data_items.push(DataItem { size: 300, time: 8 }); // 8秒时的数据,应该保留
633
634        set_mock_time(11000); // 当前时间10秒
635        state.clear_invalid_item(MockTimePicker::now());
636
637        assert_eq!(state.data_items.len(), 2);
638        assert_eq!(state.data_items[0].time, 7);
639        assert_eq!(state.data_items[1].time, 8);
640    }
641
642    #[test]
643    fn test_get_speed() {
644        // Mock TimePicker for testing
645        static MOCK_TIME: AtomicU64 = AtomicU64::new(0);
646
647        struct MockTimePicker;
648
649        impl TimePicker for MockTimePicker {
650            fn now() -> u128 {
651                MOCK_TIME.load(Ordering::Relaxed) as u128
652            }
653        }
654
655        // Helper function to set mock time
656        fn set_mock_time(time_ms: u64) {
657            MOCK_TIME.store(time_ms, Ordering::Relaxed);
658        }
659
660        // Helper function to advance mock time
661        fn advance_mock_time(delta_ms: u64) {
662            MOCK_TIME.fetch_add(delta_ms, Ordering::Relaxed);
663        }
664
665        set_mock_time(10000); // 10秒
666        let mut state: SpeedState<MockTimePicker> = SpeedState::new(nonzero!(5u64));
667
668        state.add_data(100);
669        advance_mock_time(1000);
670        state.add_data(200);
671        advance_mock_time(1000);
672        state.add_data(300);
673        advance_mock_time(1000);
674        state.add_data(400);
675        advance_mock_time(1000);
676        state.add_data(500);
677        advance_mock_time(1000);
678        state.add_data(600);
679        advance_mock_time(1000);
680        state.add_data(700);
681
682        let speed = state.get_speed();
683        // 应该计算 300+400+500+600+700 = 2500 字节在4秒内 => 2500/5 = 500 bytes/sec
684        assert_eq!(speed, 500);
685    }
686
687    #[test]
688    fn test_get_sum_size() {
689        // Mock TimePicker for testing
690        static MOCK_TIME: AtomicU64 = AtomicU64::new(0);
691
692        struct MockTimePicker;
693
694        impl TimePicker for MockTimePicker {
695            fn now() -> u128 {
696                MOCK_TIME.load(Ordering::Relaxed) as u128
697            }
698        }
699
700        let mut state: SpeedState<MockTimePicker> = SpeedState::new(nonzero!(10u64));
701        state.add_data(100);
702        state.add_data(200);
703        assert_eq!(state.get_sum_size(), 300);
704    }
705
706    #[test]
707    fn test_speed_stat_impl() {
708        // Mock TimePicker for testing
709        static MOCK_TIME: AtomicU64 = AtomicU64::new(0);
710
711        struct MockTimePicker;
712
713        impl TimePicker for MockTimePicker {
714            fn now() -> u128 {
715                MOCK_TIME.load(Ordering::Relaxed) as u128
716            }
717        }
718
719        // Helper function to set mock time
720        fn set_mock_time(time_ms: u64) {
721            MOCK_TIME.store(time_ms, Ordering::Relaxed);
722        }
723
724        let stat: SfoSpeedStat<MockTimePicker> = SfoSpeedStat::new_with_time_picker();
725
726        stat.add_write_data_size(100);
727        stat.add_read_data_size(200);
728
729        // 由于没有时间流逝,速度为0
730        assert_eq!(stat.get_write_speed(), 0);
731        assert_eq!(stat.get_read_speed(), 0);
732
733        // 模拟时间流逝后再次检查
734        set_mock_time(5000);
735        stat.add_write_data_size(500);
736        stat.add_read_data_size(1000);
737        set_mock_time(6000);
738        stat.add_write_data_size(0);
739        stat.add_read_data_size(0);
740
741        // 现在应该有速度了
742        assert!(stat.get_write_speed() > 0);
743        assert!(stat.get_read_speed() > 0);
744    }
745
746    // 注意: StatStream的测试需要tokio运行时,这里省略了复杂的AsyncRead/AsyncWrite mock
747    #[tokio::test]
748    async fn test_stat_stream_creation() {
749        static MOCK_TIME: AtomicU64 = AtomicU64::new(0);
750
751        struct MockTimePicker;
752
753        impl TimePicker for MockTimePicker {
754            fn now() -> u128 {
755                MOCK_TIME.load(Ordering::Relaxed) as u128
756            }
757        }
758
759        // Helper function to advance mock time
760        fn advance_mock_time(delta_ms: u64) {
761            MOCK_TIME.fetch_add(delta_ms, Ordering::Relaxed);
762        }
763
764        // 创建一个简单的mock stream用于测试
765        struct MockStream {
766            future: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
767        }
768
769        impl AsyncRead for MockStream {
770            fn poll_read(
771                mut self: Pin<&mut Self>,
772                _cx: &mut Context<'_>,
773                _buf: &mut ReadBuf<'_>,
774            ) -> Poll<io::Result<()>> {
775                if self.future.is_none() {
776                    self.future = Some(Box::pin(tokio::time::sleep(Duration::from_millis(10))));
777                }
778                match Pin::new(self.future.as_mut().unwrap()).poll(_cx) {
779                    Poll::Ready(_) => {
780                        self.future = None;
781                        _buf.set_filled(10);
782                        Poll::Ready(Ok(()))
783                    }
784                    Poll::Pending => Poll::Pending,
785                }
786            }
787        }
788
789        impl AsyncWrite for MockStream {
790            fn poll_write(
791                mut self: Pin<&mut Self>,
792                _cx: &mut Context<'_>,
793                buf: &[u8],
794            ) -> Poll<Result<usize, io::Error>> {
795                if self.future.is_none() {
796                    self.future = Some(Box::pin(tokio::time::sleep(Duration::from_millis(10))));
797                }
798                match Pin::new(self.future.as_mut().unwrap()).poll(_cx) {
799                    Poll::Ready(_) => {
800                        self.future = None;
801                        Poll::Ready(Ok(buf.len()))
802                    }
803                    Poll::Pending => Poll::Pending,
804                }
805            }
806
807            fn poll_flush(
808                self: Pin<&mut Self>,
809                _cx: &mut Context<'_>,
810            ) -> Poll<Result<(), io::Error>> {
811                Poll::Ready(Ok(()))
812            }
813
814            fn poll_shutdown(
815                self: Pin<&mut Self>,
816                _cx: &mut Context<'_>
817            ) -> Poll<Result<(), Error>> {
818                Poll::Ready(Ok(()))
819            }
820        }
821
822        impl Unpin for MockStream {}
823
824        let stream = MockStream{
825            future: None,
826        };
827        let mut stat_stream = StatStream::new_test::<MockTimePicker>(stream);
828        let speed_stat = stat_stream.get_speed_stat();
829        let mut upload_size = 0;
830        let mut download_size = 0;
831        let mut buf = vec![0u8; 4096];
832        advance_mock_time(500);
833        for i in 0..100 {
834            let size = stat_stream.write(&buf).await.unwrap();
835            stat_stream.flush().await.unwrap();
836            upload_size += size;
837            let size = stat_stream.read(&mut buf).await.unwrap();
838            download_size += size;
839            advance_mock_time(1000);
840            if i < 5 {
841                assert_eq!(speed_stat.get_write_speed(), (upload_size / 5) as u64);
842                assert_eq!(speed_stat.get_read_speed(), (download_size / 5) as u64);
843            } else {
844                assert_eq!(speed_stat.get_write_sum_size(), upload_size as u64);
845                assert_eq!(speed_stat.get_read_sum_size(), download_size as u64);
846                assert_eq!(speed_stat.get_write_speed(), (4096 * 5 - 2048)/5);
847                assert_eq!(speed_stat.get_read_speed(), (10 * 5 - 5)/5);
848            }
849        }
850        stat_stream.shutdown().await.unwrap();
851    }
852
853    #[tokio::test]
854    async fn test_stat_read_creation() {
855        static MOCK_TIME: AtomicU64 = AtomicU64::new(0);
856
857        struct MockTimePicker;
858
859        impl TimePicker for MockTimePicker {
860            fn now() -> u128 {
861                MOCK_TIME.load(Ordering::Relaxed) as u128
862            }
863        }
864
865        // Helper function to advance mock time
866        fn advance_mock_time(delta_ms: u64) {
867            MOCK_TIME.fetch_add(delta_ms, Ordering::Relaxed);
868        }
869
870        // 创建一个简单的mock reader用于测试
871        struct MockReader {
872            future: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
873        }
874
875        impl AsyncRead for MockReader {
876            fn poll_read(
877                mut self: Pin<&mut Self>,
878                _cx: &mut Context<'_>,
879                _buf: &mut ReadBuf<'_>,
880            ) -> Poll<io::Result<()>> {
881                if self.future.is_none() {
882                    self.future = Some(Box::pin(tokio::time::sleep(Duration::from_millis(10))));
883                }
884                match Pin::new(self.future.as_mut().unwrap()).poll(_cx) {
885                    Poll::Ready(_) => {
886                        self.future = None;
887                        _buf.set_filled(10);
888                        Poll::Ready(Ok(()))
889                    }
890                    Poll::Pending => Poll::Pending,
891                }
892            }
893        }
894
895        impl Unpin for MockReader {}
896
897        let reader = MockReader{
898            future: None,
899        };
900        let mut stat_reader = StatRead::new_test::<MockTimePicker>(reader);
901        let speed_stat = stat_reader.get_speed_stat();
902        let mut download_size = 0;
903        let mut buf = vec![0u8; 4096];
904        advance_mock_time(500);
905        for i in 0..100 {
906            let size = stat_reader.read(&mut buf).await.unwrap();
907            download_size += size;
908            advance_mock_time(1000);
909            if i < 5 {
910                assert_eq!(speed_stat.get_read_speed(), (download_size / 5) as u64);
911            } else {
912                assert_eq!(speed_stat.get_read_sum_size(), download_size as u64);
913                assert_eq!(speed_stat.get_read_speed(), (10 * 5 - 5)/5);
914            }
915        }
916    }
917
918    #[tokio::test]
919    async fn test_stat_write_creation() {
920        static MOCK_TIME: AtomicU64 = AtomicU64::new(0);
921
922        struct MockTimePicker;
923
924        impl TimePicker for MockTimePicker {
925            fn now() -> u128 {
926                MOCK_TIME.load(Ordering::Relaxed) as u128
927            }
928        }
929
930        // Helper function to advance mock time
931        fn advance_mock_time(delta_ms: u64) {
932            MOCK_TIME.fetch_add(delta_ms, Ordering::Relaxed);
933        }
934
935        // 创建一个简单的mock writer用于测试
936        struct MockWriter {
937            future: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
938        }
939
940        impl AsyncWrite for MockWriter {
941            fn poll_write(
942                mut self: Pin<&mut Self>,
943                _cx: &mut Context<'_>,
944                buf: &[u8],
945            ) -> Poll<Result<usize, io::Error>> {
946                if self.future.is_none() {
947                    self.future = Some(Box::pin(tokio::time::sleep(Duration::from_millis(10))));
948                }
949                match Pin::new(self.future.as_mut().unwrap()).poll(_cx) {
950                    Poll::Ready(_) => {
951                        self.future = None;
952                        Poll::Ready(Ok(buf.len()))
953                    }
954                    Poll::Pending => Poll::Pending,
955                }
956            }
957
958            fn poll_flush(
959                self: Pin<&mut Self>,
960                _cx: &mut Context<'_>,
961            ) -> Poll<Result<(), io::Error>> {
962                Poll::Ready(Ok(()))
963            }
964
965            fn poll_shutdown(
966                self: Pin<&mut Self>,
967                _cx: &mut Context<'_>
968            ) -> Poll<Result<(), Error>> {
969                Poll::Ready(Ok(()))
970            }
971        }
972
973        impl Unpin for MockWriter {}
974
975        let writer = MockWriter{
976            future: None,
977        };
978        let mut stat_writer = StatWrite::new_test::<MockTimePicker>(writer);
979        let speed_stat = stat_writer.get_speed_stat();
980        let mut upload_size = 0;
981        let buf = vec![0u8; 4096];
982        advance_mock_time(500);
983        for i in 0..100 {
984            let size = stat_writer.write(&buf).await.unwrap();
985            stat_writer.flush().await.unwrap();
986            upload_size += size;
987            advance_mock_time(1000);
988            if i < 5 {
989                assert_eq!(speed_stat.get_write_speed(), (upload_size / 5) as u64);
990            } else {
991                assert_eq!(speed_stat.get_write_sum_size(), upload_size as u64);
992                assert_eq!(speed_stat.get_write_speed(), (4096 * 5 - 2048)/5);
993            }
994        }
995        stat_writer.shutdown().await.unwrap();
996    }
997}