tor_rtcompat/
compound.rs

1//! Define a [`CompoundRuntime`] part that can be built from several component
2//! pieces.
3
4use std::{net, sync::Arc, time::Duration};
5
6use crate::traits::*;
7use crate::{CoarseInstant, CoarseTimeProvider};
8use async_trait::async_trait;
9use educe::Educe;
10use futures::{future::FutureObj, task::Spawn};
11use std::io::Result as IoResult;
12use std::time::{Instant, SystemTime};
13use tor_general_addr::unix;
14
15/// A runtime made of several parts, each of which implements one trait-group.
16///
17/// The `SpawnR` component should implements [`Spawn`] and [`BlockOn`];
18/// the `SleepR` component should implement [`SleepProvider`];
19/// the `CoarseTimeR` component should implement [`CoarseTimeProvider`];
20/// the `TcpR` component should implement [`NetStreamProvider`] for [`net::SocketAddr`];
21/// the `UnixR` component should implement [`NetStreamProvider`] for [`unix::SocketAddr`];
22/// and
23/// the `TlsR` component should implement [`TlsProvider`].
24///
25/// You can use this structure to create new runtimes in two ways: either by
26/// overriding a single part of an existing runtime, or by building an entirely
27/// new runtime from pieces.
28#[derive(Educe)]
29#[educe(Clone)] // #[derive(Clone)] wrongly infers Clone bounds on the generic parameters
30pub struct CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
31    /// The actual collection of Runtime objects.
32    ///
33    /// We wrap this in an Arc rather than requiring that each item implement
34    /// Clone, though we could change our minds later on.
35    inner: Arc<Inner<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>>,
36}
37
38/// A collection of objects implementing that traits that make up a [`Runtime`]
39struct Inner<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
40    /// A `Spawn` and `BlockOn` implementation.
41    spawn: SpawnR,
42    /// A `SleepProvider` implementation.
43    sleep: SleepR,
44    /// A `CoarseTimeProvider`` implementation.
45    coarse_time: CoarseTimeR,
46    /// A `NetStreamProvider<net::SocketAddr>` implementation
47    tcp: TcpR,
48    /// A `NetStreamProvider<unix::SocketAddr>` implementation.
49    unix: UnixR,
50    /// A `TlsProvider<TcpR::TcpStream>` implementation.
51    tls: TlsR,
52    /// A `UdpProvider` implementation
53    udp: UdpR,
54}
55
56impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
57    CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
58{
59    /// Construct a new CompoundRuntime from its components.
60    pub fn new(
61        spawn: SpawnR,
62        sleep: SleepR,
63        coarse_time: CoarseTimeR,
64        tcp: TcpR,
65        unix: UnixR,
66        tls: TlsR,
67        udp: UdpR,
68    ) -> Self {
69        #[allow(clippy::arc_with_non_send_sync)]
70        CompoundRuntime {
71            inner: Arc::new(Inner {
72                spawn,
73                sleep,
74                coarse_time,
75                tcp,
76                unix,
77                tls,
78                udp,
79            }),
80        }
81    }
82}
83
84impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Spawn
85    for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
86where
87    SpawnR: Spawn,
88{
89    #[inline]
90    #[track_caller]
91    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> {
92        self.inner.spawn.spawn_obj(future)
93    }
94}
95
96impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> SpawnBlocking
97    for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
98where
99    SpawnR: SpawnBlocking,
100    SleepR: Clone + Send + Sync + 'static,
101    CoarseTimeR: Clone + Send + Sync + 'static,
102    TcpR: Clone + Send + Sync + 'static,
103    UnixR: Clone + Send + Sync + 'static,
104    TlsR: Clone + Send + Sync + 'static,
105    UdpR: Clone + Send + Sync + 'static,
106{
107    type Handle<T: Send + 'static> = SpawnR::Handle<T>;
108
109    #[inline]
110    #[track_caller]
111    fn spawn_blocking<F, T>(&self, f: F) -> SpawnR::Handle<T>
112    where
113        F: FnOnce() -> T + Send + 'static,
114        T: Send + 'static,
115    {
116        self.inner.spawn.spawn_blocking(f)
117    }
118}
119
120impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> BlockOn
121    for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
122where
123    SpawnR: BlockOn,
124    SleepR: Clone + Send + Sync + 'static,
125    CoarseTimeR: Clone + Send + Sync + 'static,
126    TcpR: Clone + Send + Sync + 'static,
127    UnixR: Clone + Send + Sync + 'static,
128    TlsR: Clone + Send + Sync + 'static,
129    UdpR: Clone + Send + Sync + 'static,
130{
131    #[inline]
132    #[track_caller]
133    fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
134        self.inner.spawn.block_on(future)
135    }
136}
137
138impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> SleepProvider
139    for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
140where
141    SleepR: SleepProvider,
142    SpawnR: Clone + Send + Sync + 'static,
143    CoarseTimeR: Clone + Send + Sync + 'static,
144    TcpR: Clone + Send + Sync + 'static,
145    UnixR: Clone + Send + Sync + 'static,
146    TlsR: Clone + Send + Sync + 'static,
147    UdpR: Clone + Send + Sync + 'static,
148{
149    type SleepFuture = SleepR::SleepFuture;
150
151    #[inline]
152    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
153        self.inner.sleep.sleep(duration)
154    }
155
156    #[inline]
157    fn now(&self) -> Instant {
158        self.inner.sleep.now()
159    }
160
161    #[inline]
162    fn wallclock(&self) -> SystemTime {
163        self.inner.sleep.wallclock()
164    }
165}
166
167impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> CoarseTimeProvider
168    for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
169where
170    CoarseTimeR: CoarseTimeProvider,
171    SleepR: Clone + Send + Sync + 'static,
172    SpawnR: Clone + Send + Sync + 'static,
173    CoarseTimeR: Clone + Send + Sync + 'static,
174    TcpR: Clone + Send + Sync + 'static,
175    UnixR: Clone + Send + Sync + 'static,
176    TlsR: Clone + Send + Sync + 'static,
177    UdpR: Clone + Send + Sync + 'static,
178{
179    #[inline]
180    fn now_coarse(&self) -> CoarseInstant {
181        self.inner.coarse_time.now_coarse()
182    }
183}
184
185#[async_trait]
186impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<net::SocketAddr>
187    for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
188where
189    TcpR: NetStreamProvider<net::SocketAddr>,
190    SpawnR: Send + Sync + 'static,
191    SleepR: Send + Sync + 'static,
192    CoarseTimeR: Send + Sync + 'static,
193    TcpR: Send + Sync + 'static,
194    UnixR: Clone + Send + Sync + 'static,
195    TlsR: Send + Sync + 'static,
196    UdpR: Send + Sync + 'static,
197{
198    type Stream = TcpR::Stream;
199
200    type Listener = TcpR::Listener;
201
202    #[inline]
203    async fn connect(&self, addr: &net::SocketAddr) -> IoResult<Self::Stream> {
204        self.inner.tcp.connect(addr).await
205    }
206
207    #[inline]
208    async fn listen(&self, addr: &net::SocketAddr) -> IoResult<Self::Listener> {
209        self.inner.tcp.listen(addr).await
210    }
211}
212
213#[async_trait]
214impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<unix::SocketAddr>
215    for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
216where
217    UnixR: NetStreamProvider<unix::SocketAddr>,
218    SpawnR: Send + Sync + 'static,
219    SleepR: Send + Sync + 'static,
220    CoarseTimeR: Send + Sync + 'static,
221    TcpR: Send + Sync + 'static,
222    UnixR: Clone + Send + Sync + 'static,
223    TlsR: Send + Sync + 'static,
224    UdpR: Send + Sync + 'static,
225{
226    type Stream = UnixR::Stream;
227
228    type Listener = UnixR::Listener;
229
230    #[inline]
231    async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
232        self.inner.unix.connect(addr).await
233    }
234
235    #[inline]
236    async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
237        self.inner.unix.listen(addr).await
238    }
239}
240
241impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR, S> TlsProvider<S>
242    for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
243where
244    TcpR: NetStreamProvider,
245    TlsR: TlsProvider<S>,
246    UnixR: Clone + Send + Sync + 'static,
247    SleepR: Clone + Send + Sync + 'static,
248    CoarseTimeR: Clone + Send + Sync + 'static,
249    SpawnR: Clone + Send + Sync + 'static,
250    UdpR: Clone + Send + Sync + 'static,
251    S: StreamOps,
252{
253    type Connector = TlsR::Connector;
254    type TlsStream = TlsR::TlsStream;
255
256    #[inline]
257    fn tls_connector(&self) -> Self::Connector {
258        self.inner.tls.tls_connector()
259    }
260
261    #[inline]
262    fn supports_keying_material_export(&self) -> bool {
263        self.inner.tls.supports_keying_material_export()
264    }
265}
266
267impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> std::fmt::Debug
268    for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
269{
270    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271        f.debug_struct("CompoundRuntime").finish_non_exhaustive()
272    }
273}
274
275#[async_trait]
276impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> UdpProvider
277    for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
278where
279    UdpR: UdpProvider,
280    SpawnR: Send + Sync + 'static,
281    SleepR: Send + Sync + 'static,
282    CoarseTimeR: Send + Sync + 'static,
283    TcpR: Send + Sync + 'static,
284    UnixR: Clone + Send + Sync + 'static,
285    TlsR: Send + Sync + 'static,
286    UdpR: Send + Sync + 'static,
287{
288    type UdpSocket = UdpR::UdpSocket;
289
290    #[inline]
291    async fn bind(&self, addr: &net::SocketAddr) -> IoResult<Self::UdpSocket> {
292        self.inner.udp.bind(addr).await
293    }
294}
295
296/// Module to seal RuntimeSubstExt
297mod sealed {
298    /// Helper for sealing RuntimeSubstExt
299    #[allow(unreachable_pub)]
300    pub trait Sealed {}
301}
302/// Extension trait on Runtime:
303/// Construct new Runtimes that replace part of an original runtime.
304///
305/// (If you need to do more complicated versions of this, you should likely construct
306/// CompoundRuntime directly.)
307pub trait RuntimeSubstExt: sealed::Sealed + Sized {
308    /// Return a new runtime wrapping this runtime, but replacing its TCP NetStreamProvider.
309    fn with_tcp_provider<T>(
310        &self,
311        new_tcp: T,
312    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self>;
313    /// Return a new runtime wrapping this runtime, but replacing its SleepProvider.
314    fn with_sleep_provider<T>(
315        &self,
316        new_sleep: T,
317    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self>;
318    /// Return a new runtime wrapping this runtime, but replacing its CoarseTimeProvider.
319    fn with_coarse_time_provider<T>(
320        &self,
321        new_coarse_time: T,
322    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self>;
323}
324impl<R: Runtime> sealed::Sealed for R {}
325impl<R: Runtime + Sized> RuntimeSubstExt for R {
326    fn with_tcp_provider<T>(
327        &self,
328        new_tcp: T,
329    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self> {
330        CompoundRuntime::new(
331            self.clone(),
332            self.clone(),
333            self.clone(),
334            new_tcp,
335            self.clone(),
336            self.clone(),
337            self.clone(),
338        )
339    }
340
341    fn with_sleep_provider<T>(
342        &self,
343        new_sleep: T,
344    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self> {
345        CompoundRuntime::new(
346            self.clone(),
347            new_sleep,
348            self.clone(),
349            self.clone(),
350            self.clone(),
351            self.clone(),
352            self.clone(),
353        )
354    }
355
356    fn with_coarse_time_provider<T>(
357        &self,
358        new_coarse_time: T,
359    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self> {
360        CompoundRuntime::new(
361            self.clone(),
362            self.clone(),
363            new_coarse_time,
364            self.clone(),
365            self.clone(),
366            self.clone(),
367            self.clone(),
368        )
369    }
370}