tor_rtcompat/impls/
smol.rs

1//! Re-exports of the smol runtime for use with arti.
2//! This crate defines a slim API around our async runtime so that we
3//! can swap it out easily.
4
5/// Types used for networking (smol implementation).
6pub(crate) mod net {
7    use super::SmolRuntime;
8    use crate::{impls, traits};
9    use async_trait::async_trait;
10    use futures::future::Future;
11    use futures::stream::Stream;
12    use paste::paste;
13    use smol::Async;
14    #[cfg(unix)]
15    use smol::net::unix::{UnixListener, UnixStream};
16    use smol::net::{TcpListener, TcpStream, UdpSocket as SmolUdpSocket};
17    use std::io::Result as IoResult;
18    use std::net::SocketAddr;
19    use std::pin::Pin;
20    use std::task::{Context, Poll};
21    use tor_general_addr::unix;
22    use tracing::instrument;
23
24    /// Provide wrapper for different stream types
25    /// (e.g async_net::TcpStream and async_net::unix::UnixStream).
26    macro_rules! impl_stream {
27        { $kind:ident, $addr:ty } => { paste! {
28
29            /// A `Stream` of incoming streams.
30            pub struct [<Incoming $kind Streams>] {
31                /// A state object, stored in an Option so we can take ownership of it
32                // TODO: Currently this is an Option so we can take ownership of it using `.take()`.
33                // Check if this approach can be improved once supporting Rust 2024.
34                state: Option<[<Incoming $kind StreamsState>]>,
35            }
36
37            /// The result type returned by `take_and_poll_*`.
38            type [<$kind FResult>] = (IoResult<([<$kind Stream>], $addr)>, [<$kind Listener>]);
39
40            /// Helper to implement `Incoming*Streams`
41            async fn [<take_and_poll_ $kind:lower>](lis: [<$kind Listener>]) -> [<$kind FResult>] {
42                let result = lis.accept().await;
43                (result, lis)
44            }
45
46            /// The possible states for an `Incoming*Streams`.
47            enum [<Incoming $kind StreamsState>] {
48                /// We're ready to call `accept` on the listener again.
49                Ready([<$kind Listener>]),
50
51                /// We've called `accept` on the listener, and we're waiting
52                Accepting(Pin<Box<dyn Future<Output = [<$kind FResult>]> + Send + Sync>>),
53            }
54
55            impl [<Incoming $kind Streams>] {
56                /// Create a new `Incoming*Streams` from a listener.
57                pub fn from_listener(lis: [<$kind Listener>]) -> Self {
58                    Self { state: Some([<Incoming $kind StreamsState>]::Ready(lis)) }
59                }
60            }
61
62            impl Stream for [<Incoming $kind Streams>] {
63                type Item = IoResult<([<$kind Stream>], $addr)>;
64
65                fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
66                    use [<Incoming $kind StreamsState>] as St;
67                    let state = self.state.take().expect("No valid state!");
68                    let mut future = match state {
69                        St::Ready(lis) => Box::pin([<take_and_poll_ $kind:lower>](lis)),
70                        St::Accepting(fut) => fut,
71                    };
72                    match future.as_mut().poll(cx) {
73                        Poll::Ready((val, lis)) => {
74                            self.state = Some(St::Ready(lis));
75                            Poll::Ready(Some(val))
76                        }
77                        Poll::Pending => {
78                            self.state = Some(St::Accepting(future));
79                            Poll::Pending
80                        }
81                    }
82                }
83            }
84
85            impl traits::NetStreamListener<$addr> for [<$kind Listener>] {
86                type Stream = [<$kind Stream>];
87                type Incoming = [<Incoming $kind Streams>];
88
89                fn incoming(self) -> Self::Incoming {
90                    [<Incoming $kind Streams>]::from_listener(self)
91                }
92
93                fn local_addr(&self) -> IoResult<$addr> {
94                    [<$kind Listener>]::local_addr(self)
95                }
96            }
97        }}
98    }
99
100    impl_stream! { Tcp, SocketAddr }
101    #[cfg(unix)]
102    impl_stream! { Unix, unix::SocketAddr }
103
104    #[async_trait]
105    impl traits::NetStreamProvider<SocketAddr> for SmolRuntime {
106        type Stream = TcpStream;
107        type Listener = TcpListener;
108
109        #[instrument(skip_all, level = "trace")]
110        async fn connect(&self, addr: &SocketAddr) -> IoResult<Self::Stream> {
111            TcpStream::connect(addr).await
112        }
113
114        async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::Listener> {
115            // Use an implementation that's the same across all runtimes.
116            // The socket is already non-blocking, so `Async` doesn't need to set as non-blocking
117            // again. If it *were* to be blocking, then I/O operations would block in async
118            // contexts, which would lead to deadlocks.
119            Ok(Async::new_nonblocking(impls::tcp_listen(addr)?)?.into())
120        }
121    }
122
123    #[cfg(unix)]
124    #[async_trait]
125    impl traits::NetStreamProvider<unix::SocketAddr> for SmolRuntime {
126        type Stream = UnixStream;
127        type Listener = UnixListener;
128
129        #[instrument(skip_all, level = "trace")]
130        async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
131            let path = addr
132                .as_pathname()
133                .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
134            UnixStream::connect(path).await
135        }
136
137        async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
138            let path = addr
139                .as_pathname()
140                .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
141            UnixListener::bind(path)
142        }
143    }
144
145    #[cfg(not(unix))]
146    crate::impls::impl_unix_non_provider! { SmolRuntime }
147
148    #[async_trait]
149    impl traits::UdpProvider for SmolRuntime {
150        type UdpSocket = UdpSocket;
151
152        async fn bind(&self, addr: &SocketAddr) -> IoResult<Self::UdpSocket> {
153            SmolUdpSocket::bind(addr)
154                .await
155                .map(|socket| UdpSocket { socket })
156        }
157    }
158
159    /// Wrapper for `SmolUdpSocket`.
160    // Required to implement `traits::UdpSocket`.
161    pub struct UdpSocket {
162        /// The underlying socket.
163        socket: SmolUdpSocket,
164    }
165
166    #[async_trait]
167    impl traits::UdpSocket for UdpSocket {
168        async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
169            self.socket.recv_from(buf).await
170        }
171
172        async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
173            self.socket.send_to(buf, target).await
174        }
175
176        fn local_addr(&self) -> IoResult<SocketAddr> {
177            self.socket.local_addr()
178        }
179    }
180
181    impl traits::StreamOps for TcpStream {
182        fn set_tcp_notsent_lowat(&self, lowat: u32) -> IoResult<()> {
183            impls::streamops::set_tcp_notsent_lowat(self, lowat)
184        }
185
186        #[cfg(target_os = "linux")]
187        fn new_handle(&self) -> Box<dyn traits::StreamOps + Send + Unpin> {
188            Box::new(impls::streamops::TcpSockFd::from_fd(self))
189        }
190    }
191
192    #[cfg(unix)]
193    impl traits::StreamOps for UnixStream {
194        fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
195            Err(traits::UnsupportedStreamOp::new(
196                "set_tcp_notsent_lowat",
197                "unsupported on Unix streams",
198            )
199            .into())
200        }
201    }
202}
203
204// ==============================
205
206use crate::traits::*;
207use futures::task::{FutureObj, Spawn, SpawnError};
208use futures::{Future, FutureExt};
209use std::pin::Pin;
210use std::time::Duration;
211
212/// Type to wrap `smol::Executor`.
213#[derive(Clone)]
214pub struct SmolRuntime {
215    /// Instance of the smol executor we own.
216    executor: std::sync::Arc<smol::Executor<'static>>,
217}
218
219/// Construct new instance of the smol runtime.
220//
221// TODO: Make SmolRuntime multi-threaded.
222pub fn create_runtime() -> SmolRuntime {
223    SmolRuntime {
224        executor: std::sync::Arc::new(smol::Executor::new()),
225    }
226}
227
228impl SleepProvider for SmolRuntime {
229    type SleepFuture = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
230    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
231        Box::pin(async_io::Timer::after(duration).map(|_| ()))
232    }
233}
234
235impl ToplevelBlockOn for SmolRuntime {
236    fn block_on<F: Future>(&self, f: F) -> F::Output {
237        smol::block_on(self.executor.run(f))
238    }
239}
240
241impl Blocking for SmolRuntime {
242    type ThreadHandle<T: Send + 'static> = blocking::Task<T>;
243
244    fn spawn_blocking<F, T>(&self, f: F) -> blocking::Task<T>
245    where
246        F: FnOnce() -> T + Send + 'static,
247        T: Send + 'static,
248    {
249        smol::unblock(f)
250    }
251
252    fn reenter_block_on<F: Future>(&self, f: F) -> F::Output {
253        smol::block_on(self.executor.run(f))
254    }
255}
256
257impl Spawn for SmolRuntime {
258    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
259        self.executor.spawn(future).detach();
260        Ok(())
261    }
262}