1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
//! This crate provides `ThrottledReader`, a proxy-type for `io::Read` that limits how many times
//! the underlying reader can be read from. If the read budget is exceeded,
//! `io::ErrorKind::WouldBlock` is returned instead. This type can be useful to enforce fairness
//! when reading from many (potentially asynchronous) input streams with highly varying load. If
//! one stream always has data available, a worker may continue consuming its input forever,
//! neglecting the other stream.
//!
//! # Examples
//!
//! ```
//! # use std::io;
//! # use std::io::prelude::*;
//! # use throttled_reader::ThrottledReader;
//! let mut buf = [0];
//! let mut stream = ThrottledReader::new(io::empty());
//!
//! // initially no limit
//! assert!(stream.read(&mut buf).is_ok());
//! assert!(stream.read(&mut buf).is_ok());
//!
//! // set a limit
//! stream.set_limit(2);
//! assert!(stream.read(&mut buf).is_ok()); // first is allowed through
//! assert!(stream.read(&mut buf).is_ok()); // second is also allowed through
//! // but now the limit is reached, and the underlying stream is no longer accessible
//! assert_eq!(
//!     stream.read(&mut buf).unwrap_err().kind(),
//!     io::ErrorKind::WouldBlock
//! );
//!
//! // we can then unthrottle it again after checking other streams
//! stream.unthrottle();
//! assert!(stream.read(&mut buf).is_ok());
//! assert!(stream.read(&mut buf).is_ok());
//! ```
#![deny(missing_docs)]
use std::io;

/// `ThrottleReader` proxies an `io::Read`, but enforces a budget on how many `read` calls can be
/// made to the underlying reader.
#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct ThrottledReader<R> {
    reader: R,
    read_budget: Option<usize>,
}

impl<R> ThrottledReader<R> {
    /// Construct a new throttler that wraps the given reader.
    ///
    /// The new `ThrottledReader` initially has no limit.
    pub fn new(reader: R) -> Self {
        ThrottledReader {
            reader,
            read_budget: None,
        }
    }

    /// Set the number of `read` calls allowed to the underlying reader.
    pub fn set_limit(&mut self, limit: usize) {
        self.read_budget = Some(limit);
    }

    /// Remove the limit on how many `read` calls can be issued to the underlying reader.
    pub fn unthrottle(&mut self) {
        self.read_budget = None;
    }

    /// Check how many more `read` calls may be issued to the underlying reader.
    ///
    /// Returns `None` if the reader is not currently throttled.
    pub fn remaining(&self) -> Option<usize> {
        self.read_budget
    }

    /// Extract the underlying reader.
    pub fn into_inner(self) -> R {
        self.reader
    }
}

impl<R> io::Read for ThrottledReader<R>
where
    R: io::Read,
{
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        match self.read_budget.map(|r| r.checked_sub(1)) {
            None => {
                // no limit
                self.reader.read(buf)
            }
            Some(None) => {
                // past limit
                Err(io::Error::new(io::ErrorKind::WouldBlock, "read throttled"))
            }
            Some(Some(remaining)) => {
                // above limit
                self.read_budget = Some(remaining);
                self.reader.read(buf)
            }
        }
    }
}

impl<R> From<R> for ThrottledReader<R> {
    fn from(reader: R) -> Self {
        ThrottledReader::new(reader)
    }
}

impl<R> Default for ThrottledReader<R>
where
    R: Default,
{
    fn default() -> Self {
        ThrottledReader {
            reader: R::default(),
            read_budget: None,
        }
    }
}

use std::ops::{Deref, DerefMut};
impl<R> Deref for ThrottledReader<R> {
    type Target = R;
    fn deref(&self) -> &Self::Target {
        &self.reader
    }
}

impl<R> DerefMut for ThrottledReader<R> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.reader
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::io::prelude::*;

    #[test]
    fn it_works() {
        let mut s = ThrottledReader::new(io::empty());
        // initially no limit
        assert_eq!(s.read(&mut [0]).unwrap(), 0);
        assert_eq!(s.read(&mut [0]).unwrap(), 0);
        assert_eq!(s.read(&mut [0]).unwrap(), 0);

        // set a limit
        s.set_limit(2);
        assert_eq!(s.read(&mut [0]).unwrap(), 0); // first is allowed through
        assert_eq!(s.remaining(), Some(1));
        assert_eq!(s.read(&mut [0]).unwrap(), 0); // second is allowed through
        assert_eq!(s.remaining(), Some(0));
        assert_eq!(
            s.read(&mut [0]).unwrap_err().kind(),
            io::ErrorKind::WouldBlock
        ); // third is *not* allowed
        assert_eq!(s.remaining(), Some(0));
        assert_eq!(
            s.read(&mut [0]).unwrap_err().kind(),
            io::ErrorKind::WouldBlock
        ); // obviously neither is fourth
        assert_eq!(s.remaining(), Some(0));

        // unthrottle again
        s.unthrottle();
        assert_eq!(s.read(&mut [0]).unwrap(), 0);
        assert_eq!(s.read(&mut [0]).unwrap(), 0);
        assert_eq!(s.read(&mut [0]).unwrap(), 0);
    }
}