Skip to main content

tor_rtcompat/impls/
smol.rs

1//! Re-exports of the smol runtime for use with arti.
2//! This crate defines a slim API around our async runtime so that we
3//! can swap it out easily.
4
5/// Types used for networking (smol implementation).
6pub(crate) mod net {
7    use super::SmolRuntime;
8    use crate::{impls, traits};
9    use async_trait::async_trait;
10    use futures::stream::{self, Stream};
11    use paste::paste;
12    use smol::Async;
13    #[cfg(unix)]
14    use smol::net::unix::{UnixListener, UnixStream};
15    use smol::net::{TcpListener, TcpStream, UdpSocket as SmolUdpSocket};
16    use std::io::Result as IoResult;
17    use std::net::SocketAddr;
18    use std::pin::Pin;
19    use std::task::{Context, Poll};
20    use tor_general_addr::unix;
21    use tracing::instrument;
22
23    /// Provide wrapper for different stream types
24    /// (e.g async_net::TcpStream and async_net::unix::UnixStream).
25    macro_rules! impl_stream {
26        { $kind:ident, $addr:ty } => { paste! {
27
28            /// A `Stream` of incoming streams.
29            pub struct [<Incoming $kind Streams>] {
30                /// Underlying stream of incoming connections.
31                inner: Pin<Box<dyn Stream<Item = IoResult<([<$kind Stream>], $addr)>> + Send + Sync>>,
32            }
33
34            impl [<Incoming $kind Streams>] {
35                /// Create a new `Incoming*Streams` from a listener.
36                pub fn from_listener(lis: [<$kind Listener>]) -> Self {
37                    let stream = stream::unfold(lis, |lis| async move {
38                        let result = lis.accept().await;
39                        Some((result, lis))
40                    });
41                    Self {
42                        inner: Box::pin(stream),
43                    }
44                }
45            }
46
47            impl Stream for [<Incoming $kind Streams>] {
48                type Item = IoResult<([<$kind Stream>], $addr)>;
49
50                fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
51                    self.inner.as_mut().poll_next(cx)
52                }
53            }
54
55            impl traits::NetStreamListener<$addr> for [<$kind Listener>] {
56                type Stream = [<$kind Stream>];
57                type Incoming = [<Incoming $kind Streams>];
58
59                fn incoming(self) -> Self::Incoming {
60                    [<Incoming $kind Streams>]::from_listener(self)
61                }
62
63                fn local_addr(&self) -> IoResult<$addr> {
64                    [<$kind Listener>]::local_addr(self)
65                }
66            }
67        }}
68    }
69
70    impl_stream! { Tcp, SocketAddr }
71    #[cfg(unix)]
72    impl_stream! { Unix, unix::SocketAddr }
73
74    #[async_trait]
75    impl traits::NetStreamProvider<SocketAddr> for SmolRuntime {
76        type Stream = TcpStream;
77        type Listener = TcpListener;
78
79        #[instrument(skip_all, level = "trace")]
80        async fn connect(&self, addr: &SocketAddr) -> IoResult<Self::Stream> {
81            TcpStream::connect(addr).await
82        }
83
84        async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::Listener> {
85            // Use an implementation that's the same across all runtimes.
86            // The socket is already non-blocking, so `Async` doesn't need to set as non-blocking
87            // again. If it *were* to be blocking, then I/O operations would block in async
88            // contexts, which would lead to deadlocks.
89            Ok(Async::new_nonblocking(impls::tcp_listen(addr)?)?.into())
90        }
91    }
92
93    #[cfg(unix)]
94    #[async_trait]
95    impl traits::NetStreamProvider<unix::SocketAddr> for SmolRuntime {
96        type Stream = UnixStream;
97        type Listener = UnixListener;
98
99        #[instrument(skip_all, level = "trace")]
100        async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
101            let path = addr
102                .as_pathname()
103                .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
104            UnixStream::connect(path).await
105        }
106
107        async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
108            let path = addr
109                .as_pathname()
110                .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
111            UnixListener::bind(path)
112        }
113    }
114
115    #[cfg(not(unix))]
116    crate::impls::impl_unix_non_provider! { SmolRuntime }
117
118    #[async_trait]
119    impl traits::UdpProvider for SmolRuntime {
120        type UdpSocket = UdpSocket;
121
122        async fn bind(&self, addr: &SocketAddr) -> IoResult<Self::UdpSocket> {
123            SmolUdpSocket::bind(addr)
124                .await
125                .map(|socket| UdpSocket { socket })
126        }
127    }
128
129    /// Wrapper for `SmolUdpSocket`.
130    // Required to implement `traits::UdpSocket`.
131    pub struct UdpSocket {
132        /// The underlying socket.
133        socket: SmolUdpSocket,
134    }
135
136    #[async_trait]
137    impl traits::UdpSocket for UdpSocket {
138        async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
139            self.socket.recv_from(buf).await
140        }
141
142        async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
143            self.socket.send_to(buf, target).await
144        }
145
146        fn local_addr(&self) -> IoResult<SocketAddr> {
147            self.socket.local_addr()
148        }
149    }
150
151    impl traits::StreamOps for TcpStream {
152        fn set_tcp_notsent_lowat(&self, lowat: u32) -> IoResult<()> {
153            impls::streamops::set_tcp_notsent_lowat(self, lowat)
154        }
155
156        #[cfg(target_os = "linux")]
157        fn new_handle(&self) -> Box<dyn traits::StreamOps + Send + Unpin> {
158            Box::new(impls::streamops::TcpSockFd::from_fd(self))
159        }
160    }
161
162    #[cfg(unix)]
163    impl traits::StreamOps for UnixStream {
164        fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
165            Err(traits::UnsupportedStreamOp::new(
166                "set_tcp_notsent_lowat",
167                "unsupported on Unix streams",
168            )
169            .into())
170        }
171    }
172}
173
174// ==============================
175
176use crate::traits::*;
177use futures::task::{FutureObj, Spawn, SpawnError};
178use futures::{Future, FutureExt};
179use std::pin::Pin;
180use std::time::Duration;
181
182/// Type to wrap `smol::Executor`.
183#[derive(Clone)]
184pub struct SmolRuntime {
185    /// Instance of the smol executor we own.
186    executor: std::sync::Arc<smol::Executor<'static>>,
187}
188
189/// Construct new instance of the smol runtime.
190//
191// TODO: Make SmolRuntime multi-threaded.
192pub fn create_runtime() -> SmolRuntime {
193    SmolRuntime {
194        executor: std::sync::Arc::new(smol::Executor::new()),
195    }
196}
197
198impl SleepProvider for SmolRuntime {
199    type SleepFuture = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
200    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
201        Box::pin(async_io::Timer::after(duration).map(|_| ()))
202    }
203}
204
205impl ToplevelBlockOn for SmolRuntime {
206    fn block_on<F: Future>(&self, f: F) -> F::Output {
207        smol::block_on(self.executor.run(f))
208    }
209}
210
211impl Blocking for SmolRuntime {
212    type ThreadHandle<T: Send + 'static> = blocking::Task<T>;
213
214    fn spawn_blocking<F, T>(&self, f: F) -> blocking::Task<T>
215    where
216        F: FnOnce() -> T + Send + 'static,
217        T: Send + 'static,
218    {
219        smol::unblock(f)
220    }
221
222    fn reenter_block_on<F: Future>(&self, f: F) -> F::Output {
223        smol::block_on(self.executor.run(f))
224    }
225}
226
227impl Spawn for SmolRuntime {
228    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
229        self.executor.spawn(future).detach();
230        Ok(())
231    }
232}