Skip to main content

tor_rtcompat/
compound.rs

1//! Define a [`CompoundRuntime`] part that can be built from several component
2//! pieces.
3
4use std::{net, sync::Arc, time::Duration};
5
6use crate::traits::*;
7use crate::{CoarseInstant, CoarseTimeProvider};
8use async_trait::async_trait;
9use educe::Educe;
10use futures::{future::FutureObj, task::Spawn};
11use std::future::Future;
12use std::io::Result as IoResult;
13use tor_general_addr::unix;
14use tracing::instrument;
15use web_time_compat::{Instant, SystemTime};
16
17/// A runtime made of several parts, each of which implements one trait-group.
18///
19/// The `TaskR` component should implement [`Spawn`], [`Blocking`] and maybe [`ToplevelBlockOn`];
20/// the `SleepR` component should implement [`SleepProvider`];
21/// the `CoarseTimeR` component should implement [`CoarseTimeProvider`];
22/// the `TcpR` component should implement [`NetStreamProvider`] for [`net::SocketAddr`];
23/// the `UnixR` component should implement [`NetStreamProvider`] for [`unix::SocketAddr`];
24/// and
25/// the `TlsR` component should implement [`TlsProvider`].
26///
27/// You can use this structure to create new runtimes in two ways: either by
28/// overriding a single part of an existing runtime, or by building an entirely
29/// new runtime from pieces.
30#[derive(Educe)]
31#[educe(Clone)] // #[derive(Clone)] wrongly infers Clone bounds on the generic parameters
32pub struct CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
33    /// The actual collection of Runtime objects.
34    ///
35    /// We wrap this in an Arc rather than requiring that each item implement
36    /// Clone, though we could change our minds later on.
37    inner: Arc<Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>>,
38}
39
40/// A collection of objects implementing that traits that make up a [`Runtime`]
41struct Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
42    /// A `Spawn` and `BlockOn` implementation.
43    spawn: TaskR,
44    /// A `SleepProvider` implementation.
45    sleep: SleepR,
46    /// A `CoarseTimeProvider`` implementation.
47    coarse_time: CoarseTimeR,
48    /// A `NetStreamProvider<net::SocketAddr>` implementation
49    tcp: TcpR,
50    /// A `NetStreamProvider<unix::SocketAddr>` implementation.
51    unix: UnixR,
52    /// A `TlsProvider<TcpR::TcpStream>` implementation.
53    tls: TlsR,
54    /// A `UdpProvider` implementation
55    udp: UdpR,
56}
57
58impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
59    CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
60{
61    /// Construct a new CompoundRuntime from its components.
62    pub fn new(
63        spawn: TaskR,
64        sleep: SleepR,
65        coarse_time: CoarseTimeR,
66        tcp: TcpR,
67        unix: UnixR,
68        tls: TlsR,
69        udp: UdpR,
70    ) -> Self {
71        #[allow(clippy::arc_with_non_send_sync)]
72        CompoundRuntime {
73            inner: Arc::new(Inner {
74                spawn,
75                sleep,
76                coarse_time,
77                tcp,
78                unix,
79                tls,
80                udp,
81            }),
82        }
83    }
84}
85
86impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Spawn
87    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
88where
89    TaskR: Spawn,
90{
91    #[inline]
92    #[track_caller]
93    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> {
94        self.inner.spawn.spawn_obj(future)
95    }
96}
97
98impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Blocking
99    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
100where
101    TaskR: Blocking,
102    SleepR: Clone + Send + Sync + 'static,
103    CoarseTimeR: Clone + Send + Sync + 'static,
104    TcpR: Clone + Send + Sync + 'static,
105    UnixR: Clone + Send + Sync + 'static,
106    TlsR: Clone + Send + Sync + 'static,
107    UdpR: Clone + Send + Sync + 'static,
108{
109    type ThreadHandle<T: Send + 'static> = TaskR::ThreadHandle<T>;
110
111    #[inline]
112    #[track_caller]
113    fn spawn_blocking<F, T>(&self, f: F) -> TaskR::ThreadHandle<T>
114    where
115        F: FnOnce() -> T + Send + 'static,
116        T: Send + 'static,
117    {
118        self.inner.spawn.spawn_blocking(f)
119    }
120
121    #[inline]
122    #[track_caller]
123    fn reenter_block_on<F>(&self, future: F) -> F::Output
124    where
125        F: Future,
126        F::Output: Send + 'static,
127    {
128        self.inner.spawn.reenter_block_on(future)
129    }
130
131    #[track_caller]
132    fn blocking_io<F, T>(&self, f: F) -> impl futures::Future<Output = T>
133    where
134        F: FnOnce() -> T + Send + 'static,
135        T: Send + 'static,
136    {
137        self.inner.spawn.blocking_io(f)
138    }
139}
140
141impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> ToplevelBlockOn
142    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
143where
144    TaskR: ToplevelBlockOn,
145    SleepR: Clone + Send + Sync + 'static,
146    CoarseTimeR: Clone + Send + Sync + 'static,
147    TcpR: Clone + Send + Sync + 'static,
148    UnixR: Clone + Send + Sync + 'static,
149    TlsR: Clone + Send + Sync + 'static,
150    UdpR: Clone + Send + Sync + 'static,
151{
152    #[inline]
153    #[track_caller]
154    fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
155        self.inner.spawn.block_on(future)
156    }
157}
158
159impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> SleepProvider
160    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
161where
162    SleepR: SleepProvider,
163    TaskR: Clone + Send + Sync + 'static,
164    CoarseTimeR: Clone + Send + Sync + 'static,
165    TcpR: Clone + Send + Sync + 'static,
166    UnixR: Clone + Send + Sync + 'static,
167    TlsR: Clone + Send + Sync + 'static,
168    UdpR: Clone + Send + Sync + 'static,
169{
170    type SleepFuture = SleepR::SleepFuture;
171
172    #[inline]
173    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
174        self.inner.sleep.sleep(duration)
175    }
176
177    #[inline]
178    fn now(&self) -> Instant {
179        self.inner.sleep.now()
180    }
181
182    #[inline]
183    fn wallclock(&self) -> SystemTime {
184        self.inner.sleep.wallclock()
185    }
186}
187
188impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> CoarseTimeProvider
189    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
190where
191    CoarseTimeR: CoarseTimeProvider,
192    SleepR: Clone + Send + Sync + 'static,
193    TaskR: Clone + Send + Sync + 'static,
194    CoarseTimeR: Clone + Send + Sync + 'static,
195    TcpR: Clone + Send + Sync + 'static,
196    UnixR: Clone + Send + Sync + 'static,
197    TlsR: Clone + Send + Sync + 'static,
198    UdpR: Clone + Send + Sync + 'static,
199{
200    #[inline]
201    fn now_coarse(&self) -> CoarseInstant {
202        self.inner.coarse_time.now_coarse()
203    }
204}
205
206#[async_trait]
207impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<net::SocketAddr>
208    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
209where
210    TcpR: NetStreamProvider<net::SocketAddr>,
211    TaskR: Send + Sync + 'static,
212    SleepR: Send + Sync + 'static,
213    CoarseTimeR: Send + Sync + 'static,
214    TcpR: Send + Sync + 'static,
215    UnixR: Clone + Send + Sync + 'static,
216    TlsR: Send + Sync + 'static,
217    UdpR: Send + Sync + 'static,
218{
219    type Stream = TcpR::Stream;
220
221    type Listener = TcpR::Listener;
222
223    type ListenOptions = TcpR::ListenOptions;
224
225    #[inline]
226    #[instrument(skip_all, level = "trace")]
227    async fn connect(&self, addr: &net::SocketAddr) -> IoResult<Self::Stream> {
228        self.inner.tcp.connect(addr).await
229    }
230
231    #[inline]
232    async fn listen(
233        &self,
234        addr: &net::SocketAddr,
235        options: &Self::ListenOptions,
236    ) -> IoResult<Self::Listener> {
237        self.inner.tcp.listen(addr, options).await
238    }
239}
240
241#[async_trait]
242impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<unix::SocketAddr>
243    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
244where
245    UnixR: NetStreamProvider<unix::SocketAddr>,
246    TaskR: Send + Sync + 'static,
247    SleepR: Send + Sync + 'static,
248    CoarseTimeR: Send + Sync + 'static,
249    TcpR: Send + Sync + 'static,
250    UnixR: Clone + Send + Sync + 'static,
251    TlsR: Send + Sync + 'static,
252    UdpR: Send + Sync + 'static,
253{
254    type Stream = UnixR::Stream;
255
256    type Listener = UnixR::Listener;
257
258    type ListenOptions = UnixR::ListenOptions;
259
260    #[inline]
261    #[instrument(skip_all, level = "trace")]
262    async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
263        self.inner.unix.connect(addr).await
264    }
265
266    #[inline]
267    async fn listen(
268        &self,
269        addr: &unix::SocketAddr,
270        options: &Self::ListenOptions,
271    ) -> IoResult<Self::Listener> {
272        self.inner.unix.listen(addr, options).await
273    }
274}
275
276impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR, S> TlsProvider<S>
277    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
278where
279    TcpR: NetStreamProvider,
280    TlsR: TlsProvider<S>,
281    UnixR: Clone + Send + Sync + 'static,
282    SleepR: Clone + Send + Sync + 'static,
283    CoarseTimeR: Clone + Send + Sync + 'static,
284    TaskR: Clone + Send + Sync + 'static,
285    UdpR: Clone + Send + Sync + 'static,
286    S: StreamOps,
287{
288    type Connector = TlsR::Connector;
289    type TlsStream = TlsR::TlsStream;
290    type Acceptor = TlsR::Acceptor;
291    type TlsServerStream = TlsR::TlsServerStream;
292
293    #[inline]
294    fn tls_connector(&self) -> Self::Connector {
295        self.inner.tls.tls_connector()
296    }
297
298    #[inline]
299    fn tls_acceptor(&self, settings: TlsAcceptorSettings) -> IoResult<Self::Acceptor> {
300        self.inner.tls.tls_acceptor(settings)
301    }
302
303    #[inline]
304    fn supports_keying_material_export(&self) -> bool {
305        self.inner.tls.supports_keying_material_export()
306    }
307}
308
309impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> std::fmt::Debug
310    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
311{
312    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        f.debug_struct("CompoundRuntime").finish_non_exhaustive()
314    }
315}
316
317#[async_trait]
318impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> UdpProvider
319    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
320where
321    UdpR: UdpProvider,
322    TaskR: Send + Sync + 'static,
323    SleepR: Send + Sync + 'static,
324    CoarseTimeR: Send + Sync + 'static,
325    TcpR: Send + Sync + 'static,
326    UnixR: Clone + Send + Sync + 'static,
327    TlsR: Send + Sync + 'static,
328    UdpR: Send + Sync + 'static,
329{
330    type UdpSocket = UdpR::UdpSocket;
331
332    #[inline]
333    async fn bind(&self, addr: &net::SocketAddr) -> IoResult<Self::UdpSocket> {
334        self.inner.udp.bind(addr).await
335    }
336}
337
338/// Module to seal RuntimeSubstExt
339mod sealed {
340    /// Helper for sealing RuntimeSubstExt
341    #[allow(unreachable_pub)]
342    pub trait Sealed {}
343}
344/// Extension trait on Runtime:
345/// Construct new Runtimes that replace part of an original runtime.
346///
347/// (If you need to do more complicated versions of this, you should likely construct
348/// CompoundRuntime directly.)
349pub trait RuntimeSubstExt: sealed::Sealed + Sized {
350    /// Return a new runtime wrapping this runtime, but replacing its TCP NetStreamProvider.
351    fn with_tcp_provider<T>(
352        &self,
353        new_tcp: T,
354    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self>;
355    /// Return a new runtime wrapping this runtime, but replacing its SleepProvider.
356    fn with_sleep_provider<T>(
357        &self,
358        new_sleep: T,
359    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self>;
360    /// Return a new runtime wrapping this runtime, but replacing its CoarseTimeProvider.
361    fn with_coarse_time_provider<T>(
362        &self,
363        new_coarse_time: T,
364    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self>;
365}
366impl<R: Runtime> sealed::Sealed for R {}
367impl<R: Runtime + Sized> RuntimeSubstExt for R {
368    fn with_tcp_provider<T>(
369        &self,
370        new_tcp: T,
371    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self> {
372        CompoundRuntime::new(
373            self.clone(),
374            self.clone(),
375            self.clone(),
376            new_tcp,
377            self.clone(),
378            self.clone(),
379            self.clone(),
380        )
381    }
382
383    fn with_sleep_provider<T>(
384        &self,
385        new_sleep: T,
386    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self> {
387        CompoundRuntime::new(
388            self.clone(),
389            new_sleep,
390            self.clone(),
391            self.clone(),
392            self.clone(),
393            self.clone(),
394            self.clone(),
395        )
396    }
397
398    fn with_coarse_time_provider<T>(
399        &self,
400        new_coarse_time: T,
401    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self> {
402        CompoundRuntime::new(
403            self.clone(),
404            self.clone(),
405            new_coarse_time,
406            self.clone(),
407            self.clone(),
408            self.clone(),
409            self.clone(),
410        )
411    }
412}