sfo_io/
limit_stream.rs

1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
2
3use crate::SpeedLimitSession;
4use std::io::Error;
5use std::pin::{Pin};
6use std::task::{Context, Poll};
7use pin_project::pin_project;
8use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9
10enum ReadState {
11    Idle,
12    Waiting(Option<(Pin<Box<dyn Future<Output=usize> + Send + Sync + 'static>>, usize)>),
13    Reading(Option<(usize, usize)>),
14}
15
16enum WriteState {
17    Idle,
18    Waiting(Option<(Pin<Box<dyn Future<Output=usize> + Send + Sync + 'static>>, usize)>),
19    Writing(Option<(usize, usize)>),
20}
21
22#[pin_project]
23pub struct LimitStream<S: AsyncRead + AsyncWrite + Unpin + Send> {
24    #[pin]
25    read: LimitRead<sfo_split::ReadHalf<S>>,
26    #[pin]
27    write: LimitWrite<sfo_split::WriteHalf<S>>,
28}
29
30impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> LimitStream<S> {
31    pub fn new(stream: S, read_limit: SpeedLimitSession, write_limit: SpeedLimitSession) -> Self {
32        let (read, write) = sfo_split::split(stream);
33        let limit_read = LimitRead::new(read, read_limit);
34        let limit_write = LimitWrite::new(write, write_limit);
35        LimitStream {
36            read: limit_read,
37            write: limit_write,
38        }
39    }
40    pub fn with_lock_raw_stream<R>(&mut self, f: impl FnOnce(Pin<&mut S>) -> R) -> R {
41        self.read.raw_read().with_lock(f)
42    }
43}
44
45impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncWrite for LimitStream<S> {
46    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
47        self.project().write.poll_write(cx, buf)
48    }
49
50    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
51        self.project().write.poll_flush(cx)
52    }
53
54    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
55        self.project().write.poll_shutdown(cx)
56    }
57}
58
59impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncRead for LimitStream<S> {
60    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
61        self.project().read.poll_read(cx, buf)
62    }
63}
64
65#[pin_project]
66pub struct LimitRead<S: AsyncRead + Unpin + Send> {
67    #[pin]
68    read: S,
69    read_limit: SpeedLimitSession,
70    read_state: ReadState,
71}
72
73impl<S: AsyncRead + Unpin + Send + 'static> LimitRead<S> {
74    pub fn new(read: S, read_limit: SpeedLimitSession) -> Self {
75        LimitRead {
76            read,
77            read_limit,
78            read_state: ReadState::Idle,
79        }
80    }
81
82    pub fn raw_read_mut(&mut self) -> &mut S {
83        &mut self.read
84    }
85
86    pub fn raw_read(&self) -> &S {
87        &self.read
88    }
89
90    pub fn into_raw_read(self) -> S {
91        self.read
92    }
93}
94
95impl<S: AsyncRead + Unpin + Send + 'static> AsyncRead for LimitRead<S> {
96    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
97        let this = self.project();
98        buf.initialize_unfilled();
99        match this.read_state {
100            ReadState::Idle => {
101                let mut readded_len = 0;
102
103                let read_limit: &'static mut SpeedLimitSession = unsafe {
104                    std::mem::transmute(this.read_limit)
105                };
106                let mut waiting_future = Box::pin(read_limit.until_ready());
107                match Pin::new(&mut waiting_future).poll(cx) {
108                    Poll::Ready(read_len) => {
109                        let mut read_buf = if read_len <= buf.remaining() {
110                            buf.take(read_len)
111                        } else {
112                            buf.take(buf.remaining())
113                        };
114                        match this.read.poll_read(cx, &mut read_buf) {
115                            Poll::Ready(Ok(())) => {
116                                let len = read_buf.filled().len();
117                                readded_len += len;
118                                buf.advance(len);
119                                if readded_len >= read_len {
120                                    *this.read_state = ReadState::Idle;
121                                } else {
122                                    *this.read_state = ReadState::Reading(Some((read_len, readded_len)));
123                                }
124                                Poll::Ready(Ok(()))
125                            },
126                            Poll::Ready(Err(e)) => {
127                                *this.read_state = ReadState::Idle;
128                                Poll::Ready(Err(e))
129                            },
130                            Poll::Pending => {
131                                *this.read_state = ReadState::Reading(Some((read_len, readded_len)));
132                                Poll::Pending
133                            }
134                        }
135                    }
136                    Poll::Pending => {
137                        *this.read_state = ReadState::Waiting(Some((waiting_future, readded_len)));
138                        Poll::Pending
139                    }
140                }
141            }
142            ReadState::Waiting(state) => {
143                let (mut rx, mut readded_len) = state.take().unwrap();
144                match Pin::new(&mut rx).poll(cx) {
145                    Poll::Ready(read_len) => {
146                        let mut read_buf = if (read_len - readded_len) <= buf.remaining() {
147                            buf.take(read_len - readded_len)
148                        } else {
149                            buf.take(buf.remaining())
150                        };
151                        match this.read.poll_read(cx, &mut read_buf) {
152                            Poll::Ready(Ok(())) => {
153                                let len = read_buf.filled().len();
154                                readded_len += len;
155                                buf.advance(len);
156                                if readded_len >= read_len {
157                                    *this.read_state = ReadState::Idle;
158                                } else {
159                                    *this.read_state = ReadState::Reading(Some((read_len, readded_len)));
160                                }
161                                Poll::Ready(Ok(()))
162                            }
163                            Poll::Ready(Err(e)) => {
164                                *this.read_state = ReadState::Idle;
165                                Poll::Ready(Err(e))
166                            },
167                            Poll::Pending => {
168                                *this.read_state = ReadState::Reading(Some((read_len, readded_len)));
169                                Poll::Pending
170                            }
171                        }
172                    }
173                    Poll::Pending => {
174                        *this.read_state = ReadState::Waiting(Some((rx, readded_len)));
175                        Poll::Pending
176                    }
177                }
178            },
179            ReadState::Reading(state) => {
180                match state.take() {
181                    Some((read_len, mut readded_len)) => {
182                        let mut read_buf = if (read_len - readded_len) <= buf.remaining() {
183                            buf.take(read_len - readded_len)
184                        } else {
185                            buf.take(buf.remaining())
186                        };
187                        match this.read.poll_read(cx, &mut read_buf) {
188                            Poll::Ready(Ok(())) => {
189                                let len = read_buf.filled().len();
190                                readded_len += len;
191                                buf.advance(len);
192                                if readded_len >= read_len {
193                                    *this.read_state = ReadState::Idle;
194                                } else {
195                                    *this.read_state = ReadState::Reading(Some((read_len, readded_len)));
196                                }
197                                Poll::Ready(Ok(()))
198                            }
199                            Poll::Ready(Err(e)) => {
200                                *this.read_state = ReadState::Idle;
201                                Poll::Ready(Err(e))
202                            },
203                            Poll::Pending => {
204                                *this.read_state = ReadState::Reading(Some((read_len, readded_len)));
205                                Poll::Pending
206                            }
207                        }
208                    },
209                    None => {
210                        match this.read.poll_read(cx, buf) {
211                            Poll::Ready(Ok(())) => {
212                                *this.read_state = ReadState::Idle;
213                                Poll::Ready(Ok(()))
214                            },
215                            Poll::Ready(Err(e)) => {
216                                *this.read_state = ReadState::Idle;
217                                Poll::Ready(Err(e))
218                            },
219                            Poll::Pending => {
220                                *this.read_state = ReadState::Reading(None);
221                                Poll::Pending
222                            }
223                        }
224                    }
225                }
226
227            }
228        }
229    }
230}
231
232#[pin_project]
233pub struct LimitWrite<S: AsyncWrite + Unpin + Send> {
234    #[pin]
235    write: S,
236    write_limit: SpeedLimitSession,
237    write_state: WriteState,
238}
239
240impl<S: AsyncWrite + Unpin + Send + 'static> LimitWrite<S> {
241    pub fn new(write: S, write_limit: SpeedLimitSession) -> Self {
242        LimitWrite {
243            write,
244            write_limit,
245            write_state: WriteState::Idle,
246        }
247    }
248    pub fn raw_write_mut(&mut self) -> &mut S {
249        &mut self.write
250    }
251
252    pub fn raw_write(&self) -> &S {
253        &self.write
254    }
255
256    pub fn into_raw_write(self) -> S {
257        self.write
258    }
259}
260
261impl<S: AsyncWrite + Unpin + Send + 'static> AsyncWrite for LimitWrite<S> {
262    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
263        let this = self.project();
264        match this.write_state {
265            WriteState::Idle => {
266                let mut written_len = 0;
267                let write_limiter: &'static mut SpeedLimitSession = unsafe {
268                    std::mem::transmute(this.write_limit)
269                };
270                let mut waiting_future = Box::pin(write_limiter.until_ready());
271                match Pin::new(&mut waiting_future).poll(cx) {
272                    Poll::Ready(write_len) => {
273                        let write_buf = if write_len <= buf.len() {
274                            &buf[..write_len]
275                        } else {
276                            buf
277                        };
278                        match this.write.poll_write(cx, write_buf) {
279                            Poll::Ready(Ok(len)) => {
280                                written_len += len;
281                                if written_len >= write_len {
282                                    *this.write_state = WriteState::Idle;
283                                } else {
284                                    *this.write_state = WriteState::Writing(Some((write_len, written_len)));
285                                }
286                                Poll::Ready(Ok(written_len))
287                            }
288                            Poll::Ready(Err(e)) => {
289                                *this.write_state = WriteState::Idle;
290                                Poll::Ready(Err(e))
291                            }
292                            Poll::Pending => {
293                                *this.write_state = WriteState::Writing(Some((write_len, written_len)));
294                                Poll::Pending
295                            }
296                        }
297                    }
298                    Poll::Pending => {
299                        *this.write_state = WriteState::Waiting(Some((waiting_future, written_len)));
300                        Poll::Pending
301                    }
302                }
303            }
304            WriteState::Waiting(state) => {
305                let (mut waiting_future, mut written_len) = state.take().unwrap();
306                match Pin::new(&mut waiting_future).poll(cx) {
307                    Poll::Ready(write_len) => {
308                        let write_buf = if write_len - written_len <= buf.len() {
309                            &buf[..(write_len - written_len)]
310                        } else {
311                            buf
312                        };
313                        match this.write.poll_write(cx, write_buf) {
314                            Poll::Ready(Ok(len)) => {
315                                written_len += len;
316                                if written_len >= write_len {
317                                    *this.write_state = WriteState::Idle;
318                                } else {
319                                    *this.write_state = WriteState::Writing(Some((write_len, written_len)));
320                                }
321                                Poll::Ready(Ok(len))
322                            },
323                            Poll::Ready(Err(e)) => {
324                                *this.write_state = WriteState::Idle;
325                                Poll::Ready(Err(e))
326                            },
327                            Poll::Pending => {
328                                *this.write_state = WriteState::Writing(Some((write_len, written_len)));
329                                Poll::Pending
330                            }
331                        }
332                    }
333                    Poll::Pending => {
334                        *this.write_state = WriteState::Waiting(Some((waiting_future, written_len)));
335                        Poll::Pending
336                    }
337                }
338            }
339            WriteState::Writing(state) => {
340                match state.take() {
341                    Some((write_len, mut written_len)) => {
342                        let write_buf = if write_len - written_len <= buf.len() {
343                            &buf[..(write_len - written_len)]
344                        } else {
345                            buf
346                        };
347                        match this.write.poll_write(cx, write_buf) {
348                            Poll::Ready(Ok(len)) => {
349                                written_len += len;
350                                if written_len >= write_len {
351                                    *this.write_state = WriteState::Idle;
352                                } else {
353                                    *this.write_state = WriteState::Writing(Some((write_len, written_len)));
354                                }
355                                Poll::Ready(Ok(len))
356                            },
357                            Poll::Ready(Err(e)) => {
358                                *this.write_state = WriteState::Idle;
359                                Poll::Ready(Err(e))
360                            },
361                            Poll::Pending => {
362                                *this.write_state = WriteState::Writing(Some((write_len, written_len)));
363                                Poll::Pending
364                            }
365                        }
366                    },
367                    None => {
368                        match this.write.poll_write(cx, buf) {
369                            Poll::Ready(Ok(len)) => {
370                                *this.write_state = WriteState::Idle;
371                                Poll::Ready(Ok(len))
372                            },
373                            Poll::Ready(Err(e)) => {
374                                *this.write_state = WriteState::Idle;
375                                Poll::Ready(Err(e))
376                            },
377                            Poll::Pending => {
378                                *this.write_state = WriteState::Writing(None);
379                                Poll::Pending
380                            }
381                        }
382                    }
383                }
384            }
385        }
386    }
387
388    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
389        self.project().write.poll_flush(cx)
390    }
391
392    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
393        self.project().write.poll_shutdown(cx)
394    }
395}
396
397#[cfg(test)]
398#[cfg_attr(coverage_nightly, coverage(off))]
399mod tests {
400    use std::future::poll_fn;
401    use super::*;
402    use std::io::{Error, ErrorKind};
403    use std::pin::Pin;
404    use std::task::{Context, Poll};
405    use std::time::{Duration, Instant};
406    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
407    use futures::task::noop_waker;
408    use std::num::NonZeroU32;
409
410    // Mock stream implementation for testing
411    struct MockStream {
412        read_data: Vec<u8>,
413        read_pos: usize,
414        read_should_pending: bool,
415        read_error: Option<Error>,
416        write_should_pending: bool,
417        write_error: Option<Error>,
418        written_data: Vec<u8>,
419    }
420
421    impl MockStream {
422        fn new(read_data: Vec<u8>) -> Self {
423            Self {
424                read_data,
425                read_pos: 0,
426                read_should_pending: false,
427                read_error: None,
428                write_should_pending: false,
429                write_error: None,
430                written_data: Vec::new(),
431            }
432        }
433
434        fn with_read_pending(mut self) -> Self {
435            self.read_should_pending = true;
436            self
437        }
438
439        fn with_read_error(mut self, error: Error) -> Self {
440            self.read_error = Some(error);
441            self
442        }
443
444        fn with_write_pending(mut self) -> Self {
445            self.write_should_pending = true;
446            self
447        }
448
449        fn with_write_error(mut self, error: Error) -> Self {
450            self.write_error = Some(error);
451            self
452        }
453
454        fn written_data(&self) -> &[u8] {
455            &self.written_data
456        }
457    }
458
459    impl AsyncRead for MockStream {
460        fn poll_read(
461            mut self: Pin<&mut Self>,
462            _cx: &mut Context<'_>,
463            buf: &mut ReadBuf<'_>
464        ) -> Poll<std::io::Result<()>> {
465            if let Some(error) = self.read_error.take() {
466                return Poll::Ready(Err(error));
467            }
468
469            if self.read_should_pending {
470                return Poll::Pending;
471            }
472
473            let remaining = self.read_data.len() - self.read_pos;
474            if remaining == 0 {
475                return Poll::Ready(Ok(()));
476            }
477
478            let to_copy = std::cmp::min(remaining, buf.remaining());
479            buf.put_slice(&self.read_data[self.read_pos..self.read_pos + to_copy]);
480            self.read_pos += to_copy;
481
482            Poll::Ready(Ok(()))
483        }
484    }
485
486    impl AsyncWrite for MockStream {
487        fn poll_write(
488            mut self: Pin<&mut Self>,
489            _cx: &mut Context<'_>,
490            buf: &[u8]
491        ) -> Poll<Result<usize, Error>> {
492            if let Some(error) = self.write_error.take() {
493                return Poll::Ready(Err(error));
494            }
495
496            if self.write_should_pending {
497                return Poll::Pending;
498            }
499
500            self.written_data.extend_from_slice(buf);
501            Poll::Ready(Ok(buf.len()))
502        }
503
504        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
505            Poll::Ready(Ok(()))
506        }
507
508        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
509            Poll::Ready(Ok(()))
510        }
511    }
512
513    #[tokio::test]
514    async fn test_read_without_limit() {
515        let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
516        let read_limit = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap())).new_limit_session();
517        let write_limit = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap())).new_limit_session();
518        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
519
520        let mut buffer = [0u8; 10];
521        let mut read_buf = ReadBuf::new(&mut buffer);
522        let waker = noop_waker();
523        let mut cx = Context::from_waker(&waker);
524
525        // 测试无限制读取
526        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
527        assert!(result.is_ready());
528        assert_eq!(read_buf.filled(), &[1, 2, 3, 4, 5]);
529    }
530
531    #[tokio::test]
532    async fn test_read_without_limit1() {
533        let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]).with_read_pending();
534        let read_limit = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap())).new_limit_session();
535        let write_limit = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap())).new_limit_session();
536        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
537
538        let mut buffer = [0u8; 3];
539        let mut read_buf = ReadBuf::new(&mut buffer);
540        let waker = noop_waker();
541        let mut cx = Context::from_waker(&waker);
542
543        // 测试无限制读取
544        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
545        assert!(result.is_pending());
546        limit_stream.with_lock_raw_stream(|stream| {
547            stream.get_mut().read_should_pending = false;
548        });
549        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
550        assert!(result.is_ready());
551        assert_eq!(read_buf.filled(), &[1, 2, 3]);
552        let mut read_buf = ReadBuf::new(&mut buffer);
553        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
554        assert!(result.is_ready());
555        assert_eq!(read_buf.filled(), &[4, 5]);
556    }
557
558    #[tokio::test]
559    async fn test_read_without_limit2() {
560        let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]).with_read_pending();
561        let read_limit = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap())).new_limit_session();
562        let write_limit = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap())).new_limit_session();
563        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
564
565        let mut buffer = [0u8; 3];
566        let mut read_buf = ReadBuf::new(&mut buffer);
567        let waker = noop_waker();
568        let mut cx = Context::from_waker(&waker);
569
570        // 测试无限制读取
571        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
572        assert!(result.is_pending());
573        limit_stream.with_lock_raw_stream(|stream| {
574            let stream = stream.get_mut();
575            stream.read_should_pending = false;
576            let error = Error::new(ErrorKind::Other, "read error");
577            stream.read_error = Some(error);
578        });
579        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
580        assert!(result.is_ready());
581
582        if let Poll::Ready(ret) = result {
583            assert!(ret.is_err());
584        }
585    }
586
587    #[tokio::test]
588    async fn test_read_without_limit_err() {
589        let error = Error::new(ErrorKind::Other, "read error");
590        let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]).with_read_error(error);
591        let read_limit = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap())).new_limit_session();
592        let write_limit = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap())).new_limit_session();
593        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
594
595        let mut buffer = [0u8; 10];
596        let mut read_buf = ReadBuf::new(&mut buffer);
597        let waker = noop_waker();
598        let mut cx = Context::from_waker(&waker);
599
600        // 测试无限制读取
601        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
602        assert!(result.is_ready());
603        match result {
604            Poll::Ready(ret) => assert!(ret.is_err()),
605            Poll::Pending => panic!("Expected ready"),
606        }
607    }
608
609    #[tokio::test]
610    async fn test_read_with_limit() {
611        let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
612        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
613        let read_limit = read_limiter.new_limit_session();
614        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
615        let write_limit = write_limiter.new_limit_session();
616        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
617
618        let mut buffer = [0u8; 10];
619        let mut read_buf = ReadBuf::new(&mut buffer);
620        let waker = noop_waker();
621        let mut cx = Context::from_waker(&waker);
622
623        // 第一次读取应该等待令牌
624        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
625        assert!(result.is_ready());
626        assert_eq!(read_buf.filled(), &[1]);
627
628        let start = Instant::now();
629        let mut read_buf = ReadBuf::new(&mut buffer);
630        let result = poll_fn(|cx| {
631            Pin::new(&mut limit_stream).poll_read(cx, &mut read_buf)
632        }).await;
633        assert!(start.elapsed() >= Duration::from_millis(900));
634        assert!(result.is_ok());
635        assert_eq!(read_buf.filled(), &[2]);
636
637        let mut read_buf = ReadBuf::new(&mut buffer);
638        let result = poll_fn(|cx| {
639            Pin::new(&mut limit_stream).poll_read(cx, &mut read_buf)
640        }).await;
641        assert!(start.elapsed() >= Duration::from_millis(1900));
642        assert!(result.is_ok());
643        assert_eq!(read_buf.filled(), &[3]);
644    }
645
646    #[tokio::test]
647    async fn test_read_with_limit2() {
648        let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
649        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(2).unwrap()));
650        let read_limit = read_limiter.new_limit_session();
651        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
652        let write_limit = write_limiter.new_limit_session();
653        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
654
655        let mut buffer = [0u8; 10];
656        let mut read_buf = ReadBuf::new(&mut buffer);
657        let waker = noop_waker();
658        let mut cx = Context::from_waker(&waker);
659
660        // 第一次读取应该等待令牌
661        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
662        assert!(result.is_ready());
663        assert_eq!(read_buf.filled(), &[1, 2]);
664
665        let start = Instant::now();
666        let mut read_buf = ReadBuf::new(&mut buffer);
667        let result = poll_fn(|cx| {
668            Pin::new(&mut limit_stream).poll_read(cx, &mut read_buf)
669        }).await;
670        assert!(start.elapsed() >= Duration::from_millis(900));
671        assert!(result.is_ok());
672        assert_eq!(read_buf.filled(), &[3, 4]);
673
674        let mut read_buf = ReadBuf::new(&mut buffer);
675        let result = poll_fn(|cx| {
676            Pin::new(&mut limit_stream).poll_read(cx, &mut read_buf)
677        }).await;
678        assert!(start.elapsed() >= Duration::from_millis(1900));
679        assert!(result.is_ok());
680        assert_eq!(read_buf.filled(), &[5]);
681    }
682
683    #[tokio::test]
684    async fn test_read_with_limit3() {
685        let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
686        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(2).unwrap()));
687        let read_limit = read_limiter.new_limit_session();
688        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
689        let write_limit = write_limiter.new_limit_session();
690        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
691
692        let mut buffer = [0u8; 1];
693        let mut read_buf = ReadBuf::new(&mut buffer);
694        let waker = noop_waker();
695        let mut cx = Context::from_waker(&waker);
696
697        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
698        assert!(result.is_ready());
699        assert_eq!(read_buf.filled(), &[1]);
700
701        let mut read_buf = ReadBuf::new(&mut buffer);
702        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
703        assert!(result.is_ready());
704        assert_eq!(read_buf.filled(), &[2]);
705
706        let start = Instant::now();
707        let mut read_buf = ReadBuf::new(&mut buffer);
708        let result = poll_fn(|cx| {
709            Pin::new(&mut limit_stream).poll_read(cx, &mut read_buf)
710        }).await;
711        assert!(start.elapsed() >= Duration::from_millis(900));
712        assert!(result.is_ok());
713        assert_eq!(read_buf.filled(), &[3]);
714
715        let mut read_buf = ReadBuf::new(&mut buffer);
716        let result = poll_fn(|cx| {
717            Pin::new(&mut limit_stream).poll_read(cx, &mut read_buf)
718        }).await;
719        assert!(start.elapsed() < Duration::from_millis(1100));
720        assert!(result.is_ok());
721        assert_eq!(read_buf.filled(), &[4]);
722    }
723
724    #[tokio::test]
725    async fn test_read_with_limit4() {
726        let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
727        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(2).unwrap()), Some(NonZeroU32::new(1).unwrap()));
728        let read_limit = read_limiter.new_limit_session();
729        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
730        let write_limit = write_limiter.new_limit_session();
731        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
732
733        let mut buffer = [0u8; 1];
734        let mut read_buf = ReadBuf::new(&mut buffer);
735        let waker = noop_waker();
736        let mut cx = Context::from_waker(&waker);
737
738        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
739        assert!(result.is_ready());
740        assert_eq!(read_buf.filled(), &[1]);
741
742        let mut read_buf = ReadBuf::new(&mut buffer);
743        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
744        assert!(result.is_ready());
745        assert_eq!(read_buf.filled(), &[2]);
746
747        let mut read_buf = ReadBuf::new(&mut buffer);
748        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
749        assert!(result.is_pending());
750        tokio::time::sleep(Duration::from_millis(600)).await;
751        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
752        assert!(result.is_ready());
753        assert_eq!(read_buf.filled(), &[3]);
754
755    }
756
757    #[tokio::test]
758    async fn test_read_with_limit5() {
759        let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
760        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(2).unwrap()), Some(NonZeroU32::new(1).unwrap()));
761        let read_limit = read_limiter.new_limit_session();
762        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
763        let write_limit = write_limiter.new_limit_session();
764        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
765
766        let mut buffer = [0u8; 1];
767        let mut read_buf = ReadBuf::new(&mut buffer);
768        let waker = noop_waker();
769        let mut cx = Context::from_waker(&waker);
770
771        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
772        assert!(result.is_ready());
773        assert_eq!(read_buf.filled(), &[1]);
774        let mut read_buf = ReadBuf::new(&mut buffer);
775        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
776        assert!(result.is_ready());
777        assert_eq!(read_buf.filled(), &[2]);
778
779        let mut read_buf = ReadBuf::new(&mut buffer);
780        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
781        assert!(result.is_pending());
782        tokio::time::sleep(Duration::from_millis(1100)).await;
783        limit_stream.with_lock_raw_stream(|stream| {
784            stream.get_mut().read_should_pending = true;
785        });
786        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
787        assert!(result.is_pending());
788        limit_stream.with_lock_raw_stream(|stream| {
789            stream.get_mut().read_should_pending = false;
790        });
791        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
792        assert!(result.is_ready());
793        assert_eq!(read_buf.filled(), &[3]);
794        limit_stream.with_lock_raw_stream(|stream| {
795            let error = Error::new(ErrorKind::Other, "read error");
796            stream.get_mut().read_error = Some(error);
797        });
798        let mut read_buf = ReadBuf::new(&mut buffer);
799        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
800        assert!(result.is_ready());
801
802        if let Poll::Ready(Err(e)) = result {
803            assert_eq!(e.kind(), ErrorKind::Other);
804        }
805    }
806
807    #[tokio::test]
808    async fn test_read_with_limit6() {
809        let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
810        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(2).unwrap()), Some(NonZeroU32::new(1).unwrap()));
811        let read_limit = read_limiter.new_limit_session();
812        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
813        let write_limit = write_limiter.new_limit_session();
814        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
815
816        let mut buffer = [0u8; 1];
817        let mut read_buf = ReadBuf::new(&mut buffer);
818        let waker = noop_waker();
819        let mut cx = Context::from_waker(&waker);
820
821        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
822        assert!(result.is_ready());
823        assert_eq!(read_buf.filled(), &[1]);
824
825        let mut read_buf = ReadBuf::new(&mut buffer);
826        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
827        assert!(result.is_ready());
828        assert_eq!(read_buf.filled(), &[2]);
829
830        let mut read_buf = ReadBuf::new(&mut buffer);
831        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
832        assert!(result.is_pending());
833        tokio::time::sleep(Duration::from_millis(600)).await;
834        limit_stream.with_lock_raw_stream(|stream| {
835            let error = Error::new(ErrorKind::Other, "read error");
836            stream.get_mut().read_error = Some(error);
837        });
838        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
839        assert!(result.is_ready());
840
841        if let Poll::Ready(Err(e)) = result {
842            assert_eq!(e.kind(), ErrorKind::Other);
843        }
844    }
845
846    #[tokio::test]
847    async fn test_write_without_limit() {
848        let mock_stream = MockStream::new(vec![]);
849        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
850        let read_limit = read_limiter.new_limit_session();
851        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap()));
852        let write_limit = write_limiter.new_limit_session();
853        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
854
855        let data = [1, 2, 3, 4, 5];
856        let waker = noop_waker();
857        let mut cx = Context::from_waker(&waker);
858
859        // 测试无限制写入
860        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
861        assert!(result.is_ready());
862
863        if let Poll::Ready(Ok(written)) = result {
864            assert_eq!(written, 5);
865        }
866    }
867
868    #[tokio::test]
869    async fn test_write_without_limit2() {
870        let error = Error::new(ErrorKind::Other, "write error");
871        let mock_stream = MockStream::new(vec![]).with_write_error(error);
872        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
873        let read_limit = read_limiter.new_limit_session();
874        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap()));
875        let write_limit = write_limiter.new_limit_session();
876        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
877
878        let data = [1, 2, 3, 4, 5];
879        let waker = noop_waker();
880        let mut cx = Context::from_waker(&waker);
881
882        // 测试无限制写入
883        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
884        assert!(result.is_ready());
885
886        if let Poll::Ready(ret) = result {
887            assert!(ret.is_err());
888            if let Err(e) = ret {
889                assert_eq!(e.kind(), ErrorKind::Other);
890            }
891        }
892    }
893
894    #[tokio::test]
895    async fn test_write_without_limit3() {
896        let mock_stream = MockStream::new(vec![]).with_write_pending();
897        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
898        let read_limit = read_limiter.new_limit_session();
899        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap()));
900        let write_limit = write_limiter.new_limit_session();
901        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
902
903        let data = [1, 2, 3, 4, 5];
904        let waker = noop_waker();
905        let mut cx = Context::from_waker(&waker);
906
907        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
908        assert!(result.is_pending());
909
910        limit_stream.with_lock_raw_stream(|stream| {
911            stream.get_mut().write_should_pending = false;
912        });
913
914        // 测试无限制写入
915        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
916        assert!(result.is_ready());
917
918        if let Poll::Ready(Ok(written)) = result {
919            assert_eq!(written, 5);
920        }
921    }
922
923    #[tokio::test]
924    async fn test_write_without_limit4() {
925        let mock_stream = MockStream::new(vec![]).with_write_pending();
926        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
927        let read_limit = read_limiter.new_limit_session();
928        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap()));
929        let write_limit = write_limiter.new_limit_session();
930        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
931
932        let data = [1, 2, 3, 4, 5];
933        let waker = noop_waker();
934        let mut cx = Context::from_waker(&waker);
935
936        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
937        assert!(result.is_pending());
938
939        limit_stream.with_lock_raw_stream(|stream| {
940            let stream = stream.get_mut();
941            stream.write_should_pending = false;
942            let ererror = Error::new(ErrorKind::Other, "write error");
943            stream.write_error = Some(ererror);
944        });
945
946        // 测试无限制写入
947        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
948        assert!(result.is_ready());
949
950        if let Poll::Ready(ret) = result {
951            assert!(ret.is_err());
952            if let Err(e) = ret {
953                assert_eq!(e.kind(), ErrorKind::Other);
954            }
955        }
956    }
957
958    #[tokio::test]
959    async fn test_write_with_limit() {
960        let mock_stream = MockStream::new(vec![]);
961        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
962        let read_limit = read_limiter.new_limit_session();
963        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
964        let write_limit = write_limiter.new_limit_session();
965        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
966
967        let data = [1, 2, 3, 4, 5];
968        let waker = noop_waker();
969        let mut cx = Context::from_waker(&waker);
970
971        // 第一次写入应该等待令牌
972        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
973        // 由于使用了实际的SpeedLimiter,可能返回Pending或Ready
974        assert!(result.is_ready());
975        if let Poll::Ready(Ok(written)) = result {
976            assert_eq!(written, 1);
977        }
978
979        // 第一次写入应该等待令牌
980        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
981        assert!(result.is_pending());
982        let start = Instant::now();
983        let result = poll_fn(|cx| {
984            Pin::new(&mut limit_stream).poll_write(cx, &data)
985        }).await;
986        assert!(start.elapsed() >= Duration::from_millis(900));
987        assert!(result.is_ok());
988        assert_eq!(result.unwrap(), 1);
989
990        let result = poll_fn(|cx| {
991            Pin::new(&mut limit_stream).poll_write(cx, &data)
992        }).await;
993        assert!(start.elapsed() >= Duration::from_millis(1900));
994        assert!(result.is_ok());
995        assert_eq!(result.unwrap(), 1);
996    }
997
998    #[tokio::test]
999    async fn test_write_with_limit1() {
1000        let mock_stream = MockStream::new(vec![]).with_write_pending();
1001        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1002        let read_limit = read_limiter.new_limit_session();
1003        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1004        let write_limit = write_limiter.new_limit_session();
1005        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1006
1007        let data = [1, 2, 3, 4, 5];
1008        let waker = noop_waker();
1009        let mut cx = Context::from_waker(&waker);
1010
1011        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1012        // 由于使用了实际的SpeedLimiter,可能返回Pending或Ready
1013        assert!(result.is_pending());
1014        limit_stream.with_lock_raw_stream(|stream| {
1015            stream.get_mut().write_should_pending = false;
1016        });
1017
1018        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1019        assert!(result.is_ready());
1020        if let Poll::Ready(ret) = result {
1021            assert!(ret.is_ok());
1022            assert_eq!(ret.unwrap(), 1);
1023        }
1024
1025        tokio::time::sleep(Duration::from_millis(1100)).await;
1026        limit_stream.with_lock_raw_stream(|stream| {
1027            stream.get_mut().write_should_pending = true;
1028        });
1029
1030        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1031        assert!(result.is_pending());
1032
1033        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1034        assert!(result.is_pending());
1035
1036        limit_stream.with_lock_raw_stream(|stream| {
1037            let stream = stream.get_mut();
1038            stream.write_should_pending = false;
1039            let ererror = Error::new(ErrorKind::Other, "write error");
1040            stream.write_error = Some(ererror);
1041        });
1042
1043        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1044        assert!(result.is_ready());
1045        if let Poll::Ready(Err(e)) = result {
1046            assert_eq!(e.kind(), ErrorKind::Other);
1047        }
1048    }
1049
1050    #[tokio::test]
1051    async fn test_write_with_limit2() {
1052        let mock_stream = MockStream::new(vec![]).with_write_pending();
1053        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1054        let read_limit = read_limiter.new_limit_session();
1055        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(2).unwrap()));
1056        let write_limit = write_limiter.new_limit_session();
1057        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1058
1059        let data = [1];
1060        let waker = noop_waker();
1061        let mut cx = Context::from_waker(&waker);
1062
1063        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1064        // 由于使用了实际的SpeedLimiter,可能返回Pending或Ready
1065        assert!(result.is_pending());
1066        limit_stream.with_lock_raw_stream(|stream| {
1067            let stream = stream.get_mut();
1068            stream.write_should_pending = false;
1069        });
1070
1071        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1072        assert!(result.is_ready());
1073        if let Poll::Ready(ret) = result {
1074            assert!(ret.is_ok());
1075            assert_eq!(ret.unwrap(), 1);
1076        }
1077        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1078        assert!(result.is_ready());
1079        if let Poll::Ready(ret) = result {
1080            assert!(ret.is_ok());
1081            assert_eq!(ret.unwrap(), 1);
1082        }
1083    }
1084
1085    #[tokio::test]
1086    async fn test_write_with_limit3() {
1087        let mock_stream = MockStream::new(vec![]);
1088        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1089        let read_limit = read_limiter.new_limit_session();
1090        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(2).unwrap()));
1091        let write_limit = write_limiter.new_limit_session();
1092        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1093
1094        let data = [1];
1095        let waker = noop_waker();
1096        let mut cx = Context::from_waker(&waker);
1097
1098        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1099        assert!(result.is_ready());
1100        if let Poll::Ready(ret) = result {
1101            assert!(ret.is_ok());
1102            assert_eq!(ret.unwrap(), 1);
1103        }
1104
1105        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1106        assert!(result.is_ready());
1107        if let Poll::Ready(ret) = result {
1108            assert!(ret.is_ok());
1109            assert_eq!(ret.unwrap(), 1);
1110        }
1111
1112        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1113        // 由于使用了实际的SpeedLimiter,可能返回Pending或Ready
1114        assert!(result.is_pending());
1115        limit_stream.with_lock_raw_stream(|stream| {
1116            let stream = stream.get_mut();
1117            stream.write_should_pending = false;
1118        });
1119
1120        tokio::time::sleep(Duration::from_millis(1100)).await;
1121        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1122        assert!(result.is_ready());
1123        if let Poll::Ready(ret) = result {
1124            assert!(ret.is_ok());
1125            assert_eq!(ret.unwrap(), 1);
1126        }
1127        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1128        assert!(result.is_ready());
1129        if let Poll::Ready(ret) = result {
1130            assert!(ret.is_ok());
1131            assert_eq!(ret.unwrap(), 1);
1132        }
1133    }
1134
1135    #[tokio::test]
1136    async fn test_write_with_limit4() {
1137        let mock_stream = MockStream::new(vec![]);
1138        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1139        let read_limit = read_limiter.new_limit_session();
1140        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(2).unwrap()));
1141        let write_limit = write_limiter.new_limit_session();
1142        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1143
1144        let data = [1];
1145        let waker = noop_waker();
1146        let mut cx = Context::from_waker(&waker);
1147
1148        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1149        assert!(result.is_ready());
1150        if let Poll::Ready(ret) = result {
1151            assert!(ret.is_ok());
1152            assert_eq!(ret.unwrap(), 1);
1153        }
1154
1155        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1156        assert!(result.is_ready());
1157        if let Poll::Ready(ret) = result {
1158            assert!(ret.is_ok());
1159            assert_eq!(ret.unwrap(), 1);
1160        }
1161
1162        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1163        assert!(result.is_pending());
1164
1165        tokio::time::sleep(Duration::from_millis(1100)).await;
1166        limit_stream.with_lock_raw_stream(|stream| {
1167            let stream = stream.get_mut();
1168            stream.write_should_pending = true;
1169        });
1170        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1171        assert!(result.is_pending());
1172
1173        limit_stream.with_lock_raw_stream(|stream| {
1174            let stream = stream.get_mut();
1175            stream.write_should_pending = false;
1176            let ererror = Error::new(ErrorKind::Other, "write error");
1177            stream.write_error = Some(ererror);
1178        });
1179        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1180        assert!(result.is_ready());
1181        if let Poll::Ready(ret) = result {
1182            assert!(ret.is_err());
1183        }
1184    }
1185
1186    #[tokio::test]
1187    async fn test_write_with_limit5() {
1188        let mock_stream = MockStream::new(vec![]);
1189        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1190        let read_limit = read_limiter.new_limit_session();
1191        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(2).unwrap()));
1192        let write_limit = write_limiter.new_limit_session();
1193        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1194
1195        let data = [1];
1196        let waker = noop_waker();
1197        let mut cx = Context::from_waker(&waker);
1198
1199        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1200        assert!(result.is_ready());
1201        if let Poll::Ready(ret) = result {
1202            assert!(ret.is_ok());
1203            assert_eq!(ret.unwrap(), 1);
1204        }
1205
1206        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1207        assert!(result.is_ready());
1208        if let Poll::Ready(ret) = result {
1209            assert!(ret.is_ok());
1210            assert_eq!(ret.unwrap(), 1);
1211        }
1212
1213        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1214        assert!(result.is_pending());
1215
1216        tokio::time::sleep(Duration::from_millis(1100)).await;
1217
1218        limit_stream.with_lock_raw_stream(|stream| {
1219            let stream = stream.get_mut();
1220            let ererror = Error::new(ErrorKind::Other, "write error");
1221            stream.write_error = Some(ererror);
1222        });
1223
1224        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1225        assert!(result.is_ready());
1226        if let Poll::Ready(ret) = result {
1227            assert!(ret.is_err());
1228        }
1229    }
1230
1231    #[tokio::test]
1232    async fn test_read_error_propagation() {
1233        let error = Error::new(ErrorKind::Other, "read error");
1234        let mock_stream = MockStream::new(vec![]).with_read_error(error);
1235        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1236        let read_limit = read_limiter.new_limit_session();
1237        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(10).unwrap()), Some(NonZeroU32::new(10).unwrap()));
1238        let write_limit = write_limiter.new_limit_session();
1239        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1240
1241        let mut buffer = [0u8; 10];
1242        let mut read_buf = ReadBuf::new(&mut buffer);
1243        let waker = noop_waker();
1244        let mut cx = Context::from_waker(&waker);
1245
1246        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1247        assert!(result.is_ready());
1248
1249        if let Poll::Ready(Err(e)) = result {
1250            assert_eq!(e.kind(), ErrorKind::Other);
1251        }
1252    }
1253
1254    #[tokio::test]
1255    async fn test_write_error_propagation() {
1256        let error = Error::new(ErrorKind::Other, "write error");
1257        let mock_stream = MockStream::new(vec![]).with_write_error(error);
1258        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1259        let read_limit = read_limiter.new_limit_session();
1260        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(10).unwrap()), Some(NonZeroU32::new(10).unwrap()));
1261        let write_limit = write_limiter.new_limit_session();
1262        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1263
1264        let data = [1, 2, 3, 4, 5];
1265        let waker = noop_waker();
1266        let mut cx = Context::from_waker(&waker);
1267
1268        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1269        assert!(result.is_ready());
1270
1271        if let Poll::Ready(Err(e)) = result {
1272            assert_eq!(e.kind(), ErrorKind::Other);
1273        }
1274    }
1275
1276    #[tokio::test]
1277    async fn test_read_limit_pending_handling() {
1278        let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]).with_read_pending();
1279        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1280        let read_limit = read_limiter.new_limit_session();
1281        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1282        let write_limit = write_limiter.new_limit_session();
1283        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1284
1285        let mut buffer = [0u8; 10];
1286        let mut read_buf = ReadBuf::new(&mut buffer);
1287        let waker = noop_waker();
1288        let mut cx = Context::from_waker(&waker);
1289
1290        // 第一次应该返回Pending
1291        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1292        assert!(result.is_pending());
1293        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1294        assert!(result.is_pending());
1295    }
1296
1297    #[tokio::test]
1298    async fn test_read_limit_pending_handling2() {
1299        let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
1300        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1301        let read_limit = read_limiter.new_limit_session();
1302        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1303        let write_limit = write_limiter.new_limit_session();
1304        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1305
1306        let mut buffer = [0u8; 1];
1307        let mut read_buf = ReadBuf::new(&mut buffer);
1308        let waker = noop_waker();
1309        let mut cx = Context::from_waker(&waker);
1310
1311        // 第一次应该返回Pending
1312        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1313        assert!(result.is_ready());
1314        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1315        assert!(result.is_pending());
1316        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1317        assert!(result.is_pending());
1318    }
1319
1320    #[tokio::test]
1321    async fn test_read_pending_handling() {
1322        let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]).with_read_pending();
1323        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1324        let read_limit = read_limiter.new_limit_session();
1325        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
1326        let write_limit = write_limiter.new_limit_session();
1327        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1328
1329        let mut buffer = [0u8; 10];
1330        let mut read_buf = ReadBuf::new(&mut buffer);
1331        let waker = noop_waker();
1332        let mut cx = Context::from_waker(&waker);
1333
1334        // 第一次应该返回Pending
1335        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1336        assert!(result.is_pending());
1337        let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1338        assert!(result.is_pending());
1339    }
1340
1341    #[tokio::test]
1342    async fn test_write_pending_handling() {
1343        let mock_stream = MockStream::new(vec![]).with_write_pending();
1344        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1345        let read_limit = read_limiter.new_limit_session();
1346        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
1347        let write_limit = write_limiter.new_limit_session();
1348        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1349
1350        let data = [1, 2, 3, 4, 5];
1351        let waker = noop_waker();
1352        let mut cx = Context::from_waker(&waker);
1353
1354        // 第一次应该返回Pending
1355        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1356        assert!(result.is_pending());
1357        let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1358        assert!(result.is_pending());
1359    }
1360
1361    #[tokio::test]
1362    async fn test_flush_and_shutdown() {
1363        let mock_stream = MockStream::new(vec![]);
1364        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1365        let read_limit = read_limiter.new_limit_session();
1366        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
1367        let write_limit = write_limiter.new_limit_session();
1368        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1369        let waker = noop_waker();
1370        let mut cx = Context::from_waker(&waker);
1371
1372        // 测试flush
1373        let flush_result = Pin::new(&mut limit_stream).poll_flush(&mut cx);
1374        assert!(flush_result.is_ready());
1375
1376        // 测试shutdown
1377        let shutdown_result = Pin::new(&mut limit_stream).poll_shutdown(&mut cx);
1378        assert!(shutdown_result.is_ready());
1379    }
1380
1381    #[tokio::test]
1382    async fn test_mixed_read_write() {
1383        let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
1384        let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(10).unwrap()), Some(NonZeroU32::new(10).unwrap()));
1385        let read_limit = read_limiter.new_limit_session();
1386        let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(10).unwrap()), Some(NonZeroU32::new(10).unwrap()));
1387        let write_limit = write_limiter.new_limit_session();
1388        let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1389
1390        let waker = noop_waker();
1391        let mut cx = Context::from_waker(&waker);
1392
1393        // 先写入数据
1394        let write_data = [6, 7, 8, 9, 10];
1395        let write_result = Pin::new(&mut limit_stream).poll_write(&mut cx, &write_data);
1396        assert!(write_result.is_ready());
1397
1398        // 再读取数据
1399        let mut buffer = [0u8; 10];
1400        let mut read_buf = ReadBuf::new(&mut buffer);
1401        let read_result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1402        assert!(read_result.is_ready());
1403    }
1404}