Skip to main content

tor_rtcompat/
general.rs

1//! Support for streams and listeners on `general::SocketAddr`.
2
3use async_trait::async_trait;
4use futures::{AsyncRead, AsyncWrite, StreamExt as _, stream};
5use std::io::{Error as IoError, ErrorKind as IoErrorKind, Result as IoResult};
6use std::net;
7use std::task::Poll;
8use std::{pin::Pin, task::Context};
9use tor_general_addr::unix;
10use tracing::instrument;
11
12use crate::{NetStreamListener, NetStreamProvider, StreamOps};
13use tor_general_addr::general;
14
15pub use general::{AddrParseError, SocketAddr};
16
17/// Helper trait to allow us to create a type-erased stream.
18///
19/// (Rust doesn't allow "dyn AsyncRead + AsyncWrite")
20trait ReadAndWrite: AsyncRead + AsyncWrite + StreamOps + Send + Sync {}
21impl<T> ReadAndWrite for T where T: AsyncRead + AsyncWrite + StreamOps + Send + Sync {}
22
23/// A stream returned by a `NetStreamProvider<GeneralizedAddr>`
24pub struct Stream(Pin<Box<dyn ReadAndWrite>>);
25impl AsyncRead for Stream {
26    fn poll_read(
27        mut self: Pin<&mut Self>,
28        cx: &mut Context<'_>,
29        buf: &mut [u8],
30    ) -> Poll<IoResult<usize>> {
31        self.0.as_mut().poll_read(cx, buf)
32    }
33}
34impl AsyncWrite for Stream {
35    fn poll_write(
36        mut self: Pin<&mut Self>,
37        cx: &mut Context<'_>,
38        buf: &[u8],
39    ) -> Poll<IoResult<usize>> {
40        self.0.as_mut().poll_write(cx, buf)
41    }
42
43    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
44        self.0.as_mut().poll_flush(cx)
45    }
46
47    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
48        self.0.as_mut().poll_close(cx)
49    }
50}
51
52impl StreamOps for Stream {
53    fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
54        self.0.set_tcp_notsent_lowat(notsent_lowat)
55    }
56
57    fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
58        self.0.new_handle()
59    }
60}
61
62/// The type of the result from an [`IncomingStreams`].
63type StreamItem = IoResult<(Stream, general::SocketAddr)>;
64
65/// A stream of incoming connections on a [`general::Listener`](Listener).
66pub struct IncomingStreams(Pin<Box<dyn stream::Stream<Item = StreamItem> + Send + Sync>>);
67
68impl stream::Stream for IncomingStreams {
69    type Item = IoResult<(Stream, general::SocketAddr)>;
70
71    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
72        self.0.as_mut().poll_next(cx)
73    }
74}
75
76/// A listener returned by a `NetStreamProvider<general::SocketAddr>`.
77pub struct Listener {
78    /// The `futures::Stream` of incoming network streams.
79    streams: IncomingStreams,
80    /// The local address on which we're listening.
81    local_addr: general::SocketAddr,
82}
83
84impl NetStreamListener<general::SocketAddr> for Listener {
85    type Stream = Stream;
86    type Incoming = IncomingStreams;
87
88    fn incoming(self) -> IncomingStreams {
89        self.streams
90    }
91
92    fn local_addr(&self) -> IoResult<general::SocketAddr> {
93        Ok(self.local_addr.clone())
94    }
95}
96
97/// Use `provider` to launch a `NetStreamListener` at `address`, and wrap that listener
98/// as a `Listener`.
99async fn abstract_listener_on<ADDR, P>(
100    provider: &P,
101    address: &ADDR,
102    options: &P::ListenOptions,
103) -> IoResult<Listener>
104where
105    P: NetStreamProvider<ADDR>,
106    general::SocketAddr: From<ADDR>,
107{
108    let lis = provider.listen(address, options).await?;
109    let local_addr = general::SocketAddr::from(lis.local_addr()?);
110    let streams = lis.incoming().map(|result| {
111        result.map(|(socket, addr)| (Stream(Box::pin(socket)), general::SocketAddr::from(addr)))
112    });
113    let streams = IncomingStreams(Box::pin(streams));
114    Ok(Listener {
115        streams,
116        local_addr,
117    })
118}
119
120#[async_trait]
121impl<T> NetStreamProvider<general::SocketAddr> for T
122where
123    T: NetStreamProvider<net::SocketAddr> + NetStreamProvider<unix::SocketAddr>,
124{
125    type Stream = Stream;
126    type Listener = Listener;
127    // TODO: If unix sockets ever support `CommonConnectOptions`,
128    // we could accept these common options and convert to the appropriate type.
129    type ConnectOptions = ();
130    // TODO: If unix sockets ever support `CommonListenOptions`,
131    // we could accept these common options and convert to the appropriate type.
132    type ListenOptions = ();
133
134    #[instrument(skip_all, level = "trace")]
135    async fn connect(
136        &self,
137        addr: &general::SocketAddr,
138        (): &Self::ConnectOptions,
139    ) -> IoResult<Stream> {
140        use general::SocketAddr as G;
141        match addr {
142            G::Inet(a) => {
143                let options = Default::default();
144                Ok(Stream(Box::pin(self.connect(a, &options).await?)))
145            }
146            G::Unix(a) => {
147                let options = Default::default();
148                Ok(Stream(Box::pin(self.connect(a, &options).await?)))
149            }
150            other => Err(IoError::new(
151                IoErrorKind::InvalidInput,
152                UnsupportedAddress(other.clone()),
153            )),
154        }
155    }
156    async fn listen(
157        &self,
158        addr: &general::SocketAddr,
159        (): &Self::ListenOptions,
160    ) -> IoResult<Listener> {
161        use general::SocketAddr as G;
162        match addr {
163            G::Inet(a) => abstract_listener_on(self, a, &Default::default()).await,
164            G::Unix(a) => abstract_listener_on(self, a, &Default::default()).await,
165            other => Err(IoError::new(
166                IoErrorKind::InvalidInput,
167                UnsupportedAddress(other.clone()),
168            )),
169        }
170    }
171}
172
173/// Tried to use a [`general::SocketAddr`] that `tor-rtcompat` didn't understand.
174#[derive(Clone, Debug, thiserror::Error)]
175#[error("Socket address {0:?} is not supported by tor-rtcompat")]
176pub struct UnsupportedAddress(general::SocketAddr);