1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
2
3use crate::SpeedLimitSession;
4use std::io::Error;
5use std::pin::{Pin};
6use std::task::{Context, Poll};
7use pin_project::pin_project;
8use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9
10enum ReadState {
11 Idle,
12 Waiting(Option<(Pin<Box<dyn Future<Output=usize> + Send + Sync + 'static>>, usize)>),
13 Reading(Option<(usize, usize)>),
14}
15
16enum WriteState {
17 Idle,
18 Waiting(Option<(Pin<Box<dyn Future<Output=usize> + Send + Sync + 'static>>, usize)>),
19 Writing(Option<(usize, usize)>),
20}
21
22#[pin_project]
23pub struct LimitStream<S: AsyncRead + AsyncWrite + Unpin + Send> {
24 #[pin]
25 read: LimitRead<sfo_split::ReadHalf<S>>,
26 #[pin]
27 write: LimitWrite<sfo_split::WriteHalf<S>>,
28}
29
30impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> LimitStream<S> {
31 pub fn new(stream: S, read_limit: SpeedLimitSession, write_limit: SpeedLimitSession) -> Self {
32 let (read, write) = sfo_split::split(stream);
33 let limit_read = LimitRead::new(read, read_limit);
34 let limit_write = LimitWrite::new(write, write_limit);
35 LimitStream {
36 read: limit_read,
37 write: limit_write,
38 }
39 }
40 pub fn with_lock_raw_stream<R>(&mut self, f: impl FnOnce(Pin<&mut S>) -> R) -> R {
41 self.read.raw_read().with_lock(f)
42 }
43}
44
45impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncWrite for LimitStream<S> {
46 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
47 self.project().write.poll_write(cx, buf)
48 }
49
50 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
51 self.project().write.poll_flush(cx)
52 }
53
54 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
55 self.project().write.poll_shutdown(cx)
56 }
57}
58
59impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncRead for LimitStream<S> {
60 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
61 self.project().read.poll_read(cx, buf)
62 }
63}
64
65#[pin_project]
66pub struct LimitRead<S: AsyncRead + Unpin + Send> {
67 #[pin]
68 read: S,
69 read_limit: SpeedLimitSession,
70 read_state: ReadState,
71}
72
73impl<S: AsyncRead + Unpin + Send + 'static> LimitRead<S> {
74 pub fn new(read: S, read_limit: SpeedLimitSession) -> Self {
75 LimitRead {
76 read,
77 read_limit,
78 read_state: ReadState::Idle,
79 }
80 }
81
82 pub fn raw_read_mut(&mut self) -> &mut S {
83 &mut self.read
84 }
85
86 pub fn raw_read(&self) -> &S {
87 &self.read
88 }
89
90 pub fn into_raw_read(self) -> S {
91 self.read
92 }
93}
94
95impl<S: AsyncRead + Unpin + Send + 'static> AsyncRead for LimitRead<S> {
96 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
97 let this = self.project();
98 buf.initialize_unfilled();
99 match this.read_state {
100 ReadState::Idle => {
101 let mut readded_len = 0;
102
103 let read_limit: &'static mut SpeedLimitSession = unsafe {
104 std::mem::transmute(this.read_limit)
105 };
106 let mut waiting_future = Box::pin(read_limit.until_ready());
107 match Pin::new(&mut waiting_future).poll(cx) {
108 Poll::Ready(read_len) => {
109 let mut read_buf = if read_len <= buf.remaining() {
110 buf.take(read_len)
111 } else {
112 buf.take(buf.remaining())
113 };
114 match this.read.poll_read(cx, &mut read_buf) {
115 Poll::Ready(Ok(())) => {
116 let len = read_buf.filled().len();
117 readded_len += len;
118 buf.advance(len);
119 if readded_len >= read_len {
120 *this.read_state = ReadState::Idle;
121 } else {
122 *this.read_state = ReadState::Reading(Some((read_len, readded_len)));
123 }
124 Poll::Ready(Ok(()))
125 },
126 Poll::Ready(Err(e)) => {
127 *this.read_state = ReadState::Idle;
128 Poll::Ready(Err(e))
129 },
130 Poll::Pending => {
131 *this.read_state = ReadState::Reading(Some((read_len, readded_len)));
132 Poll::Pending
133 }
134 }
135 }
136 Poll::Pending => {
137 *this.read_state = ReadState::Waiting(Some((waiting_future, readded_len)));
138 Poll::Pending
139 }
140 }
141 }
142 ReadState::Waiting(state) => {
143 let (mut rx, mut readded_len) = state.take().unwrap();
144 match Pin::new(&mut rx).poll(cx) {
145 Poll::Ready(read_len) => {
146 let mut read_buf = if (read_len - readded_len) <= buf.remaining() {
147 buf.take(read_len - readded_len)
148 } else {
149 buf.take(buf.remaining())
150 };
151 match this.read.poll_read(cx, &mut read_buf) {
152 Poll::Ready(Ok(())) => {
153 let len = read_buf.filled().len();
154 readded_len += len;
155 buf.advance(len);
156 if readded_len >= read_len {
157 *this.read_state = ReadState::Idle;
158 } else {
159 *this.read_state = ReadState::Reading(Some((read_len, readded_len)));
160 }
161 Poll::Ready(Ok(()))
162 }
163 Poll::Ready(Err(e)) => {
164 *this.read_state = ReadState::Idle;
165 Poll::Ready(Err(e))
166 },
167 Poll::Pending => {
168 *this.read_state = ReadState::Reading(Some((read_len, readded_len)));
169 Poll::Pending
170 }
171 }
172 }
173 Poll::Pending => {
174 *this.read_state = ReadState::Waiting(Some((rx, readded_len)));
175 Poll::Pending
176 }
177 }
178 },
179 ReadState::Reading(state) => {
180 match state.take() {
181 Some((read_len, mut readded_len)) => {
182 let mut read_buf = if (read_len - readded_len) <= buf.remaining() {
183 buf.take(read_len - readded_len)
184 } else {
185 buf.take(buf.remaining())
186 };
187 match this.read.poll_read(cx, &mut read_buf) {
188 Poll::Ready(Ok(())) => {
189 let len = read_buf.filled().len();
190 readded_len += len;
191 buf.advance(len);
192 if readded_len >= read_len {
193 *this.read_state = ReadState::Idle;
194 } else {
195 *this.read_state = ReadState::Reading(Some((read_len, readded_len)));
196 }
197 Poll::Ready(Ok(()))
198 }
199 Poll::Ready(Err(e)) => {
200 *this.read_state = ReadState::Idle;
201 Poll::Ready(Err(e))
202 },
203 Poll::Pending => {
204 *this.read_state = ReadState::Reading(Some((read_len, readded_len)));
205 Poll::Pending
206 }
207 }
208 },
209 None => {
210 match this.read.poll_read(cx, buf) {
211 Poll::Ready(Ok(())) => {
212 *this.read_state = ReadState::Idle;
213 Poll::Ready(Ok(()))
214 },
215 Poll::Ready(Err(e)) => {
216 *this.read_state = ReadState::Idle;
217 Poll::Ready(Err(e))
218 },
219 Poll::Pending => {
220 *this.read_state = ReadState::Reading(None);
221 Poll::Pending
222 }
223 }
224 }
225 }
226
227 }
228 }
229 }
230}
231
232#[pin_project]
233pub struct LimitWrite<S: AsyncWrite + Unpin + Send> {
234 #[pin]
235 write: S,
236 write_limit: SpeedLimitSession,
237 write_state: WriteState,
238}
239
240impl<S: AsyncWrite + Unpin + Send + 'static> LimitWrite<S> {
241 pub fn new(write: S, write_limit: SpeedLimitSession) -> Self {
242 LimitWrite {
243 write,
244 write_limit,
245 write_state: WriteState::Idle,
246 }
247 }
248 pub fn raw_write_mut(&mut self) -> &mut S {
249 &mut self.write
250 }
251
252 pub fn raw_write(&self) -> &S {
253 &self.write
254 }
255
256 pub fn into_raw_write(self) -> S {
257 self.write
258 }
259}
260
261impl<S: AsyncWrite + Unpin + Send + 'static> AsyncWrite for LimitWrite<S> {
262 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
263 let this = self.project();
264 match this.write_state {
265 WriteState::Idle => {
266 let mut written_len = 0;
267 let write_limiter: &'static mut SpeedLimitSession = unsafe {
268 std::mem::transmute(this.write_limit)
269 };
270 let mut waiting_future = Box::pin(write_limiter.until_ready());
271 match Pin::new(&mut waiting_future).poll(cx) {
272 Poll::Ready(write_len) => {
273 let write_buf = if write_len <= buf.len() {
274 &buf[..write_len]
275 } else {
276 buf
277 };
278 match this.write.poll_write(cx, write_buf) {
279 Poll::Ready(Ok(len)) => {
280 written_len += len;
281 if written_len >= write_len {
282 *this.write_state = WriteState::Idle;
283 } else {
284 *this.write_state = WriteState::Writing(Some((write_len, written_len)));
285 }
286 Poll::Ready(Ok(written_len))
287 }
288 Poll::Ready(Err(e)) => {
289 *this.write_state = WriteState::Idle;
290 Poll::Ready(Err(e))
291 }
292 Poll::Pending => {
293 *this.write_state = WriteState::Writing(Some((write_len, written_len)));
294 Poll::Pending
295 }
296 }
297 }
298 Poll::Pending => {
299 *this.write_state = WriteState::Waiting(Some((waiting_future, written_len)));
300 Poll::Pending
301 }
302 }
303 }
304 WriteState::Waiting(state) => {
305 let (mut waiting_future, mut written_len) = state.take().unwrap();
306 match Pin::new(&mut waiting_future).poll(cx) {
307 Poll::Ready(write_len) => {
308 let write_buf = if write_len - written_len <= buf.len() {
309 &buf[..(write_len - written_len)]
310 } else {
311 buf
312 };
313 match this.write.poll_write(cx, write_buf) {
314 Poll::Ready(Ok(len)) => {
315 written_len += len;
316 if written_len >= write_len {
317 *this.write_state = WriteState::Idle;
318 } else {
319 *this.write_state = WriteState::Writing(Some((write_len, written_len)));
320 }
321 Poll::Ready(Ok(len))
322 },
323 Poll::Ready(Err(e)) => {
324 *this.write_state = WriteState::Idle;
325 Poll::Ready(Err(e))
326 },
327 Poll::Pending => {
328 *this.write_state = WriteState::Writing(Some((write_len, written_len)));
329 Poll::Pending
330 }
331 }
332 }
333 Poll::Pending => {
334 *this.write_state = WriteState::Waiting(Some((waiting_future, written_len)));
335 Poll::Pending
336 }
337 }
338 }
339 WriteState::Writing(state) => {
340 match state.take() {
341 Some((write_len, mut written_len)) => {
342 let write_buf = if write_len - written_len <= buf.len() {
343 &buf[..(write_len - written_len)]
344 } else {
345 buf
346 };
347 match this.write.poll_write(cx, write_buf) {
348 Poll::Ready(Ok(len)) => {
349 written_len += len;
350 if written_len >= write_len {
351 *this.write_state = WriteState::Idle;
352 } else {
353 *this.write_state = WriteState::Writing(Some((write_len, written_len)));
354 }
355 Poll::Ready(Ok(len))
356 },
357 Poll::Ready(Err(e)) => {
358 *this.write_state = WriteState::Idle;
359 Poll::Ready(Err(e))
360 },
361 Poll::Pending => {
362 *this.write_state = WriteState::Writing(Some((write_len, written_len)));
363 Poll::Pending
364 }
365 }
366 },
367 None => {
368 match this.write.poll_write(cx, buf) {
369 Poll::Ready(Ok(len)) => {
370 *this.write_state = WriteState::Idle;
371 Poll::Ready(Ok(len))
372 },
373 Poll::Ready(Err(e)) => {
374 *this.write_state = WriteState::Idle;
375 Poll::Ready(Err(e))
376 },
377 Poll::Pending => {
378 *this.write_state = WriteState::Writing(None);
379 Poll::Pending
380 }
381 }
382 }
383 }
384 }
385 }
386 }
387
388 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
389 self.project().write.poll_flush(cx)
390 }
391
392 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
393 self.project().write.poll_shutdown(cx)
394 }
395}
396
397#[cfg(test)]
398#[cfg_attr(coverage_nightly, coverage(off))]
399mod tests {
400 use std::future::poll_fn;
401 use super::*;
402 use std::io::{Error, ErrorKind};
403 use std::pin::Pin;
404 use std::task::{Context, Poll};
405 use std::time::{Duration, Instant};
406 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
407 use futures::task::noop_waker;
408 use std::num::NonZeroU32;
409
410 struct MockStream {
412 read_data: Vec<u8>,
413 read_pos: usize,
414 read_should_pending: bool,
415 read_error: Option<Error>,
416 write_should_pending: bool,
417 write_error: Option<Error>,
418 written_data: Vec<u8>,
419 }
420
421 impl MockStream {
422 fn new(read_data: Vec<u8>) -> Self {
423 Self {
424 read_data,
425 read_pos: 0,
426 read_should_pending: false,
427 read_error: None,
428 write_should_pending: false,
429 write_error: None,
430 written_data: Vec::new(),
431 }
432 }
433
434 fn with_read_pending(mut self) -> Self {
435 self.read_should_pending = true;
436 self
437 }
438
439 fn with_read_error(mut self, error: Error) -> Self {
440 self.read_error = Some(error);
441 self
442 }
443
444 fn with_write_pending(mut self) -> Self {
445 self.write_should_pending = true;
446 self
447 }
448
449 fn with_write_error(mut self, error: Error) -> Self {
450 self.write_error = Some(error);
451 self
452 }
453
454 fn written_data(&self) -> &[u8] {
455 &self.written_data
456 }
457 }
458
459 impl AsyncRead for MockStream {
460 fn poll_read(
461 mut self: Pin<&mut Self>,
462 _cx: &mut Context<'_>,
463 buf: &mut ReadBuf<'_>
464 ) -> Poll<std::io::Result<()>> {
465 if let Some(error) = self.read_error.take() {
466 return Poll::Ready(Err(error));
467 }
468
469 if self.read_should_pending {
470 return Poll::Pending;
471 }
472
473 let remaining = self.read_data.len() - self.read_pos;
474 if remaining == 0 {
475 return Poll::Ready(Ok(()));
476 }
477
478 let to_copy = std::cmp::min(remaining, buf.remaining());
479 buf.put_slice(&self.read_data[self.read_pos..self.read_pos + to_copy]);
480 self.read_pos += to_copy;
481
482 Poll::Ready(Ok(()))
483 }
484 }
485
486 impl AsyncWrite for MockStream {
487 fn poll_write(
488 mut self: Pin<&mut Self>,
489 _cx: &mut Context<'_>,
490 buf: &[u8]
491 ) -> Poll<Result<usize, Error>> {
492 if let Some(error) = self.write_error.take() {
493 return Poll::Ready(Err(error));
494 }
495
496 if self.write_should_pending {
497 return Poll::Pending;
498 }
499
500 self.written_data.extend_from_slice(buf);
501 Poll::Ready(Ok(buf.len()))
502 }
503
504 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
505 Poll::Ready(Ok(()))
506 }
507
508 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
509 Poll::Ready(Ok(()))
510 }
511 }
512
513 #[tokio::test]
514 async fn test_read_without_limit() {
515 let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
516 let read_limit = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap())).new_limit_session();
517 let write_limit = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap())).new_limit_session();
518 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
519
520 let mut buffer = [0u8; 10];
521 let mut read_buf = ReadBuf::new(&mut buffer);
522 let waker = noop_waker();
523 let mut cx = Context::from_waker(&waker);
524
525 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
527 assert!(result.is_ready());
528 assert_eq!(read_buf.filled(), &[1, 2, 3, 4, 5]);
529 }
530
531 #[tokio::test]
532 async fn test_read_without_limit1() {
533 let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]).with_read_pending();
534 let read_limit = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap())).new_limit_session();
535 let write_limit = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap())).new_limit_session();
536 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
537
538 let mut buffer = [0u8; 3];
539 let mut read_buf = ReadBuf::new(&mut buffer);
540 let waker = noop_waker();
541 let mut cx = Context::from_waker(&waker);
542
543 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
545 assert!(result.is_pending());
546 limit_stream.with_lock_raw_stream(|stream| {
547 stream.get_mut().read_should_pending = false;
548 });
549 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
550 assert!(result.is_ready());
551 assert_eq!(read_buf.filled(), &[1, 2, 3]);
552 let mut read_buf = ReadBuf::new(&mut buffer);
553 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
554 assert!(result.is_ready());
555 assert_eq!(read_buf.filled(), &[4, 5]);
556 }
557
558 #[tokio::test]
559 async fn test_read_without_limit2() {
560 let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]).with_read_pending();
561 let read_limit = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap())).new_limit_session();
562 let write_limit = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap())).new_limit_session();
563 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
564
565 let mut buffer = [0u8; 3];
566 let mut read_buf = ReadBuf::new(&mut buffer);
567 let waker = noop_waker();
568 let mut cx = Context::from_waker(&waker);
569
570 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
572 assert!(result.is_pending());
573 limit_stream.with_lock_raw_stream(|stream| {
574 let stream = stream.get_mut();
575 stream.read_should_pending = false;
576 let error = Error::new(ErrorKind::Other, "read error");
577 stream.read_error = Some(error);
578 });
579 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
580 assert!(result.is_ready());
581
582 if let Poll::Ready(ret) = result {
583 assert!(ret.is_err());
584 }
585 }
586
587 #[tokio::test]
588 async fn test_read_without_limit_err() {
589 let error = Error::new(ErrorKind::Other, "read error");
590 let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]).with_read_error(error);
591 let read_limit = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap())).new_limit_session();
592 let write_limit = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap())).new_limit_session();
593 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
594
595 let mut buffer = [0u8; 10];
596 let mut read_buf = ReadBuf::new(&mut buffer);
597 let waker = noop_waker();
598 let mut cx = Context::from_waker(&waker);
599
600 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
602 assert!(result.is_ready());
603 match result {
604 Poll::Ready(ret) => assert!(ret.is_err()),
605 Poll::Pending => panic!("Expected ready"),
606 }
607 }
608
609 #[tokio::test]
610 async fn test_read_with_limit() {
611 let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
612 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
613 let read_limit = read_limiter.new_limit_session();
614 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
615 let write_limit = write_limiter.new_limit_session();
616 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
617
618 let mut buffer = [0u8; 10];
619 let mut read_buf = ReadBuf::new(&mut buffer);
620 let waker = noop_waker();
621 let mut cx = Context::from_waker(&waker);
622
623 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
625 assert!(result.is_ready());
626 assert_eq!(read_buf.filled(), &[1]);
627
628 let start = Instant::now();
629 let mut read_buf = ReadBuf::new(&mut buffer);
630 let result = poll_fn(|cx| {
631 Pin::new(&mut limit_stream).poll_read(cx, &mut read_buf)
632 }).await;
633 assert!(start.elapsed() >= Duration::from_millis(900));
634 assert!(result.is_ok());
635 assert_eq!(read_buf.filled(), &[2]);
636
637 let mut read_buf = ReadBuf::new(&mut buffer);
638 let result = poll_fn(|cx| {
639 Pin::new(&mut limit_stream).poll_read(cx, &mut read_buf)
640 }).await;
641 assert!(start.elapsed() >= Duration::from_millis(1900));
642 assert!(result.is_ok());
643 assert_eq!(read_buf.filled(), &[3]);
644 }
645
646 #[tokio::test]
647 async fn test_read_with_limit2() {
648 let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
649 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(2).unwrap()));
650 let read_limit = read_limiter.new_limit_session();
651 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
652 let write_limit = write_limiter.new_limit_session();
653 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
654
655 let mut buffer = [0u8; 10];
656 let mut read_buf = ReadBuf::new(&mut buffer);
657 let waker = noop_waker();
658 let mut cx = Context::from_waker(&waker);
659
660 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
662 assert!(result.is_ready());
663 assert_eq!(read_buf.filled(), &[1, 2]);
664
665 let start = Instant::now();
666 let mut read_buf = ReadBuf::new(&mut buffer);
667 let result = poll_fn(|cx| {
668 Pin::new(&mut limit_stream).poll_read(cx, &mut read_buf)
669 }).await;
670 assert!(start.elapsed() >= Duration::from_millis(900));
671 assert!(result.is_ok());
672 assert_eq!(read_buf.filled(), &[3, 4]);
673
674 let mut read_buf = ReadBuf::new(&mut buffer);
675 let result = poll_fn(|cx| {
676 Pin::new(&mut limit_stream).poll_read(cx, &mut read_buf)
677 }).await;
678 assert!(start.elapsed() >= Duration::from_millis(1900));
679 assert!(result.is_ok());
680 assert_eq!(read_buf.filled(), &[5]);
681 }
682
683 #[tokio::test]
684 async fn test_read_with_limit3() {
685 let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
686 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(2).unwrap()));
687 let read_limit = read_limiter.new_limit_session();
688 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
689 let write_limit = write_limiter.new_limit_session();
690 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
691
692 let mut buffer = [0u8; 1];
693 let mut read_buf = ReadBuf::new(&mut buffer);
694 let waker = noop_waker();
695 let mut cx = Context::from_waker(&waker);
696
697 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
698 assert!(result.is_ready());
699 assert_eq!(read_buf.filled(), &[1]);
700
701 let mut read_buf = ReadBuf::new(&mut buffer);
702 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
703 assert!(result.is_ready());
704 assert_eq!(read_buf.filled(), &[2]);
705
706 let start = Instant::now();
707 let mut read_buf = ReadBuf::new(&mut buffer);
708 let result = poll_fn(|cx| {
709 Pin::new(&mut limit_stream).poll_read(cx, &mut read_buf)
710 }).await;
711 assert!(start.elapsed() >= Duration::from_millis(900));
712 assert!(result.is_ok());
713 assert_eq!(read_buf.filled(), &[3]);
714
715 let mut read_buf = ReadBuf::new(&mut buffer);
716 let result = poll_fn(|cx| {
717 Pin::new(&mut limit_stream).poll_read(cx, &mut read_buf)
718 }).await;
719 assert!(start.elapsed() < Duration::from_millis(1100));
720 assert!(result.is_ok());
721 assert_eq!(read_buf.filled(), &[4]);
722 }
723
724 #[tokio::test]
725 async fn test_read_with_limit4() {
726 let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
727 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(2).unwrap()), Some(NonZeroU32::new(1).unwrap()));
728 let read_limit = read_limiter.new_limit_session();
729 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
730 let write_limit = write_limiter.new_limit_session();
731 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
732
733 let mut buffer = [0u8; 1];
734 let mut read_buf = ReadBuf::new(&mut buffer);
735 let waker = noop_waker();
736 let mut cx = Context::from_waker(&waker);
737
738 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
739 assert!(result.is_ready());
740 assert_eq!(read_buf.filled(), &[1]);
741
742 let mut read_buf = ReadBuf::new(&mut buffer);
743 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
744 assert!(result.is_ready());
745 assert_eq!(read_buf.filled(), &[2]);
746
747 let mut read_buf = ReadBuf::new(&mut buffer);
748 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
749 assert!(result.is_pending());
750 tokio::time::sleep(Duration::from_millis(600)).await;
751 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
752 assert!(result.is_ready());
753 assert_eq!(read_buf.filled(), &[3]);
754
755 }
756
757 #[tokio::test]
758 async fn test_read_with_limit5() {
759 let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
760 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(2).unwrap()), Some(NonZeroU32::new(1).unwrap()));
761 let read_limit = read_limiter.new_limit_session();
762 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
763 let write_limit = write_limiter.new_limit_session();
764 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
765
766 let mut buffer = [0u8; 1];
767 let mut read_buf = ReadBuf::new(&mut buffer);
768 let waker = noop_waker();
769 let mut cx = Context::from_waker(&waker);
770
771 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
772 assert!(result.is_ready());
773 assert_eq!(read_buf.filled(), &[1]);
774 let mut read_buf = ReadBuf::new(&mut buffer);
775 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
776 assert!(result.is_ready());
777 assert_eq!(read_buf.filled(), &[2]);
778
779 let mut read_buf = ReadBuf::new(&mut buffer);
780 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
781 assert!(result.is_pending());
782 tokio::time::sleep(Duration::from_millis(1100)).await;
783 limit_stream.with_lock_raw_stream(|stream| {
784 stream.get_mut().read_should_pending = true;
785 });
786 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
787 assert!(result.is_pending());
788 limit_stream.with_lock_raw_stream(|stream| {
789 stream.get_mut().read_should_pending = false;
790 });
791 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
792 assert!(result.is_ready());
793 assert_eq!(read_buf.filled(), &[3]);
794 limit_stream.with_lock_raw_stream(|stream| {
795 let error = Error::new(ErrorKind::Other, "read error");
796 stream.get_mut().read_error = Some(error);
797 });
798 let mut read_buf = ReadBuf::new(&mut buffer);
799 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
800 assert!(result.is_ready());
801
802 if let Poll::Ready(Err(e)) = result {
803 assert_eq!(e.kind(), ErrorKind::Other);
804 }
805 }
806
807 #[tokio::test]
808 async fn test_read_with_limit6() {
809 let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
810 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(2).unwrap()), Some(NonZeroU32::new(1).unwrap()));
811 let read_limit = read_limiter.new_limit_session();
812 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
813 let write_limit = write_limiter.new_limit_session();
814 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
815
816 let mut buffer = [0u8; 1];
817 let mut read_buf = ReadBuf::new(&mut buffer);
818 let waker = noop_waker();
819 let mut cx = Context::from_waker(&waker);
820
821 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
822 assert!(result.is_ready());
823 assert_eq!(read_buf.filled(), &[1]);
824
825 let mut read_buf = ReadBuf::new(&mut buffer);
826 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
827 assert!(result.is_ready());
828 assert_eq!(read_buf.filled(), &[2]);
829
830 let mut read_buf = ReadBuf::new(&mut buffer);
831 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
832 assert!(result.is_pending());
833 tokio::time::sleep(Duration::from_millis(600)).await;
834 limit_stream.with_lock_raw_stream(|stream| {
835 let error = Error::new(ErrorKind::Other, "read error");
836 stream.get_mut().read_error = Some(error);
837 });
838 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
839 assert!(result.is_ready());
840
841 if let Poll::Ready(Err(e)) = result {
842 assert_eq!(e.kind(), ErrorKind::Other);
843 }
844 }
845
846 #[tokio::test]
847 async fn test_write_without_limit() {
848 let mock_stream = MockStream::new(vec![]);
849 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
850 let read_limit = read_limiter.new_limit_session();
851 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap()));
852 let write_limit = write_limiter.new_limit_session();
853 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
854
855 let data = [1, 2, 3, 4, 5];
856 let waker = noop_waker();
857 let mut cx = Context::from_waker(&waker);
858
859 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
861 assert!(result.is_ready());
862
863 if let Poll::Ready(Ok(written)) = result {
864 assert_eq!(written, 5);
865 }
866 }
867
868 #[tokio::test]
869 async fn test_write_without_limit2() {
870 let error = Error::new(ErrorKind::Other, "write error");
871 let mock_stream = MockStream::new(vec![]).with_write_error(error);
872 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
873 let read_limit = read_limiter.new_limit_session();
874 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap()));
875 let write_limit = write_limiter.new_limit_session();
876 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
877
878 let data = [1, 2, 3, 4, 5];
879 let waker = noop_waker();
880 let mut cx = Context::from_waker(&waker);
881
882 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
884 assert!(result.is_ready());
885
886 if let Poll::Ready(ret) = result {
887 assert!(ret.is_err());
888 if let Err(e) = ret {
889 assert_eq!(e.kind(), ErrorKind::Other);
890 }
891 }
892 }
893
894 #[tokio::test]
895 async fn test_write_without_limit3() {
896 let mock_stream = MockStream::new(vec![]).with_write_pending();
897 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
898 let read_limit = read_limiter.new_limit_session();
899 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap()));
900 let write_limit = write_limiter.new_limit_session();
901 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
902
903 let data = [1, 2, 3, 4, 5];
904 let waker = noop_waker();
905 let mut cx = Context::from_waker(&waker);
906
907 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
908 assert!(result.is_pending());
909
910 limit_stream.with_lock_raw_stream(|stream| {
911 stream.get_mut().write_should_pending = false;
912 });
913
914 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
916 assert!(result.is_ready());
917
918 if let Poll::Ready(Ok(written)) = result {
919 assert_eq!(written, 5);
920 }
921 }
922
923 #[tokio::test]
924 async fn test_write_without_limit4() {
925 let mock_stream = MockStream::new(vec![]).with_write_pending();
926 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
927 let read_limit = read_limiter.new_limit_session();
928 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1024).unwrap()));
929 let write_limit = write_limiter.new_limit_session();
930 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
931
932 let data = [1, 2, 3, 4, 5];
933 let waker = noop_waker();
934 let mut cx = Context::from_waker(&waker);
935
936 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
937 assert!(result.is_pending());
938
939 limit_stream.with_lock_raw_stream(|stream| {
940 let stream = stream.get_mut();
941 stream.write_should_pending = false;
942 let ererror = Error::new(ErrorKind::Other, "write error");
943 stream.write_error = Some(ererror);
944 });
945
946 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
948 assert!(result.is_ready());
949
950 if let Poll::Ready(ret) = result {
951 assert!(ret.is_err());
952 if let Err(e) = ret {
953 assert_eq!(e.kind(), ErrorKind::Other);
954 }
955 }
956 }
957
958 #[tokio::test]
959 async fn test_write_with_limit() {
960 let mock_stream = MockStream::new(vec![]);
961 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
962 let read_limit = read_limiter.new_limit_session();
963 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
964 let write_limit = write_limiter.new_limit_session();
965 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
966
967 let data = [1, 2, 3, 4, 5];
968 let waker = noop_waker();
969 let mut cx = Context::from_waker(&waker);
970
971 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
973 assert!(result.is_ready());
975 if let Poll::Ready(Ok(written)) = result {
976 assert_eq!(written, 1);
977 }
978
979 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
981 assert!(result.is_pending());
982 let start = Instant::now();
983 let result = poll_fn(|cx| {
984 Pin::new(&mut limit_stream).poll_write(cx, &data)
985 }).await;
986 assert!(start.elapsed() >= Duration::from_millis(900));
987 assert!(result.is_ok());
988 assert_eq!(result.unwrap(), 1);
989
990 let result = poll_fn(|cx| {
991 Pin::new(&mut limit_stream).poll_write(cx, &data)
992 }).await;
993 assert!(start.elapsed() >= Duration::from_millis(1900));
994 assert!(result.is_ok());
995 assert_eq!(result.unwrap(), 1);
996 }
997
998 #[tokio::test]
999 async fn test_write_with_limit1() {
1000 let mock_stream = MockStream::new(vec![]).with_write_pending();
1001 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1002 let read_limit = read_limiter.new_limit_session();
1003 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1004 let write_limit = write_limiter.new_limit_session();
1005 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1006
1007 let data = [1, 2, 3, 4, 5];
1008 let waker = noop_waker();
1009 let mut cx = Context::from_waker(&waker);
1010
1011 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1012 assert!(result.is_pending());
1014 limit_stream.with_lock_raw_stream(|stream| {
1015 stream.get_mut().write_should_pending = false;
1016 });
1017
1018 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1019 assert!(result.is_ready());
1020 if let Poll::Ready(ret) = result {
1021 assert!(ret.is_ok());
1022 assert_eq!(ret.unwrap(), 1);
1023 }
1024
1025 tokio::time::sleep(Duration::from_millis(1100)).await;
1026 limit_stream.with_lock_raw_stream(|stream| {
1027 stream.get_mut().write_should_pending = true;
1028 });
1029
1030 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1031 assert!(result.is_pending());
1032
1033 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1034 assert!(result.is_pending());
1035
1036 limit_stream.with_lock_raw_stream(|stream| {
1037 let stream = stream.get_mut();
1038 stream.write_should_pending = false;
1039 let ererror = Error::new(ErrorKind::Other, "write error");
1040 stream.write_error = Some(ererror);
1041 });
1042
1043 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1044 assert!(result.is_ready());
1045 if let Poll::Ready(Err(e)) = result {
1046 assert_eq!(e.kind(), ErrorKind::Other);
1047 }
1048 }
1049
1050 #[tokio::test]
1051 async fn test_write_with_limit2() {
1052 let mock_stream = MockStream::new(vec![]).with_write_pending();
1053 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1054 let read_limit = read_limiter.new_limit_session();
1055 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(2).unwrap()));
1056 let write_limit = write_limiter.new_limit_session();
1057 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1058
1059 let data = [1];
1060 let waker = noop_waker();
1061 let mut cx = Context::from_waker(&waker);
1062
1063 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1064 assert!(result.is_pending());
1066 limit_stream.with_lock_raw_stream(|stream| {
1067 let stream = stream.get_mut();
1068 stream.write_should_pending = false;
1069 });
1070
1071 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1072 assert!(result.is_ready());
1073 if let Poll::Ready(ret) = result {
1074 assert!(ret.is_ok());
1075 assert_eq!(ret.unwrap(), 1);
1076 }
1077 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1078 assert!(result.is_ready());
1079 if let Poll::Ready(ret) = result {
1080 assert!(ret.is_ok());
1081 assert_eq!(ret.unwrap(), 1);
1082 }
1083 }
1084
1085 #[tokio::test]
1086 async fn test_write_with_limit3() {
1087 let mock_stream = MockStream::new(vec![]);
1088 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1089 let read_limit = read_limiter.new_limit_session();
1090 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(2).unwrap()));
1091 let write_limit = write_limiter.new_limit_session();
1092 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1093
1094 let data = [1];
1095 let waker = noop_waker();
1096 let mut cx = Context::from_waker(&waker);
1097
1098 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1099 assert!(result.is_ready());
1100 if let Poll::Ready(ret) = result {
1101 assert!(ret.is_ok());
1102 assert_eq!(ret.unwrap(), 1);
1103 }
1104
1105 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1106 assert!(result.is_ready());
1107 if let Poll::Ready(ret) = result {
1108 assert!(ret.is_ok());
1109 assert_eq!(ret.unwrap(), 1);
1110 }
1111
1112 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1113 assert!(result.is_pending());
1115 limit_stream.with_lock_raw_stream(|stream| {
1116 let stream = stream.get_mut();
1117 stream.write_should_pending = false;
1118 });
1119
1120 tokio::time::sleep(Duration::from_millis(1100)).await;
1121 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1122 assert!(result.is_ready());
1123 if let Poll::Ready(ret) = result {
1124 assert!(ret.is_ok());
1125 assert_eq!(ret.unwrap(), 1);
1126 }
1127 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1128 assert!(result.is_ready());
1129 if let Poll::Ready(ret) = result {
1130 assert!(ret.is_ok());
1131 assert_eq!(ret.unwrap(), 1);
1132 }
1133 }
1134
1135 #[tokio::test]
1136 async fn test_write_with_limit4() {
1137 let mock_stream = MockStream::new(vec![]);
1138 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1139 let read_limit = read_limiter.new_limit_session();
1140 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(2).unwrap()));
1141 let write_limit = write_limiter.new_limit_session();
1142 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1143
1144 let data = [1];
1145 let waker = noop_waker();
1146 let mut cx = Context::from_waker(&waker);
1147
1148 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1149 assert!(result.is_ready());
1150 if let Poll::Ready(ret) = result {
1151 assert!(ret.is_ok());
1152 assert_eq!(ret.unwrap(), 1);
1153 }
1154
1155 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1156 assert!(result.is_ready());
1157 if let Poll::Ready(ret) = result {
1158 assert!(ret.is_ok());
1159 assert_eq!(ret.unwrap(), 1);
1160 }
1161
1162 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1163 assert!(result.is_pending());
1164
1165 tokio::time::sleep(Duration::from_millis(1100)).await;
1166 limit_stream.with_lock_raw_stream(|stream| {
1167 let stream = stream.get_mut();
1168 stream.write_should_pending = true;
1169 });
1170 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1171 assert!(result.is_pending());
1172
1173 limit_stream.with_lock_raw_stream(|stream| {
1174 let stream = stream.get_mut();
1175 stream.write_should_pending = false;
1176 let ererror = Error::new(ErrorKind::Other, "write error");
1177 stream.write_error = Some(ererror);
1178 });
1179 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1180 assert!(result.is_ready());
1181 if let Poll::Ready(ret) = result {
1182 assert!(ret.is_err());
1183 }
1184 }
1185
1186 #[tokio::test]
1187 async fn test_write_with_limit5() {
1188 let mock_stream = MockStream::new(vec![]);
1189 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1190 let read_limit = read_limiter.new_limit_session();
1191 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(2).unwrap()));
1192 let write_limit = write_limiter.new_limit_session();
1193 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1194
1195 let data = [1];
1196 let waker = noop_waker();
1197 let mut cx = Context::from_waker(&waker);
1198
1199 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1200 assert!(result.is_ready());
1201 if let Poll::Ready(ret) = result {
1202 assert!(ret.is_ok());
1203 assert_eq!(ret.unwrap(), 1);
1204 }
1205
1206 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1207 assert!(result.is_ready());
1208 if let Poll::Ready(ret) = result {
1209 assert!(ret.is_ok());
1210 assert_eq!(ret.unwrap(), 1);
1211 }
1212
1213 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1214 assert!(result.is_pending());
1215
1216 tokio::time::sleep(Duration::from_millis(1100)).await;
1217
1218 limit_stream.with_lock_raw_stream(|stream| {
1219 let stream = stream.get_mut();
1220 let ererror = Error::new(ErrorKind::Other, "write error");
1221 stream.write_error = Some(ererror);
1222 });
1223
1224 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1225 assert!(result.is_ready());
1226 if let Poll::Ready(ret) = result {
1227 assert!(ret.is_err());
1228 }
1229 }
1230
1231 #[tokio::test]
1232 async fn test_read_error_propagation() {
1233 let error = Error::new(ErrorKind::Other, "read error");
1234 let mock_stream = MockStream::new(vec![]).with_read_error(error);
1235 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1236 let read_limit = read_limiter.new_limit_session();
1237 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(10).unwrap()), Some(NonZeroU32::new(10).unwrap()));
1238 let write_limit = write_limiter.new_limit_session();
1239 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1240
1241 let mut buffer = [0u8; 10];
1242 let mut read_buf = ReadBuf::new(&mut buffer);
1243 let waker = noop_waker();
1244 let mut cx = Context::from_waker(&waker);
1245
1246 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1247 assert!(result.is_ready());
1248
1249 if let Poll::Ready(Err(e)) = result {
1250 assert_eq!(e.kind(), ErrorKind::Other);
1251 }
1252 }
1253
1254 #[tokio::test]
1255 async fn test_write_error_propagation() {
1256 let error = Error::new(ErrorKind::Other, "write error");
1257 let mock_stream = MockStream::new(vec![]).with_write_error(error);
1258 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1259 let read_limit = read_limiter.new_limit_session();
1260 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(10).unwrap()), Some(NonZeroU32::new(10).unwrap()));
1261 let write_limit = write_limiter.new_limit_session();
1262 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1263
1264 let data = [1, 2, 3, 4, 5];
1265 let waker = noop_waker();
1266 let mut cx = Context::from_waker(&waker);
1267
1268 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1269 assert!(result.is_ready());
1270
1271 if let Poll::Ready(Err(e)) = result {
1272 assert_eq!(e.kind(), ErrorKind::Other);
1273 }
1274 }
1275
1276 #[tokio::test]
1277 async fn test_read_limit_pending_handling() {
1278 let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]).with_read_pending();
1279 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1280 let read_limit = read_limiter.new_limit_session();
1281 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1282 let write_limit = write_limiter.new_limit_session();
1283 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1284
1285 let mut buffer = [0u8; 10];
1286 let mut read_buf = ReadBuf::new(&mut buffer);
1287 let waker = noop_waker();
1288 let mut cx = Context::from_waker(&waker);
1289
1290 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1292 assert!(result.is_pending());
1293 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1294 assert!(result.is_pending());
1295 }
1296
1297 #[tokio::test]
1298 async fn test_read_limit_pending_handling2() {
1299 let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
1300 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1301 let read_limit = read_limiter.new_limit_session();
1302 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1303 let write_limit = write_limiter.new_limit_session();
1304 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1305
1306 let mut buffer = [0u8; 1];
1307 let mut read_buf = ReadBuf::new(&mut buffer);
1308 let waker = noop_waker();
1309 let mut cx = Context::from_waker(&waker);
1310
1311 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1313 assert!(result.is_ready());
1314 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1315 assert!(result.is_pending());
1316 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1317 assert!(result.is_pending());
1318 }
1319
1320 #[tokio::test]
1321 async fn test_read_pending_handling() {
1322 let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]).with_read_pending();
1323 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1324 let read_limit = read_limiter.new_limit_session();
1325 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
1326 let write_limit = write_limiter.new_limit_session();
1327 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1328
1329 let mut buffer = [0u8; 10];
1330 let mut read_buf = ReadBuf::new(&mut buffer);
1331 let waker = noop_waker();
1332 let mut cx = Context::from_waker(&waker);
1333
1334 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1336 assert!(result.is_pending());
1337 let result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1338 assert!(result.is_pending());
1339 }
1340
1341 #[tokio::test]
1342 async fn test_write_pending_handling() {
1343 let mock_stream = MockStream::new(vec![]).with_write_pending();
1344 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1345 let read_limit = read_limiter.new_limit_session();
1346 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
1347 let write_limit = write_limiter.new_limit_session();
1348 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1349
1350 let data = [1, 2, 3, 4, 5];
1351 let waker = noop_waker();
1352 let mut cx = Context::from_waker(&waker);
1353
1354 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1356 assert!(result.is_pending());
1357 let result = Pin::new(&mut limit_stream).poll_write(&mut cx, &data);
1358 assert!(result.is_pending());
1359 }
1360
1361 #[tokio::test]
1362 async fn test_flush_and_shutdown() {
1363 let mock_stream = MockStream::new(vec![]);
1364 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(1).unwrap()), Some(NonZeroU32::new(1).unwrap()));
1365 let read_limit = read_limiter.new_limit_session();
1366 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::MAX), Some(NonZeroU32::new(1).unwrap()));
1367 let write_limit = write_limiter.new_limit_session();
1368 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1369 let waker = noop_waker();
1370 let mut cx = Context::from_waker(&waker);
1371
1372 let flush_result = Pin::new(&mut limit_stream).poll_flush(&mut cx);
1374 assert!(flush_result.is_ready());
1375
1376 let shutdown_result = Pin::new(&mut limit_stream).poll_shutdown(&mut cx);
1378 assert!(shutdown_result.is_ready());
1379 }
1380
1381 #[tokio::test]
1382 async fn test_mixed_read_write() {
1383 let mock_stream = MockStream::new(vec![1, 2, 3, 4, 5]);
1384 let read_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(10).unwrap()), Some(NonZeroU32::new(10).unwrap()));
1385 let read_limit = read_limiter.new_limit_session();
1386 let write_limiter = crate::SpeedLimiter::new(None, Some(NonZeroU32::new(10).unwrap()), Some(NonZeroU32::new(10).unwrap()));
1387 let write_limit = write_limiter.new_limit_session();
1388 let mut limit_stream = LimitStream::new(mock_stream, read_limit, write_limit);
1389
1390 let waker = noop_waker();
1391 let mut cx = Context::from_waker(&waker);
1392
1393 let write_data = [6, 7, 8, 9, 10];
1395 let write_result = Pin::new(&mut limit_stream).poll_write(&mut cx, &write_data);
1396 assert!(write_result.is_ready());
1397
1398 let mut buffer = [0u8; 10];
1400 let mut read_buf = ReadBuf::new(&mut buffer);
1401 let read_result = Pin::new(&mut limit_stream).poll_read(&mut cx, &mut read_buf);
1402 assert!(read_result.is_ready());
1403 }
1404}