1use std::{net, sync::Arc, time::Duration};
5
6use crate::traits::*;
7use crate::{CoarseInstant, CoarseTimeProvider};
8use async_trait::async_trait;
9use educe::Educe;
10use futures::{future::FutureObj, task::Spawn};
11use std::io::Result as IoResult;
12use std::time::{Instant, SystemTime};
13use tor_general_addr::unix;
14
15#[derive(Educe)]
29#[educe(Clone)] pub struct CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
31 inner: Arc<Inner<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>>,
36}
37
38struct Inner<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
40 spawn: SpawnR,
42 sleep: SleepR,
44 coarse_time: CoarseTimeR,
46 tcp: TcpR,
48 unix: UnixR,
50 tls: TlsR,
52 udp: UdpR,
54}
55
56impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
57 CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
58{
59 pub fn new(
61 spawn: SpawnR,
62 sleep: SleepR,
63 coarse_time: CoarseTimeR,
64 tcp: TcpR,
65 unix: UnixR,
66 tls: TlsR,
67 udp: UdpR,
68 ) -> Self {
69 #[allow(clippy::arc_with_non_send_sync)]
70 CompoundRuntime {
71 inner: Arc::new(Inner {
72 spawn,
73 sleep,
74 coarse_time,
75 tcp,
76 unix,
77 tls,
78 udp,
79 }),
80 }
81 }
82}
83
84impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Spawn
85 for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
86where
87 SpawnR: Spawn,
88{
89 #[inline]
90 #[track_caller]
91 fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> {
92 self.inner.spawn.spawn_obj(future)
93 }
94}
95
96impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> SpawnBlocking
97 for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
98where
99 SpawnR: SpawnBlocking,
100 SleepR: Clone + Send + Sync + 'static,
101 CoarseTimeR: Clone + Send + Sync + 'static,
102 TcpR: Clone + Send + Sync + 'static,
103 UnixR: Clone + Send + Sync + 'static,
104 TlsR: Clone + Send + Sync + 'static,
105 UdpR: Clone + Send + Sync + 'static,
106{
107 type Handle<T: Send + 'static> = SpawnR::Handle<T>;
108
109 #[inline]
110 #[track_caller]
111 fn spawn_blocking<F, T>(&self, f: F) -> SpawnR::Handle<T>
112 where
113 F: FnOnce() -> T + Send + 'static,
114 T: Send + 'static,
115 {
116 self.inner.spawn.spawn_blocking(f)
117 }
118}
119
120impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> BlockOn
121 for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
122where
123 SpawnR: BlockOn,
124 SleepR: Clone + Send + Sync + 'static,
125 CoarseTimeR: Clone + Send + Sync + 'static,
126 TcpR: Clone + Send + Sync + 'static,
127 UnixR: Clone + Send + Sync + 'static,
128 TlsR: Clone + Send + Sync + 'static,
129 UdpR: Clone + Send + Sync + 'static,
130{
131 #[inline]
132 #[track_caller]
133 fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
134 self.inner.spawn.block_on(future)
135 }
136}
137
138impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> SleepProvider
139 for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
140where
141 SleepR: SleepProvider,
142 SpawnR: Clone + Send + Sync + 'static,
143 CoarseTimeR: Clone + Send + Sync + 'static,
144 TcpR: Clone + Send + Sync + 'static,
145 UnixR: Clone + Send + Sync + 'static,
146 TlsR: Clone + Send + Sync + 'static,
147 UdpR: Clone + Send + Sync + 'static,
148{
149 type SleepFuture = SleepR::SleepFuture;
150
151 #[inline]
152 fn sleep(&self, duration: Duration) -> Self::SleepFuture {
153 self.inner.sleep.sleep(duration)
154 }
155
156 #[inline]
157 fn now(&self) -> Instant {
158 self.inner.sleep.now()
159 }
160
161 #[inline]
162 fn wallclock(&self) -> SystemTime {
163 self.inner.sleep.wallclock()
164 }
165}
166
167impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> CoarseTimeProvider
168 for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
169where
170 CoarseTimeR: CoarseTimeProvider,
171 SleepR: Clone + Send + Sync + 'static,
172 SpawnR: Clone + Send + Sync + 'static,
173 CoarseTimeR: Clone + Send + Sync + 'static,
174 TcpR: Clone + Send + Sync + 'static,
175 UnixR: Clone + Send + Sync + 'static,
176 TlsR: Clone + Send + Sync + 'static,
177 UdpR: Clone + Send + Sync + 'static,
178{
179 #[inline]
180 fn now_coarse(&self) -> CoarseInstant {
181 self.inner.coarse_time.now_coarse()
182 }
183}
184
185#[async_trait]
186impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<net::SocketAddr>
187 for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
188where
189 TcpR: NetStreamProvider<net::SocketAddr>,
190 SpawnR: Send + Sync + 'static,
191 SleepR: Send + Sync + 'static,
192 CoarseTimeR: Send + Sync + 'static,
193 TcpR: Send + Sync + 'static,
194 UnixR: Clone + Send + Sync + 'static,
195 TlsR: Send + Sync + 'static,
196 UdpR: Send + Sync + 'static,
197{
198 type Stream = TcpR::Stream;
199
200 type Listener = TcpR::Listener;
201
202 #[inline]
203 async fn connect(&self, addr: &net::SocketAddr) -> IoResult<Self::Stream> {
204 self.inner.tcp.connect(addr).await
205 }
206
207 #[inline]
208 async fn listen(&self, addr: &net::SocketAddr) -> IoResult<Self::Listener> {
209 self.inner.tcp.listen(addr).await
210 }
211}
212
213#[async_trait]
214impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<unix::SocketAddr>
215 for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
216where
217 UnixR: NetStreamProvider<unix::SocketAddr>,
218 SpawnR: Send + Sync + 'static,
219 SleepR: Send + Sync + 'static,
220 CoarseTimeR: Send + Sync + 'static,
221 TcpR: Send + Sync + 'static,
222 UnixR: Clone + Send + Sync + 'static,
223 TlsR: Send + Sync + 'static,
224 UdpR: Send + Sync + 'static,
225{
226 type Stream = UnixR::Stream;
227
228 type Listener = UnixR::Listener;
229
230 #[inline]
231 async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
232 self.inner.unix.connect(addr).await
233 }
234
235 #[inline]
236 async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
237 self.inner.unix.listen(addr).await
238 }
239}
240
241impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR, S> TlsProvider<S>
242 for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
243where
244 TcpR: NetStreamProvider,
245 TlsR: TlsProvider<S>,
246 UnixR: Clone + Send + Sync + 'static,
247 SleepR: Clone + Send + Sync + 'static,
248 CoarseTimeR: Clone + Send + Sync + 'static,
249 SpawnR: Clone + Send + Sync + 'static,
250 UdpR: Clone + Send + Sync + 'static,
251 S: StreamOps,
252{
253 type Connector = TlsR::Connector;
254 type TlsStream = TlsR::TlsStream;
255
256 #[inline]
257 fn tls_connector(&self) -> Self::Connector {
258 self.inner.tls.tls_connector()
259 }
260
261 #[inline]
262 fn supports_keying_material_export(&self) -> bool {
263 self.inner.tls.supports_keying_material_export()
264 }
265}
266
267impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> std::fmt::Debug
268 for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
269{
270 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271 f.debug_struct("CompoundRuntime").finish_non_exhaustive()
272 }
273}
274
275#[async_trait]
276impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> UdpProvider
277 for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
278where
279 UdpR: UdpProvider,
280 SpawnR: Send + Sync + 'static,
281 SleepR: Send + Sync + 'static,
282 CoarseTimeR: Send + Sync + 'static,
283 TcpR: Send + Sync + 'static,
284 UnixR: Clone + Send + Sync + 'static,
285 TlsR: Send + Sync + 'static,
286 UdpR: Send + Sync + 'static,
287{
288 type UdpSocket = UdpR::UdpSocket;
289
290 #[inline]
291 async fn bind(&self, addr: &net::SocketAddr) -> IoResult<Self::UdpSocket> {
292 self.inner.udp.bind(addr).await
293 }
294}
295
296mod sealed {
298 #[allow(unreachable_pub)]
300 pub trait Sealed {}
301}
302pub trait RuntimeSubstExt: sealed::Sealed + Sized {
308 fn with_tcp_provider<T>(
310 &self,
311 new_tcp: T,
312 ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self>;
313 fn with_sleep_provider<T>(
315 &self,
316 new_sleep: T,
317 ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self>;
318 fn with_coarse_time_provider<T>(
320 &self,
321 new_coarse_time: T,
322 ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self>;
323}
324impl<R: Runtime> sealed::Sealed for R {}
325impl<R: Runtime + Sized> RuntimeSubstExt for R {
326 fn with_tcp_provider<T>(
327 &self,
328 new_tcp: T,
329 ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self> {
330 CompoundRuntime::new(
331 self.clone(),
332 self.clone(),
333 self.clone(),
334 new_tcp,
335 self.clone(),
336 self.clone(),
337 self.clone(),
338 )
339 }
340
341 fn with_sleep_provider<T>(
342 &self,
343 new_sleep: T,
344 ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self> {
345 CompoundRuntime::new(
346 self.clone(),
347 new_sleep,
348 self.clone(),
349 self.clone(),
350 self.clone(),
351 self.clone(),
352 self.clone(),
353 )
354 }
355
356 fn with_coarse_time_provider<T>(
357 &self,
358 new_coarse_time: T,
359 ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self> {
360 CompoundRuntime::new(
361 self.clone(),
362 self.clone(),
363 new_coarse_time,
364 self.clone(),
365 self.clone(),
366 self.clone(),
367 self.clone(),
368 )
369 }
370}