Skip to main content

tor_rtcompat/
compound.rs

1//! Define a [`CompoundRuntime`] part that can be built from several component
2//! pieces.
3
4use std::{net, sync::Arc, time::Duration};
5
6use crate::traits::*;
7use crate::{CoarseInstant, CoarseTimeProvider};
8use async_trait::async_trait;
9use educe::Educe;
10use futures::{future::FutureObj, task::Spawn};
11use std::future::Future;
12use std::io::Result as IoResult;
13use std::time::{Instant, SystemTime};
14use tor_general_addr::unix;
15use tracing::instrument;
16
17/// A runtime made of several parts, each of which implements one trait-group.
18///
19/// The `TaskR` component should implement [`Spawn`], [`Blocking`] and maybe [`ToplevelBlockOn`];
20/// the `SleepR` component should implement [`SleepProvider`];
21/// the `CoarseTimeR` component should implement [`CoarseTimeProvider`];
22/// the `TcpR` component should implement [`NetStreamProvider`] for [`net::SocketAddr`];
23/// the `UnixR` component should implement [`NetStreamProvider`] for [`unix::SocketAddr`];
24/// and
25/// the `TlsR` component should implement [`TlsProvider`].
26///
27/// You can use this structure to create new runtimes in two ways: either by
28/// overriding a single part of an existing runtime, or by building an entirely
29/// new runtime from pieces.
30#[derive(Educe)]
31#[educe(Clone)] // #[derive(Clone)] wrongly infers Clone bounds on the generic parameters
32pub struct CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
33    /// The actual collection of Runtime objects.
34    ///
35    /// We wrap this in an Arc rather than requiring that each item implement
36    /// Clone, though we could change our minds later on.
37    inner: Arc<Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>>,
38}
39
40/// A collection of objects implementing that traits that make up a [`Runtime`]
41struct Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
42    /// A `Spawn` and `BlockOn` implementation.
43    spawn: TaskR,
44    /// A `SleepProvider` implementation.
45    sleep: SleepR,
46    /// A `CoarseTimeProvider`` implementation.
47    coarse_time: CoarseTimeR,
48    /// A `NetStreamProvider<net::SocketAddr>` implementation
49    tcp: TcpR,
50    /// A `NetStreamProvider<unix::SocketAddr>` implementation.
51    unix: UnixR,
52    /// A `TlsProvider<TcpR::TcpStream>` implementation.
53    tls: TlsR,
54    /// A `UdpProvider` implementation
55    udp: UdpR,
56}
57
58impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
59    CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
60{
61    /// Construct a new CompoundRuntime from its components.
62    pub fn new(
63        spawn: TaskR,
64        sleep: SleepR,
65        coarse_time: CoarseTimeR,
66        tcp: TcpR,
67        unix: UnixR,
68        tls: TlsR,
69        udp: UdpR,
70    ) -> Self {
71        #[allow(clippy::arc_with_non_send_sync)]
72        CompoundRuntime {
73            inner: Arc::new(Inner {
74                spawn,
75                sleep,
76                coarse_time,
77                tcp,
78                unix,
79                tls,
80                udp,
81            }),
82        }
83    }
84}
85
86impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Spawn
87    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
88where
89    TaskR: Spawn,
90{
91    #[inline]
92    #[track_caller]
93    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> {
94        self.inner.spawn.spawn_obj(future)
95    }
96}
97
98impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Blocking
99    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
100where
101    TaskR: Blocking,
102    SleepR: Clone + Send + Sync + 'static,
103    CoarseTimeR: Clone + Send + Sync + 'static,
104    TcpR: Clone + Send + Sync + 'static,
105    UnixR: Clone + Send + Sync + 'static,
106    TlsR: Clone + Send + Sync + 'static,
107    UdpR: Clone + Send + Sync + 'static,
108{
109    type ThreadHandle<T: Send + 'static> = TaskR::ThreadHandle<T>;
110
111    #[inline]
112    #[track_caller]
113    fn spawn_blocking<F, T>(&self, f: F) -> TaskR::ThreadHandle<T>
114    where
115        F: FnOnce() -> T + Send + 'static,
116        T: Send + 'static,
117    {
118        self.inner.spawn.spawn_blocking(f)
119    }
120
121    #[inline]
122    #[track_caller]
123    fn reenter_block_on<F>(&self, future: F) -> F::Output
124    where
125        F: Future,
126        F::Output: Send + 'static,
127    {
128        self.inner.spawn.reenter_block_on(future)
129    }
130
131    #[track_caller]
132    fn blocking_io<F, T>(&self, f: F) -> impl futures::Future<Output = T>
133    where
134        F: FnOnce() -> T + Send + 'static,
135        T: Send + 'static,
136    {
137        self.inner.spawn.blocking_io(f)
138    }
139}
140
141impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> ToplevelBlockOn
142    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
143where
144    TaskR: ToplevelBlockOn,
145    SleepR: Clone + Send + Sync + 'static,
146    CoarseTimeR: Clone + Send + Sync + 'static,
147    TcpR: Clone + Send + Sync + 'static,
148    UnixR: Clone + Send + Sync + 'static,
149    TlsR: Clone + Send + Sync + 'static,
150    UdpR: Clone + Send + Sync + 'static,
151{
152    #[inline]
153    #[track_caller]
154    fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
155        self.inner.spawn.block_on(future)
156    }
157}
158
159impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> SleepProvider
160    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
161where
162    SleepR: SleepProvider,
163    TaskR: Clone + Send + Sync + 'static,
164    CoarseTimeR: Clone + Send + Sync + 'static,
165    TcpR: Clone + Send + Sync + 'static,
166    UnixR: Clone + Send + Sync + 'static,
167    TlsR: Clone + Send + Sync + 'static,
168    UdpR: Clone + Send + Sync + 'static,
169{
170    type SleepFuture = SleepR::SleepFuture;
171
172    #[inline]
173    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
174        self.inner.sleep.sleep(duration)
175    }
176
177    #[inline]
178    fn now(&self) -> Instant {
179        self.inner.sleep.now()
180    }
181
182    #[inline]
183    fn wallclock(&self) -> SystemTime {
184        self.inner.sleep.wallclock()
185    }
186}
187
188impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> CoarseTimeProvider
189    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
190where
191    CoarseTimeR: CoarseTimeProvider,
192    SleepR: Clone + Send + Sync + 'static,
193    TaskR: Clone + Send + Sync + 'static,
194    CoarseTimeR: Clone + Send + Sync + 'static,
195    TcpR: Clone + Send + Sync + 'static,
196    UnixR: Clone + Send + Sync + 'static,
197    TlsR: Clone + Send + Sync + 'static,
198    UdpR: Clone + Send + Sync + 'static,
199{
200    #[inline]
201    fn now_coarse(&self) -> CoarseInstant {
202        self.inner.coarse_time.now_coarse()
203    }
204}
205
206#[async_trait]
207impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<net::SocketAddr>
208    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
209where
210    TcpR: NetStreamProvider<net::SocketAddr>,
211    TaskR: Send + Sync + 'static,
212    SleepR: Send + Sync + 'static,
213    CoarseTimeR: Send + Sync + 'static,
214    TcpR: Send + Sync + 'static,
215    UnixR: Clone + Send + Sync + 'static,
216    TlsR: Send + Sync + 'static,
217    UdpR: Send + Sync + 'static,
218{
219    type Stream = TcpR::Stream;
220
221    type Listener = TcpR::Listener;
222
223    #[inline]
224    #[instrument(skip_all, level = "trace")]
225    async fn connect(&self, addr: &net::SocketAddr) -> IoResult<Self::Stream> {
226        self.inner.tcp.connect(addr).await
227    }
228
229    #[inline]
230    async fn listen(&self, addr: &net::SocketAddr) -> IoResult<Self::Listener> {
231        self.inner.tcp.listen(addr).await
232    }
233}
234
235#[async_trait]
236impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<unix::SocketAddr>
237    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
238where
239    UnixR: NetStreamProvider<unix::SocketAddr>,
240    TaskR: Send + Sync + 'static,
241    SleepR: Send + Sync + 'static,
242    CoarseTimeR: Send + Sync + 'static,
243    TcpR: Send + Sync + 'static,
244    UnixR: Clone + Send + Sync + 'static,
245    TlsR: Send + Sync + 'static,
246    UdpR: Send + Sync + 'static,
247{
248    type Stream = UnixR::Stream;
249
250    type Listener = UnixR::Listener;
251
252    #[inline]
253    #[instrument(skip_all, level = "trace")]
254    async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
255        self.inner.unix.connect(addr).await
256    }
257
258    #[inline]
259    async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
260        self.inner.unix.listen(addr).await
261    }
262}
263
264impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR, S> TlsProvider<S>
265    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
266where
267    TcpR: NetStreamProvider,
268    TlsR: TlsProvider<S>,
269    UnixR: Clone + Send + Sync + 'static,
270    SleepR: Clone + Send + Sync + 'static,
271    CoarseTimeR: Clone + Send + Sync + 'static,
272    TaskR: Clone + Send + Sync + 'static,
273    UdpR: Clone + Send + Sync + 'static,
274    S: StreamOps,
275{
276    type Connector = TlsR::Connector;
277    type TlsStream = TlsR::TlsStream;
278    type Acceptor = TlsR::Acceptor;
279    type TlsServerStream = TlsR::TlsServerStream;
280
281    #[inline]
282    fn tls_connector(&self) -> Self::Connector {
283        self.inner.tls.tls_connector()
284    }
285
286    #[inline]
287    fn tls_acceptor(&self, settings: TlsAcceptorSettings) -> IoResult<Self::Acceptor> {
288        self.inner.tls.tls_acceptor(settings)
289    }
290
291    #[inline]
292    fn supports_keying_material_export(&self) -> bool {
293        self.inner.tls.supports_keying_material_export()
294    }
295}
296
297impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> std::fmt::Debug
298    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
299{
300    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301        f.debug_struct("CompoundRuntime").finish_non_exhaustive()
302    }
303}
304
305#[async_trait]
306impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> UdpProvider
307    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
308where
309    UdpR: UdpProvider,
310    TaskR: Send + Sync + 'static,
311    SleepR: Send + Sync + 'static,
312    CoarseTimeR: Send + Sync + 'static,
313    TcpR: Send + Sync + 'static,
314    UnixR: Clone + Send + Sync + 'static,
315    TlsR: Send + Sync + 'static,
316    UdpR: Send + Sync + 'static,
317{
318    type UdpSocket = UdpR::UdpSocket;
319
320    #[inline]
321    async fn bind(&self, addr: &net::SocketAddr) -> IoResult<Self::UdpSocket> {
322        self.inner.udp.bind(addr).await
323    }
324}
325
326/// Module to seal RuntimeSubstExt
327mod sealed {
328    /// Helper for sealing RuntimeSubstExt
329    #[allow(unreachable_pub)]
330    pub trait Sealed {}
331}
332/// Extension trait on Runtime:
333/// Construct new Runtimes that replace part of an original runtime.
334///
335/// (If you need to do more complicated versions of this, you should likely construct
336/// CompoundRuntime directly.)
337pub trait RuntimeSubstExt: sealed::Sealed + Sized {
338    /// Return a new runtime wrapping this runtime, but replacing its TCP NetStreamProvider.
339    fn with_tcp_provider<T>(
340        &self,
341        new_tcp: T,
342    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self>;
343    /// Return a new runtime wrapping this runtime, but replacing its SleepProvider.
344    fn with_sleep_provider<T>(
345        &self,
346        new_sleep: T,
347    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self>;
348    /// Return a new runtime wrapping this runtime, but replacing its CoarseTimeProvider.
349    fn with_coarse_time_provider<T>(
350        &self,
351        new_coarse_time: T,
352    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self>;
353}
354impl<R: Runtime> sealed::Sealed for R {}
355impl<R: Runtime + Sized> RuntimeSubstExt for R {
356    fn with_tcp_provider<T>(
357        &self,
358        new_tcp: T,
359    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self> {
360        CompoundRuntime::new(
361            self.clone(),
362            self.clone(),
363            self.clone(),
364            new_tcp,
365            self.clone(),
366            self.clone(),
367            self.clone(),
368        )
369    }
370
371    fn with_sleep_provider<T>(
372        &self,
373        new_sleep: T,
374    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self> {
375        CompoundRuntime::new(
376            self.clone(),
377            new_sleep,
378            self.clone(),
379            self.clone(),
380            self.clone(),
381            self.clone(),
382            self.clone(),
383        )
384    }
385
386    fn with_coarse_time_provider<T>(
387        &self,
388        new_coarse_time: T,
389    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self> {
390        CompoundRuntime::new(
391            self.clone(),
392            self.clone(),
393            new_coarse_time,
394            self.clone(),
395            self.clone(),
396            self.clone(),
397            self.clone(),
398        )
399    }
400}