1#![forbid(unsafe_code)] use crate::util::mpsc_channel;
9use futures::channel::mpsc;
10use futures::io::{AsyncRead, AsyncWrite};
11use futures::sink::{Sink, SinkExt};
12use futures::stream::Stream;
13use std::io::{Error as IoError, ErrorKind, Result as IoResult};
14use std::pin::Pin;
15use std::task::{Context, Poll};
16use tor_rtcompat::{StreamOps, UnsupportedStreamOp};
17
18const CAPACITY: usize = 4;
23
24const CHUNKSZ: usize = 213;
28
29pub fn stream_pair() -> (LocalStream, LocalStream) {
38 let (w1, r2) = mpsc_channel(CAPACITY);
39 let (w2, r1) = mpsc_channel(CAPACITY);
40 let s1 = LocalStream {
41 w: w1,
42 r: r1,
43 pending_bytes: Vec::new(),
44 tls_cert: None,
45 };
46 let s2 = LocalStream {
47 w: w2,
48 r: r2,
49 pending_bytes: Vec::new(),
50 tls_cert: None,
51 };
52 (s1, s2)
53}
54
55pub struct LocalStream {
62 w: mpsc::Sender<IoResult<Vec<u8>>>,
67 r: mpsc::Receiver<IoResult<Vec<u8>>>,
72 pending_bytes: Vec<u8>,
74 pub(crate) tls_cert: Option<Vec<u8>>,
83}
84
85fn drain_helper(buf: &mut [u8], pending_bytes: &mut Vec<u8>) -> usize {
88 let n_to_drain = std::cmp::min(buf.len(), pending_bytes.len());
89 buf[..n_to_drain].copy_from_slice(&pending_bytes[..n_to_drain]);
90 pending_bytes.drain(..n_to_drain);
91 n_to_drain
92}
93
94impl AsyncRead for LocalStream {
95 fn poll_read(
96 mut self: Pin<&mut Self>,
97 cx: &mut Context<'_>,
98 buf: &mut [u8],
99 ) -> Poll<IoResult<usize>> {
100 if buf.is_empty() {
101 return Poll::Ready(Ok(0));
102 }
103 if self.tls_cert.is_some() {
104 return Poll::Ready(Err(std::io::Error::new(
105 std::io::ErrorKind::Other,
106 "attempted to treat a TLS stream as non-TLS!",
107 )));
108 }
109 if !self.pending_bytes.is_empty() {
110 return Poll::Ready(Ok(drain_helper(buf, &mut self.pending_bytes)));
111 }
112
113 match futures::ready!(Pin::new(&mut self.r).poll_next(cx)) {
114 Some(Err(e)) => Poll::Ready(Err(e)),
115 Some(Ok(bytes)) => {
116 self.pending_bytes = bytes;
117 let n = drain_helper(buf, &mut self.pending_bytes);
118 Poll::Ready(Ok(n))
119 }
120 None => Poll::Ready(Ok(0)), }
122 }
123}
124
125impl AsyncWrite for LocalStream {
126 fn poll_write(
127 mut self: Pin<&mut Self>,
128 cx: &mut Context<'_>,
129 buf: &[u8],
130 ) -> Poll<IoResult<usize>> {
131 if self.tls_cert.is_some() {
132 return Poll::Ready(Err(std::io::Error::new(
133 std::io::ErrorKind::Other,
134 "attempted to treat a TLS stream as non-TLS!",
135 )));
136 }
137
138 match futures::ready!(Pin::new(&mut self.w).poll_ready(cx)) {
139 Ok(()) => (),
140 Err(e) => return Poll::Ready(Err(IoError::new(ErrorKind::BrokenPipe, e))),
141 }
142
143 let buf = if buf.len() > CHUNKSZ {
144 &buf[..CHUNKSZ]
145 } else {
146 buf
147 };
148 let len = buf.len();
149 match Pin::new(&mut self.w).start_send(Ok(buf.to_vec())) {
150 Ok(()) => Poll::Ready(Ok(len)),
151 Err(e) => Poll::Ready(Err(IoError::new(ErrorKind::BrokenPipe, e))),
152 }
153 }
154 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
155 Pin::new(&mut self.w)
156 .poll_flush(cx)
157 .map_err(|e| IoError::new(ErrorKind::BrokenPipe, e))
158 }
159 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
160 Pin::new(&mut self.w)
161 .poll_close(cx)
162 .map_err(|e| IoError::new(ErrorKind::Other, e))
163 }
164}
165
166impl StreamOps for LocalStream {
167 fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
168 Err(
169 UnsupportedStreamOp::new("set_tcp_notsent_lowat", "unsupported on local streams")
170 .into(),
171 )
172 }
173}
174
175#[derive(Debug, Clone, Eq, PartialEq)]
177#[non_exhaustive]
178pub struct SyntheticError;
179impl std::error::Error for SyntheticError {}
180impl std::fmt::Display for SyntheticError {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 write!(f, "Synthetic error")
183 }
184}
185
186impl LocalStream {
187 pub async fn send_err(&mut self, kind: ErrorKind) {
192 let _ignore = self.w.send(Err(IoError::new(kind, SyntheticError))).await;
193 }
194}
195
196#[cfg(all(test, not(miri)))] mod test {
198 #![allow(clippy::bool_assert_comparison)]
200 #![allow(clippy::clone_on_copy)]
201 #![allow(clippy::dbg_macro)]
202 #![allow(clippy::mixed_attributes_style)]
203 #![allow(clippy::print_stderr)]
204 #![allow(clippy::print_stdout)]
205 #![allow(clippy::single_char_pattern)]
206 #![allow(clippy::unwrap_used)]
207 #![allow(clippy::unchecked_duration_subtraction)]
208 #![allow(clippy::useless_vec)]
209 #![allow(clippy::needless_pass_by_value)]
210 use super::*;
212
213 use futures::io::{AsyncReadExt, AsyncWriteExt};
214 use futures_await_test::async_test;
215 use rand::Rng;
216 use tor_basic_utils::test_rng::testing_rng;
217
218 #[async_test]
219 async fn basic_rw() {
220 let (mut s1, mut s2) = stream_pair();
221 let mut text1 = vec![0_u8; 9999];
222 testing_rng().fill(&mut text1[..]);
223
224 let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
225 async {
226 for _ in 0_u8..10 {
227 s1.write_all(&text1[..]).await?;
228 }
229 s1.close().await?;
230 Ok(())
231 },
232 async {
233 let mut text2: Vec<u8> = Vec::new();
234 let mut buf = [0_u8; 33];
235 loop {
236 let n = s2.read(&mut buf[..]).await?;
237 if n == 0 {
238 break;
239 }
240 text2.extend(&buf[..n]);
241 }
242 for ch in text2[..].chunks(text1.len()) {
243 assert_eq!(ch, &text1[..]);
244 }
245 Ok(())
246 }
247 );
248
249 v1.unwrap();
250 v2.unwrap();
251 }
252
253 #[async_test]
254 async fn send_error() {
255 let (mut s1, mut s2) = stream_pair();
256
257 let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
258 async {
259 s1.write_all(b"hello world").await?;
260 s1.send_err(ErrorKind::PermissionDenied).await;
261 Ok(())
262 },
263 async {
264 let mut buf = [0_u8; 33];
265 loop {
266 let n = s2.read(&mut buf[..]).await?;
267 if n == 0 {
268 break;
269 }
270 }
271 Ok(())
272 }
273 );
274
275 v1.unwrap();
276 let e = v2.err().unwrap();
277 assert_eq!(e.kind(), ErrorKind::PermissionDenied);
278 let synth = e.into_inner().unwrap();
279 assert_eq!(synth.to_string(), "Synthetic error");
280 }
281
282 #[async_test]
283 async fn drop_reader() {
284 let (mut s1, s2) = stream_pair();
285
286 let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
287 async {
288 for _ in 0_u16..1000 {
289 s1.write_all(&[9_u8; 9999]).await?;
290 }
291 Ok(())
292 },
293 async {
294 drop(s2);
295 Ok(())
296 }
297 );
298
299 v2.unwrap();
300 let e = v1.err().unwrap();
301 assert_eq!(e.kind(), ErrorKind::BrokenPipe);
302 }
303}