tor_rtmock/
io.rs

1//! Mocking helpers for testing with futures::io types.
2//!
3//! Note that some of this code might be of general use, but for now
4//! we're only trying it for testing.
5
6#![forbid(unsafe_code)] // if you remove this, enable (or write) miri tests (git grep miri)
7
8use crate::util::mpsc_channel;
9use futures::channel::mpsc;
10use futures::io::{AsyncRead, AsyncWrite};
11use futures::sink::{Sink, SinkExt};
12use futures::stream::Stream;
13use std::io::{Error as IoError, ErrorKind, Result as IoResult};
14use std::pin::Pin;
15use std::task::{Context, Poll};
16use tor_rtcompat::{StreamOps, UnsupportedStreamOp};
17
18/// Channel capacity for our internal MPSC channels.
19///
20/// We keep this intentionally low to make sure that some blocking
21/// will occur occur.
22const CAPACITY: usize = 4;
23
24/// Maximum size for a queued buffer on a local chunk.
25///
26/// This size is deliberately weird, to try to find errors.
27const CHUNKSZ: usize = 213;
28
29/// Construct a new pair of linked LocalStream objects.
30///
31/// Any bytes written to one will be readable on the other, and vice
32/// versa.  These streams will behave more or less like a socketpair,
33/// except without actually going through the operating system.
34///
35/// Note that this implementation is intended for testing only, and
36/// isn't optimized.
37pub fn stream_pair() -> (LocalStream, LocalStream) {
38    let (w1, r2) = mpsc_channel(CAPACITY);
39    let (w2, r1) = mpsc_channel(CAPACITY);
40    let s1 = LocalStream {
41        w: w1,
42        r: r1,
43        pending_bytes: Vec::new(),
44        tls_cert: None,
45    };
46    let s2 = LocalStream {
47        w: w2,
48        r: r2,
49        pending_bytes: Vec::new(),
50        tls_cert: None,
51    };
52    (s1, s2)
53}
54
55/// One half of a pair of linked streams returned by [`stream_pair`].
56//
57// Implementation notes: linked streams are made out a pair of mpsc
58// channels.  There's one channel for sending bytes in each direction.
59// Bytes are sent as IoResult<Vec<u8>>: sending an error causes an error
60// to occur on the other side.
61pub struct LocalStream {
62    /// The writing side of the channel that we use to implement this
63    /// stream.
64    ///
65    /// The reading side is held by the other linked stream.
66    w: mpsc::Sender<IoResult<Vec<u8>>>,
67    /// The reading side of the channel that we use to implement this
68    /// stream.
69    ///
70    /// The writing side is held by the other linked stream.
71    r: mpsc::Receiver<IoResult<Vec<u8>>>,
72    /// Bytes that we have read from `r` but not yet delivered.
73    pending_bytes: Vec<u8>,
74    /// Data about the other side of this stream's fake TLS certificate, if any.
75    /// If this is present, I/O operations will fail with an error.
76    ///
77    /// How this is intended to work: things that return `LocalStream`s that could potentially
78    /// be connected to a fake TLS listener should set this field. Then, a fake TLS wrapper
79    /// type would clear this field (after checking its contents are as expected).
80    ///
81    /// FIXME(eta): this is a bit of a layering violation, but it's hard to do otherwise
82    pub(crate) tls_cert: Option<Vec<u8>>,
83}
84
85/// Helper: pull bytes off the front of `pending_bytes` and put them
86/// onto `buf.  Return the number of bytes moved.
87fn drain_helper(buf: &mut [u8], pending_bytes: &mut Vec<u8>) -> usize {
88    let n_to_drain = std::cmp::min(buf.len(), pending_bytes.len());
89    buf[..n_to_drain].copy_from_slice(&pending_bytes[..n_to_drain]);
90    pending_bytes.drain(..n_to_drain);
91    n_to_drain
92}
93
94impl AsyncRead for LocalStream {
95    fn poll_read(
96        mut self: Pin<&mut Self>,
97        cx: &mut Context<'_>,
98        buf: &mut [u8],
99    ) -> Poll<IoResult<usize>> {
100        if buf.is_empty() {
101            return Poll::Ready(Ok(0));
102        }
103        if self.tls_cert.is_some() {
104            return Poll::Ready(Err(std::io::Error::new(
105                std::io::ErrorKind::Other,
106                "attempted to treat a TLS stream as non-TLS!",
107            )));
108        }
109        if !self.pending_bytes.is_empty() {
110            return Poll::Ready(Ok(drain_helper(buf, &mut self.pending_bytes)));
111        }
112
113        match futures::ready!(Pin::new(&mut self.r).poll_next(cx)) {
114            Some(Err(e)) => Poll::Ready(Err(e)),
115            Some(Ok(bytes)) => {
116                self.pending_bytes = bytes;
117                let n = drain_helper(buf, &mut self.pending_bytes);
118                Poll::Ready(Ok(n))
119            }
120            None => Poll::Ready(Ok(0)), // This is an EOF
121        }
122    }
123}
124
125impl AsyncWrite for LocalStream {
126    fn poll_write(
127        mut self: Pin<&mut Self>,
128        cx: &mut Context<'_>,
129        buf: &[u8],
130    ) -> Poll<IoResult<usize>> {
131        if self.tls_cert.is_some() {
132            return Poll::Ready(Err(std::io::Error::new(
133                std::io::ErrorKind::Other,
134                "attempted to treat a TLS stream as non-TLS!",
135            )));
136        }
137
138        match futures::ready!(Pin::new(&mut self.w).poll_ready(cx)) {
139            Ok(()) => (),
140            Err(e) => return Poll::Ready(Err(IoError::new(ErrorKind::BrokenPipe, e))),
141        }
142
143        let buf = if buf.len() > CHUNKSZ {
144            &buf[..CHUNKSZ]
145        } else {
146            buf
147        };
148        let len = buf.len();
149        match Pin::new(&mut self.w).start_send(Ok(buf.to_vec())) {
150            Ok(()) => Poll::Ready(Ok(len)),
151            Err(e) => Poll::Ready(Err(IoError::new(ErrorKind::BrokenPipe, e))),
152        }
153    }
154    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
155        Pin::new(&mut self.w)
156            .poll_flush(cx)
157            .map_err(|e| IoError::new(ErrorKind::BrokenPipe, e))
158    }
159    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
160        Pin::new(&mut self.w)
161            .poll_close(cx)
162            .map_err(|e| IoError::new(ErrorKind::Other, e))
163    }
164}
165
166impl StreamOps for LocalStream {
167    fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
168        Err(
169            UnsupportedStreamOp::new("set_tcp_notsent_lowat", "unsupported on local streams")
170                .into(),
171        )
172    }
173}
174
175/// An error generated by [`LocalStream::send_err`].
176#[derive(Debug, Clone, Eq, PartialEq)]
177#[non_exhaustive]
178pub struct SyntheticError;
179impl std::error::Error for SyntheticError {}
180impl std::fmt::Display for SyntheticError {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        write!(f, "Synthetic error")
183    }
184}
185
186impl LocalStream {
187    /// Send an error to the other linked local stream.
188    ///
189    /// When the other stream reads this message, it will generate a
190    /// [`std::io::Error`] with the provided `ErrorKind`.
191    pub async fn send_err(&mut self, kind: ErrorKind) {
192        let _ignore = self.w.send(Err(IoError::new(kind, SyntheticError))).await;
193    }
194}
195
196#[cfg(all(test, not(miri)))] // These tests are very slow under miri
197mod test {
198    // @@ begin test lint list maintained by maint/add_warning @@
199    #![allow(clippy::bool_assert_comparison)]
200    #![allow(clippy::clone_on_copy)]
201    #![allow(clippy::dbg_macro)]
202    #![allow(clippy::mixed_attributes_style)]
203    #![allow(clippy::print_stderr)]
204    #![allow(clippy::print_stdout)]
205    #![allow(clippy::single_char_pattern)]
206    #![allow(clippy::unwrap_used)]
207    #![allow(clippy::unchecked_duration_subtraction)]
208    #![allow(clippy::useless_vec)]
209    #![allow(clippy::needless_pass_by_value)]
210    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
211    use super::*;
212
213    use futures::io::{AsyncReadExt, AsyncWriteExt};
214    use futures_await_test::async_test;
215    use rand::Rng;
216    use tor_basic_utils::test_rng::testing_rng;
217
218    #[async_test]
219    async fn basic_rw() {
220        let (mut s1, mut s2) = stream_pair();
221        let mut text1 = vec![0_u8; 9999];
222        testing_rng().fill(&mut text1[..]);
223
224        let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
225            async {
226                for _ in 0_u8..10 {
227                    s1.write_all(&text1[..]).await?;
228                }
229                s1.close().await?;
230                Ok(())
231            },
232            async {
233                let mut text2: Vec<u8> = Vec::new();
234                let mut buf = [0_u8; 33];
235                loop {
236                    let n = s2.read(&mut buf[..]).await?;
237                    if n == 0 {
238                        break;
239                    }
240                    text2.extend(&buf[..n]);
241                }
242                for ch in text2[..].chunks(text1.len()) {
243                    assert_eq!(ch, &text1[..]);
244                }
245                Ok(())
246            }
247        );
248
249        v1.unwrap();
250        v2.unwrap();
251    }
252
253    #[async_test]
254    async fn send_error() {
255        let (mut s1, mut s2) = stream_pair();
256
257        let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
258            async {
259                s1.write_all(b"hello world").await?;
260                s1.send_err(ErrorKind::PermissionDenied).await;
261                Ok(())
262            },
263            async {
264                let mut buf = [0_u8; 33];
265                loop {
266                    let n = s2.read(&mut buf[..]).await?;
267                    if n == 0 {
268                        break;
269                    }
270                }
271                Ok(())
272            }
273        );
274
275        v1.unwrap();
276        let e = v2.err().unwrap();
277        assert_eq!(e.kind(), ErrorKind::PermissionDenied);
278        let synth = e.into_inner().unwrap();
279        assert_eq!(synth.to_string(), "Synthetic error");
280    }
281
282    #[async_test]
283    async fn drop_reader() {
284        let (mut s1, s2) = stream_pair();
285
286        let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
287            async {
288                for _ in 0_u16..1000 {
289                    s1.write_all(&[9_u8; 9999]).await?;
290                }
291                Ok(())
292            },
293            async {
294                drop(s2);
295                Ok(())
296            }
297        );
298
299        v2.unwrap();
300        let e = v1.err().unwrap();
301        assert_eq!(e.kind(), ErrorKind::BrokenPipe);
302    }
303}