Skip to main content

tor_rtcompat/
compound.rs

1//! Define a [`CompoundRuntime`] part that can be built from several component
2//! pieces.
3
4use std::{net, sync::Arc, time::Duration};
5
6use crate::traits::*;
7use crate::{CoarseInstant, CoarseTimeProvider};
8use async_trait::async_trait;
9use educe::Educe;
10use futures::{future::FutureObj, task::Spawn};
11use std::future::Future;
12use std::io::Result as IoResult;
13use tor_general_addr::unix;
14use tracing::instrument;
15use web_time_compat::{Instant, SystemTime};
16
17/// A runtime made of several parts, each of which implements one trait-group.
18///
19/// The `TaskR` component should implement [`Spawn`], [`Blocking`] and maybe [`ToplevelBlockOn`];
20/// the `SleepR` component should implement [`SleepProvider`];
21/// the `CoarseTimeR` component should implement [`CoarseTimeProvider`];
22/// the `TcpR` component should implement [`NetStreamProvider`] for [`net::SocketAddr`];
23/// the `UnixR` component should implement [`NetStreamProvider`] for [`unix::SocketAddr`];
24/// and
25/// the `TlsR` component should implement [`TlsProvider`].
26///
27/// You can use this structure to create new runtimes in two ways: either by
28/// overriding a single part of an existing runtime, or by building an entirely
29/// new runtime from pieces.
30#[derive(Educe)]
31#[educe(Clone)] // #[derive(Clone)] wrongly infers Clone bounds on the generic parameters
32pub struct CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
33    /// The actual collection of Runtime objects.
34    ///
35    /// We wrap this in an Arc rather than requiring that each item implement
36    /// Clone, though we could change our minds later on.
37    inner: Arc<Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>>,
38}
39
40/// A collection of objects implementing that traits that make up a [`Runtime`]
41struct Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
42    /// A `Spawn` and `BlockOn` implementation.
43    spawn: TaskR,
44    /// A `SleepProvider` implementation.
45    sleep: SleepR,
46    /// A `CoarseTimeProvider`` implementation.
47    coarse_time: CoarseTimeR,
48    /// A `NetStreamProvider<net::SocketAddr>` implementation
49    tcp: TcpR,
50    /// A `NetStreamProvider<unix::SocketAddr>` implementation.
51    unix: UnixR,
52    /// A `TlsProvider<TcpR::TcpStream>` implementation.
53    tls: TlsR,
54    /// A `UdpProvider` implementation
55    udp: UdpR,
56}
57
58impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
59    CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
60{
61    /// Construct a new CompoundRuntime from its components.
62    pub fn new(
63        spawn: TaskR,
64        sleep: SleepR,
65        coarse_time: CoarseTimeR,
66        tcp: TcpR,
67        unix: UnixR,
68        tls: TlsR,
69        udp: UdpR,
70    ) -> Self {
71        #[allow(clippy::arc_with_non_send_sync)]
72        CompoundRuntime {
73            inner: Arc::new(Inner {
74                spawn,
75                sleep,
76                coarse_time,
77                tcp,
78                unix,
79                tls,
80                udp,
81            }),
82        }
83    }
84}
85
86impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Spawn
87    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
88where
89    TaskR: Spawn,
90{
91    #[inline]
92    #[track_caller]
93    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> {
94        self.inner.spawn.spawn_obj(future)
95    }
96}
97
98impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Blocking
99    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
100where
101    TaskR: Blocking,
102    SleepR: Clone + Send + Sync + 'static,
103    CoarseTimeR: Clone + Send + Sync + 'static,
104    TcpR: Clone + Send + Sync + 'static,
105    UnixR: Clone + Send + Sync + 'static,
106    TlsR: Clone + Send + Sync + 'static,
107    UdpR: Clone + Send + Sync + 'static,
108{
109    type ThreadHandle<T: Send + 'static> = TaskR::ThreadHandle<T>;
110
111    #[inline]
112    #[track_caller]
113    fn spawn_blocking<F, T>(&self, f: F) -> TaskR::ThreadHandle<T>
114    where
115        F: FnOnce() -> T + Send + 'static,
116        T: Send + 'static,
117    {
118        self.inner.spawn.spawn_blocking(f)
119    }
120
121    #[inline]
122    #[track_caller]
123    fn reenter_block_on<F>(&self, future: F) -> F::Output
124    where
125        F: Future,
126        F::Output: Send + 'static,
127    {
128        self.inner.spawn.reenter_block_on(future)
129    }
130
131    #[track_caller]
132    fn blocking_io<F, T>(&self, f: F) -> impl futures::Future<Output = T>
133    where
134        F: FnOnce() -> T + Send + 'static,
135        T: Send + 'static,
136    {
137        self.inner.spawn.blocking_io(f)
138    }
139}
140
141impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> ToplevelBlockOn
142    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
143where
144    TaskR: ToplevelBlockOn,
145    SleepR: Clone + Send + Sync + 'static,
146    CoarseTimeR: Clone + Send + Sync + 'static,
147    TcpR: Clone + Send + Sync + 'static,
148    UnixR: Clone + Send + Sync + 'static,
149    TlsR: Clone + Send + Sync + 'static,
150    UdpR: Clone + Send + Sync + 'static,
151{
152    #[inline]
153    #[track_caller]
154    fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
155        self.inner.spawn.block_on(future)
156    }
157}
158
159impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> SleepProvider
160    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
161where
162    SleepR: SleepProvider,
163    TaskR: Clone + Send + Sync + 'static,
164    CoarseTimeR: Clone + Send + Sync + 'static,
165    TcpR: Clone + Send + Sync + 'static,
166    UnixR: Clone + Send + Sync + 'static,
167    TlsR: Clone + Send + Sync + 'static,
168    UdpR: Clone + Send + Sync + 'static,
169{
170    type SleepFuture = SleepR::SleepFuture;
171
172    #[inline]
173    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
174        self.inner.sleep.sleep(duration)
175    }
176
177    #[inline]
178    fn now(&self) -> Instant {
179        self.inner.sleep.now()
180    }
181
182    #[inline]
183    fn wallclock(&self) -> SystemTime {
184        self.inner.sleep.wallclock()
185    }
186}
187
188impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> CoarseTimeProvider
189    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
190where
191    CoarseTimeR: CoarseTimeProvider,
192    SleepR: Clone + Send + Sync + 'static,
193    TaskR: Clone + Send + Sync + 'static,
194    CoarseTimeR: Clone + Send + Sync + 'static,
195    TcpR: Clone + Send + Sync + 'static,
196    UnixR: Clone + Send + Sync + 'static,
197    TlsR: Clone + Send + Sync + 'static,
198    UdpR: Clone + Send + Sync + 'static,
199{
200    #[inline]
201    fn now_coarse(&self) -> CoarseInstant {
202        self.inner.coarse_time.now_coarse()
203    }
204}
205
206#[async_trait]
207impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<net::SocketAddr>
208    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
209where
210    TcpR: NetStreamProvider<net::SocketAddr>,
211    TaskR: Send + Sync + 'static,
212    SleepR: Send + Sync + 'static,
213    CoarseTimeR: Send + Sync + 'static,
214    TcpR: Send + Sync + 'static,
215    UnixR: Clone + Send + Sync + 'static,
216    TlsR: Send + Sync + 'static,
217    UdpR: Send + Sync + 'static,
218{
219    type Stream = TcpR::Stream;
220
221    type Listener = TcpR::Listener;
222
223    type ConnectOptions = TcpR::ConnectOptions;
224    type ListenOptions = TcpR::ListenOptions;
225
226    #[inline]
227    #[instrument(skip_all, level = "trace")]
228    async fn connect(
229        &self,
230        addr: &net::SocketAddr,
231        options: &Self::ConnectOptions,
232    ) -> IoResult<Self::Stream> {
233        self.inner.tcp.connect(addr, options).await
234    }
235
236    #[inline]
237    async fn listen(
238        &self,
239        addr: &net::SocketAddr,
240        options: &Self::ListenOptions,
241    ) -> IoResult<Self::Listener> {
242        self.inner.tcp.listen(addr, options).await
243    }
244}
245
246#[async_trait]
247impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<unix::SocketAddr>
248    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
249where
250    UnixR: NetStreamProvider<unix::SocketAddr>,
251    TaskR: Send + Sync + 'static,
252    SleepR: Send + Sync + 'static,
253    CoarseTimeR: Send + Sync + 'static,
254    TcpR: Send + Sync + 'static,
255    UnixR: Clone + Send + Sync + 'static,
256    TlsR: Send + Sync + 'static,
257    UdpR: Send + Sync + 'static,
258{
259    type Stream = UnixR::Stream;
260
261    type Listener = UnixR::Listener;
262
263    type ConnectOptions = UnixR::ConnectOptions;
264    type ListenOptions = UnixR::ListenOptions;
265
266    #[inline]
267    #[instrument(skip_all, level = "trace")]
268    async fn connect(
269        &self,
270        addr: &unix::SocketAddr,
271        options: &Self::ConnectOptions,
272    ) -> IoResult<Self::Stream> {
273        self.inner.unix.connect(addr, options).await
274    }
275
276    #[inline]
277    async fn listen(
278        &self,
279        addr: &unix::SocketAddr,
280        options: &Self::ListenOptions,
281    ) -> IoResult<Self::Listener> {
282        self.inner.unix.listen(addr, options).await
283    }
284}
285
286impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR, S> TlsProvider<S>
287    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
288where
289    TcpR: NetStreamProvider,
290    TlsR: TlsProvider<S>,
291    UnixR: Clone + Send + Sync + 'static,
292    SleepR: Clone + Send + Sync + 'static,
293    CoarseTimeR: Clone + Send + Sync + 'static,
294    TaskR: Clone + Send + Sync + 'static,
295    UdpR: Clone + Send + Sync + 'static,
296    S: StreamOps,
297{
298    type Connector = TlsR::Connector;
299    type TlsStream = TlsR::TlsStream;
300    type Acceptor = TlsR::Acceptor;
301    type TlsServerStream = TlsR::TlsServerStream;
302
303    #[inline]
304    fn tls_connector(&self) -> Self::Connector {
305        self.inner.tls.tls_connector()
306    }
307
308    #[inline]
309    fn tls_acceptor(&self, settings: TlsAcceptorSettings) -> IoResult<Self::Acceptor> {
310        self.inner.tls.tls_acceptor(settings)
311    }
312
313    #[inline]
314    fn supports_keying_material_export(&self) -> bool {
315        self.inner.tls.supports_keying_material_export()
316    }
317}
318
319impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> std::fmt::Debug
320    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
321{
322    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323        f.debug_struct("CompoundRuntime").finish_non_exhaustive()
324    }
325}
326
327#[async_trait]
328impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> UdpProvider
329    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
330where
331    UdpR: UdpProvider,
332    TaskR: Send + Sync + 'static,
333    SleepR: Send + Sync + 'static,
334    CoarseTimeR: Send + Sync + 'static,
335    TcpR: Send + Sync + 'static,
336    UnixR: Clone + Send + Sync + 'static,
337    TlsR: Send + Sync + 'static,
338    UdpR: Send + Sync + 'static,
339{
340    type UdpSocket = UdpR::UdpSocket;
341
342    #[inline]
343    async fn bind(&self, addr: &net::SocketAddr) -> IoResult<Self::UdpSocket> {
344        self.inner.udp.bind(addr).await
345    }
346}
347
348/// Module to seal RuntimeSubstExt
349mod sealed {
350    /// Helper for sealing RuntimeSubstExt
351    #[allow(unreachable_pub)]
352    pub trait Sealed {}
353}
354/// Extension trait on Runtime:
355/// Construct new Runtimes that replace part of an original runtime.
356///
357/// (If you need to do more complicated versions of this, you should likely construct
358/// CompoundRuntime directly.)
359pub trait RuntimeSubstExt: sealed::Sealed + Sized {
360    /// Return a new runtime wrapping this runtime, but replacing its TCP NetStreamProvider.
361    fn with_tcp_provider<T>(
362        &self,
363        new_tcp: T,
364    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self>;
365    /// Return a new runtime wrapping this runtime, but replacing its SleepProvider.
366    fn with_sleep_provider<T>(
367        &self,
368        new_sleep: T,
369    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self>;
370    /// Return a new runtime wrapping this runtime, but replacing its CoarseTimeProvider.
371    fn with_coarse_time_provider<T>(
372        &self,
373        new_coarse_time: T,
374    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self>;
375}
376impl<R: Runtime> sealed::Sealed for R {}
377impl<R: Runtime + Sized> RuntimeSubstExt for R {
378    fn with_tcp_provider<T>(
379        &self,
380        new_tcp: T,
381    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self> {
382        CompoundRuntime::new(
383            self.clone(),
384            self.clone(),
385            self.clone(),
386            new_tcp,
387            self.clone(),
388            self.clone(),
389            self.clone(),
390        )
391    }
392
393    fn with_sleep_provider<T>(
394        &self,
395        new_sleep: T,
396    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self> {
397        CompoundRuntime::new(
398            self.clone(),
399            new_sleep,
400            self.clone(),
401            self.clone(),
402            self.clone(),
403            self.clone(),
404            self.clone(),
405        )
406    }
407
408    fn with_coarse_time_provider<T>(
409        &self,
410        new_coarse_time: T,
411    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self> {
412        CompoundRuntime::new(
413            self.clone(),
414            self.clone(),
415            new_coarse_time,
416            self.clone(),
417            self.clone(),
418            self.clone(),
419            self.clone(),
420        )
421    }
422}