1use std::{net, sync::Arc, time::Duration};
5
6use crate::traits::*;
7use crate::{CoarseInstant, CoarseTimeProvider};
8use async_trait::async_trait;
9use educe::Educe;
10use futures::{future::FutureObj, task::Spawn};
11use std::future::Future;
12use std::io::Result as IoResult;
13use std::time::{Instant, SystemTime};
14use tor_general_addr::unix;
15use tracing::instrument;
16
17#[derive(Educe)]
31#[educe(Clone)] pub struct CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
33 inner: Arc<Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>>,
38}
39
40struct Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
42 spawn: TaskR,
44 sleep: SleepR,
46 coarse_time: CoarseTimeR,
48 tcp: TcpR,
50 unix: UnixR,
52 tls: TlsR,
54 udp: UdpR,
56}
57
58impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
59 CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
60{
61 pub fn new(
63 spawn: TaskR,
64 sleep: SleepR,
65 coarse_time: CoarseTimeR,
66 tcp: TcpR,
67 unix: UnixR,
68 tls: TlsR,
69 udp: UdpR,
70 ) -> Self {
71 #[allow(clippy::arc_with_non_send_sync)]
72 CompoundRuntime {
73 inner: Arc::new(Inner {
74 spawn,
75 sleep,
76 coarse_time,
77 tcp,
78 unix,
79 tls,
80 udp,
81 }),
82 }
83 }
84}
85
86impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Spawn
87 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
88where
89 TaskR: Spawn,
90{
91 #[inline]
92 #[track_caller]
93 fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> {
94 self.inner.spawn.spawn_obj(future)
95 }
96}
97
98impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Blocking
99 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
100where
101 TaskR: Blocking,
102 SleepR: Clone + Send + Sync + 'static,
103 CoarseTimeR: Clone + Send + Sync + 'static,
104 TcpR: Clone + Send + Sync + 'static,
105 UnixR: Clone + Send + Sync + 'static,
106 TlsR: Clone + Send + Sync + 'static,
107 UdpR: Clone + Send + Sync + 'static,
108{
109 type ThreadHandle<T: Send + 'static> = TaskR::ThreadHandle<T>;
110
111 #[inline]
112 #[track_caller]
113 fn spawn_blocking<F, T>(&self, f: F) -> TaskR::ThreadHandle<T>
114 where
115 F: FnOnce() -> T + Send + 'static,
116 T: Send + 'static,
117 {
118 self.inner.spawn.spawn_blocking(f)
119 }
120
121 #[inline]
122 #[track_caller]
123 fn reenter_block_on<F>(&self, future: F) -> F::Output
124 where
125 F: Future,
126 F::Output: Send + 'static,
127 {
128 self.inner.spawn.reenter_block_on(future)
129 }
130
131 #[track_caller]
132 fn blocking_io<F, T>(&self, f: F) -> impl futures::Future<Output = T>
133 where
134 F: FnOnce() -> T + Send + 'static,
135 T: Send + 'static,
136 {
137 self.inner.spawn.blocking_io(f)
138 }
139}
140
141impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> ToplevelBlockOn
142 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
143where
144 TaskR: ToplevelBlockOn,
145 SleepR: Clone + Send + Sync + 'static,
146 CoarseTimeR: Clone + Send + Sync + 'static,
147 TcpR: Clone + Send + Sync + 'static,
148 UnixR: Clone + Send + Sync + 'static,
149 TlsR: Clone + Send + Sync + 'static,
150 UdpR: Clone + Send + Sync + 'static,
151{
152 #[inline]
153 #[track_caller]
154 fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
155 self.inner.spawn.block_on(future)
156 }
157}
158
159impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> SleepProvider
160 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
161where
162 SleepR: SleepProvider,
163 TaskR: Clone + Send + Sync + 'static,
164 CoarseTimeR: Clone + Send + Sync + 'static,
165 TcpR: Clone + Send + Sync + 'static,
166 UnixR: Clone + Send + Sync + 'static,
167 TlsR: Clone + Send + Sync + 'static,
168 UdpR: Clone + Send + Sync + 'static,
169{
170 type SleepFuture = SleepR::SleepFuture;
171
172 #[inline]
173 fn sleep(&self, duration: Duration) -> Self::SleepFuture {
174 self.inner.sleep.sleep(duration)
175 }
176
177 #[inline]
178 fn now(&self) -> Instant {
179 self.inner.sleep.now()
180 }
181
182 #[inline]
183 fn wallclock(&self) -> SystemTime {
184 self.inner.sleep.wallclock()
185 }
186}
187
188impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> CoarseTimeProvider
189 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
190where
191 CoarseTimeR: CoarseTimeProvider,
192 SleepR: Clone + Send + Sync + 'static,
193 TaskR: Clone + Send + Sync + 'static,
194 CoarseTimeR: Clone + Send + Sync + 'static,
195 TcpR: Clone + Send + Sync + 'static,
196 UnixR: Clone + Send + Sync + 'static,
197 TlsR: Clone + Send + Sync + 'static,
198 UdpR: Clone + Send + Sync + 'static,
199{
200 #[inline]
201 fn now_coarse(&self) -> CoarseInstant {
202 self.inner.coarse_time.now_coarse()
203 }
204}
205
206#[async_trait]
207impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<net::SocketAddr>
208 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
209where
210 TcpR: NetStreamProvider<net::SocketAddr>,
211 TaskR: Send + Sync + 'static,
212 SleepR: Send + Sync + 'static,
213 CoarseTimeR: Send + Sync + 'static,
214 TcpR: Send + Sync + 'static,
215 UnixR: Clone + Send + Sync + 'static,
216 TlsR: Send + Sync + 'static,
217 UdpR: Send + Sync + 'static,
218{
219 type Stream = TcpR::Stream;
220
221 type Listener = TcpR::Listener;
222
223 #[inline]
224 #[instrument(skip_all, level = "trace")]
225 async fn connect(&self, addr: &net::SocketAddr) -> IoResult<Self::Stream> {
226 self.inner.tcp.connect(addr).await
227 }
228
229 #[inline]
230 async fn listen(&self, addr: &net::SocketAddr) -> IoResult<Self::Listener> {
231 self.inner.tcp.listen(addr).await
232 }
233}
234
235#[async_trait]
236impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<unix::SocketAddr>
237 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
238where
239 UnixR: NetStreamProvider<unix::SocketAddr>,
240 TaskR: Send + Sync + 'static,
241 SleepR: Send + Sync + 'static,
242 CoarseTimeR: Send + Sync + 'static,
243 TcpR: Send + Sync + 'static,
244 UnixR: Clone + Send + Sync + 'static,
245 TlsR: Send + Sync + 'static,
246 UdpR: Send + Sync + 'static,
247{
248 type Stream = UnixR::Stream;
249
250 type Listener = UnixR::Listener;
251
252 #[inline]
253 #[instrument(skip_all, level = "trace")]
254 async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
255 self.inner.unix.connect(addr).await
256 }
257
258 #[inline]
259 async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
260 self.inner.unix.listen(addr).await
261 }
262}
263
264impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR, S> TlsProvider<S>
265 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
266where
267 TcpR: NetStreamProvider,
268 TlsR: TlsProvider<S>,
269 UnixR: Clone + Send + Sync + 'static,
270 SleepR: Clone + Send + Sync + 'static,
271 CoarseTimeR: Clone + Send + Sync + 'static,
272 TaskR: Clone + Send + Sync + 'static,
273 UdpR: Clone + Send + Sync + 'static,
274 S: StreamOps,
275{
276 type Connector = TlsR::Connector;
277 type TlsStream = TlsR::TlsStream;
278
279 #[inline]
280 fn tls_connector(&self) -> Self::Connector {
281 self.inner.tls.tls_connector()
282 }
283
284 #[inline]
285 fn supports_keying_material_export(&self) -> bool {
286 self.inner.tls.supports_keying_material_export()
287 }
288}
289
290impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> std::fmt::Debug
291 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
292{
293 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294 f.debug_struct("CompoundRuntime").finish_non_exhaustive()
295 }
296}
297
298#[async_trait]
299impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> UdpProvider
300 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
301where
302 UdpR: UdpProvider,
303 TaskR: Send + Sync + 'static,
304 SleepR: Send + Sync + 'static,
305 CoarseTimeR: Send + Sync + 'static,
306 TcpR: Send + Sync + 'static,
307 UnixR: Clone + Send + Sync + 'static,
308 TlsR: Send + Sync + 'static,
309 UdpR: Send + Sync + 'static,
310{
311 type UdpSocket = UdpR::UdpSocket;
312
313 #[inline]
314 async fn bind(&self, addr: &net::SocketAddr) -> IoResult<Self::UdpSocket> {
315 self.inner.udp.bind(addr).await
316 }
317}
318
319mod sealed {
321 #[allow(unreachable_pub)]
323 pub trait Sealed {}
324}
325pub trait RuntimeSubstExt: sealed::Sealed + Sized {
331 fn with_tcp_provider<T>(
333 &self,
334 new_tcp: T,
335 ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self>;
336 fn with_sleep_provider<T>(
338 &self,
339 new_sleep: T,
340 ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self>;
341 fn with_coarse_time_provider<T>(
343 &self,
344 new_coarse_time: T,
345 ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self>;
346}
347impl<R: Runtime> sealed::Sealed for R {}
348impl<R: Runtime + Sized> RuntimeSubstExt for R {
349 fn with_tcp_provider<T>(
350 &self,
351 new_tcp: T,
352 ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self> {
353 CompoundRuntime::new(
354 self.clone(),
355 self.clone(),
356 self.clone(),
357 new_tcp,
358 self.clone(),
359 self.clone(),
360 self.clone(),
361 )
362 }
363
364 fn with_sleep_provider<T>(
365 &self,
366 new_sleep: T,
367 ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self> {
368 CompoundRuntime::new(
369 self.clone(),
370 new_sleep,
371 self.clone(),
372 self.clone(),
373 self.clone(),
374 self.clone(),
375 self.clone(),
376 )
377 }
378
379 fn with_coarse_time_provider<T>(
380 &self,
381 new_coarse_time: T,
382 ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self> {
383 CompoundRuntime::new(
384 self.clone(),
385 self.clone(),
386 new_coarse_time,
387 self.clone(),
388 self.clone(),
389 self.clone(),
390 self.clone(),
391 )
392 }
393}