tor_rtcompat/
compound.rs

1//! Define a [`CompoundRuntime`] part that can be built from several component
2//! pieces.
3
4use std::{net, sync::Arc, time::Duration};
5
6use crate::traits::*;
7use crate::{CoarseInstant, CoarseTimeProvider};
8use async_trait::async_trait;
9use educe::Educe;
10use futures::{future::FutureObj, task::Spawn};
11use std::future::Future;
12use std::io::Result as IoResult;
13use std::time::{Instant, SystemTime};
14use tor_general_addr::unix;
15use tracing::instrument;
16
17/// A runtime made of several parts, each of which implements one trait-group.
18///
19/// The `TaskR` component should implement [`Spawn`], [`Blocking`] and maybe [`ToplevelBlockOn`];
20/// the `SleepR` component should implement [`SleepProvider`];
21/// the `CoarseTimeR` component should implement [`CoarseTimeProvider`];
22/// the `TcpR` component should implement [`NetStreamProvider`] for [`net::SocketAddr`];
23/// the `UnixR` component should implement [`NetStreamProvider`] for [`unix::SocketAddr`];
24/// and
25/// the `TlsR` component should implement [`TlsProvider`].
26///
27/// You can use this structure to create new runtimes in two ways: either by
28/// overriding a single part of an existing runtime, or by building an entirely
29/// new runtime from pieces.
30#[derive(Educe)]
31#[educe(Clone)] // #[derive(Clone)] wrongly infers Clone bounds on the generic parameters
32pub struct CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
33    /// The actual collection of Runtime objects.
34    ///
35    /// We wrap this in an Arc rather than requiring that each item implement
36    /// Clone, though we could change our minds later on.
37    inner: Arc<Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>>,
38}
39
40/// A collection of objects implementing that traits that make up a [`Runtime`]
41struct Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
42    /// A `Spawn` and `BlockOn` implementation.
43    spawn: TaskR,
44    /// A `SleepProvider` implementation.
45    sleep: SleepR,
46    /// A `CoarseTimeProvider`` implementation.
47    coarse_time: CoarseTimeR,
48    /// A `NetStreamProvider<net::SocketAddr>` implementation
49    tcp: TcpR,
50    /// A `NetStreamProvider<unix::SocketAddr>` implementation.
51    unix: UnixR,
52    /// A `TlsProvider<TcpR::TcpStream>` implementation.
53    tls: TlsR,
54    /// A `UdpProvider` implementation
55    udp: UdpR,
56}
57
58impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
59    CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
60{
61    /// Construct a new CompoundRuntime from its components.
62    pub fn new(
63        spawn: TaskR,
64        sleep: SleepR,
65        coarse_time: CoarseTimeR,
66        tcp: TcpR,
67        unix: UnixR,
68        tls: TlsR,
69        udp: UdpR,
70    ) -> Self {
71        #[allow(clippy::arc_with_non_send_sync)]
72        CompoundRuntime {
73            inner: Arc::new(Inner {
74                spawn,
75                sleep,
76                coarse_time,
77                tcp,
78                unix,
79                tls,
80                udp,
81            }),
82        }
83    }
84}
85
86impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Spawn
87    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
88where
89    TaskR: Spawn,
90{
91    #[inline]
92    #[track_caller]
93    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> {
94        self.inner.spawn.spawn_obj(future)
95    }
96}
97
98impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Blocking
99    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
100where
101    TaskR: Blocking,
102    SleepR: Clone + Send + Sync + 'static,
103    CoarseTimeR: Clone + Send + Sync + 'static,
104    TcpR: Clone + Send + Sync + 'static,
105    UnixR: Clone + Send + Sync + 'static,
106    TlsR: Clone + Send + Sync + 'static,
107    UdpR: Clone + Send + Sync + 'static,
108{
109    type ThreadHandle<T: Send + 'static> = TaskR::ThreadHandle<T>;
110
111    #[inline]
112    #[track_caller]
113    fn spawn_blocking<F, T>(&self, f: F) -> TaskR::ThreadHandle<T>
114    where
115        F: FnOnce() -> T + Send + 'static,
116        T: Send + 'static,
117    {
118        self.inner.spawn.spawn_blocking(f)
119    }
120
121    #[inline]
122    #[track_caller]
123    fn reenter_block_on<F>(&self, future: F) -> F::Output
124    where
125        F: Future,
126        F::Output: Send + 'static,
127    {
128        self.inner.spawn.reenter_block_on(future)
129    }
130
131    #[track_caller]
132    fn blocking_io<F, T>(&self, f: F) -> impl futures::Future<Output = T>
133    where
134        F: FnOnce() -> T + Send + 'static,
135        T: Send + 'static,
136    {
137        self.inner.spawn.blocking_io(f)
138    }
139}
140
141impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> ToplevelBlockOn
142    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
143where
144    TaskR: ToplevelBlockOn,
145    SleepR: Clone + Send + Sync + 'static,
146    CoarseTimeR: Clone + Send + Sync + 'static,
147    TcpR: Clone + Send + Sync + 'static,
148    UnixR: Clone + Send + Sync + 'static,
149    TlsR: Clone + Send + Sync + 'static,
150    UdpR: Clone + Send + Sync + 'static,
151{
152    #[inline]
153    #[track_caller]
154    fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
155        self.inner.spawn.block_on(future)
156    }
157}
158
159impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> SleepProvider
160    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
161where
162    SleepR: SleepProvider,
163    TaskR: Clone + Send + Sync + 'static,
164    CoarseTimeR: Clone + Send + Sync + 'static,
165    TcpR: Clone + Send + Sync + 'static,
166    UnixR: Clone + Send + Sync + 'static,
167    TlsR: Clone + Send + Sync + 'static,
168    UdpR: Clone + Send + Sync + 'static,
169{
170    type SleepFuture = SleepR::SleepFuture;
171
172    #[inline]
173    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
174        self.inner.sleep.sleep(duration)
175    }
176
177    #[inline]
178    fn now(&self) -> Instant {
179        self.inner.sleep.now()
180    }
181
182    #[inline]
183    fn wallclock(&self) -> SystemTime {
184        self.inner.sleep.wallclock()
185    }
186}
187
188impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> CoarseTimeProvider
189    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
190where
191    CoarseTimeR: CoarseTimeProvider,
192    SleepR: Clone + Send + Sync + 'static,
193    TaskR: Clone + Send + Sync + 'static,
194    CoarseTimeR: Clone + Send + Sync + 'static,
195    TcpR: Clone + Send + Sync + 'static,
196    UnixR: Clone + Send + Sync + 'static,
197    TlsR: Clone + Send + Sync + 'static,
198    UdpR: Clone + Send + Sync + 'static,
199{
200    #[inline]
201    fn now_coarse(&self) -> CoarseInstant {
202        self.inner.coarse_time.now_coarse()
203    }
204}
205
206#[async_trait]
207impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<net::SocketAddr>
208    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
209where
210    TcpR: NetStreamProvider<net::SocketAddr>,
211    TaskR: Send + Sync + 'static,
212    SleepR: Send + Sync + 'static,
213    CoarseTimeR: Send + Sync + 'static,
214    TcpR: Send + Sync + 'static,
215    UnixR: Clone + Send + Sync + 'static,
216    TlsR: Send + Sync + 'static,
217    UdpR: Send + Sync + 'static,
218{
219    type Stream = TcpR::Stream;
220
221    type Listener = TcpR::Listener;
222
223    #[inline]
224    #[instrument(skip_all, level = "trace")]
225    async fn connect(&self, addr: &net::SocketAddr) -> IoResult<Self::Stream> {
226        self.inner.tcp.connect(addr).await
227    }
228
229    #[inline]
230    async fn listen(&self, addr: &net::SocketAddr) -> IoResult<Self::Listener> {
231        self.inner.tcp.listen(addr).await
232    }
233}
234
235#[async_trait]
236impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<unix::SocketAddr>
237    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
238where
239    UnixR: NetStreamProvider<unix::SocketAddr>,
240    TaskR: Send + Sync + 'static,
241    SleepR: Send + Sync + 'static,
242    CoarseTimeR: Send + Sync + 'static,
243    TcpR: Send + Sync + 'static,
244    UnixR: Clone + Send + Sync + 'static,
245    TlsR: Send + Sync + 'static,
246    UdpR: Send + Sync + 'static,
247{
248    type Stream = UnixR::Stream;
249
250    type Listener = UnixR::Listener;
251
252    #[inline]
253    #[instrument(skip_all, level = "trace")]
254    async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
255        self.inner.unix.connect(addr).await
256    }
257
258    #[inline]
259    async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
260        self.inner.unix.listen(addr).await
261    }
262}
263
264impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR, S> TlsProvider<S>
265    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
266where
267    TcpR: NetStreamProvider,
268    TlsR: TlsProvider<S>,
269    UnixR: Clone + Send + Sync + 'static,
270    SleepR: Clone + Send + Sync + 'static,
271    CoarseTimeR: Clone + Send + Sync + 'static,
272    TaskR: Clone + Send + Sync + 'static,
273    UdpR: Clone + Send + Sync + 'static,
274    S: StreamOps,
275{
276    type Connector = TlsR::Connector;
277    type TlsStream = TlsR::TlsStream;
278
279    #[inline]
280    fn tls_connector(&self) -> Self::Connector {
281        self.inner.tls.tls_connector()
282    }
283
284    #[inline]
285    fn supports_keying_material_export(&self) -> bool {
286        self.inner.tls.supports_keying_material_export()
287    }
288}
289
290impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> std::fmt::Debug
291    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
292{
293    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294        f.debug_struct("CompoundRuntime").finish_non_exhaustive()
295    }
296}
297
298#[async_trait]
299impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> UdpProvider
300    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
301where
302    UdpR: UdpProvider,
303    TaskR: Send + Sync + 'static,
304    SleepR: Send + Sync + 'static,
305    CoarseTimeR: Send + Sync + 'static,
306    TcpR: Send + Sync + 'static,
307    UnixR: Clone + Send + Sync + 'static,
308    TlsR: Send + Sync + 'static,
309    UdpR: Send + Sync + 'static,
310{
311    type UdpSocket = UdpR::UdpSocket;
312
313    #[inline]
314    async fn bind(&self, addr: &net::SocketAddr) -> IoResult<Self::UdpSocket> {
315        self.inner.udp.bind(addr).await
316    }
317}
318
319/// Module to seal RuntimeSubstExt
320mod sealed {
321    /// Helper for sealing RuntimeSubstExt
322    #[allow(unreachable_pub)]
323    pub trait Sealed {}
324}
325/// Extension trait on Runtime:
326/// Construct new Runtimes that replace part of an original runtime.
327///
328/// (If you need to do more complicated versions of this, you should likely construct
329/// CompoundRuntime directly.)
330pub trait RuntimeSubstExt: sealed::Sealed + Sized {
331    /// Return a new runtime wrapping this runtime, but replacing its TCP NetStreamProvider.
332    fn with_tcp_provider<T>(
333        &self,
334        new_tcp: T,
335    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self>;
336    /// Return a new runtime wrapping this runtime, but replacing its SleepProvider.
337    fn with_sleep_provider<T>(
338        &self,
339        new_sleep: T,
340    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self>;
341    /// Return a new runtime wrapping this runtime, but replacing its CoarseTimeProvider.
342    fn with_coarse_time_provider<T>(
343        &self,
344        new_coarse_time: T,
345    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self>;
346}
347impl<R: Runtime> sealed::Sealed for R {}
348impl<R: Runtime + Sized> RuntimeSubstExt for R {
349    fn with_tcp_provider<T>(
350        &self,
351        new_tcp: T,
352    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self> {
353        CompoundRuntime::new(
354            self.clone(),
355            self.clone(),
356            self.clone(),
357            new_tcp,
358            self.clone(),
359            self.clone(),
360            self.clone(),
361        )
362    }
363
364    fn with_sleep_provider<T>(
365        &self,
366        new_sleep: T,
367    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self> {
368        CompoundRuntime::new(
369            self.clone(),
370            new_sleep,
371            self.clone(),
372            self.clone(),
373            self.clone(),
374            self.clone(),
375            self.clone(),
376        )
377    }
378
379    fn with_coarse_time_provider<T>(
380        &self,
381        new_coarse_time: T,
382    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self> {
383        CompoundRuntime::new(
384            self.clone(),
385            self.clone(),
386            new_coarse_time,
387            self.clone(),
388            self.clone(),
389            self.clone(),
390            self.clone(),
391        )
392    }
393}