1use std::{net, sync::Arc, time::Duration};
5
6use crate::traits::*;
7use crate::{CoarseInstant, CoarseTimeProvider};
8use async_trait::async_trait;
9use educe::Educe;
10use futures::{future::FutureObj, task::Spawn};
11use std::future::Future;
12use std::io::Result as IoResult;
13use tor_general_addr::unix;
14use tracing::instrument;
15use web_time_compat::{Instant, SystemTime};
16
17#[derive(Educe)]
31#[educe(Clone)] pub struct CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
33 inner: Arc<Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>>,
38}
39
40struct Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
42 spawn: TaskR,
44 sleep: SleepR,
46 coarse_time: CoarseTimeR,
48 tcp: TcpR,
50 unix: UnixR,
52 tls: TlsR,
54 udp: UdpR,
56}
57
58impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
59 CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
60{
61 pub fn new(
63 spawn: TaskR,
64 sleep: SleepR,
65 coarse_time: CoarseTimeR,
66 tcp: TcpR,
67 unix: UnixR,
68 tls: TlsR,
69 udp: UdpR,
70 ) -> Self {
71 #[allow(clippy::arc_with_non_send_sync)]
72 CompoundRuntime {
73 inner: Arc::new(Inner {
74 spawn,
75 sleep,
76 coarse_time,
77 tcp,
78 unix,
79 tls,
80 udp,
81 }),
82 }
83 }
84}
85
86impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Spawn
87 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
88where
89 TaskR: Spawn,
90{
91 #[inline]
92 #[track_caller]
93 fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> {
94 self.inner.spawn.spawn_obj(future)
95 }
96}
97
98impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Blocking
99 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
100where
101 TaskR: Blocking,
102 SleepR: Clone + Send + Sync + 'static,
103 CoarseTimeR: Clone + Send + Sync + 'static,
104 TcpR: Clone + Send + Sync + 'static,
105 UnixR: Clone + Send + Sync + 'static,
106 TlsR: Clone + Send + Sync + 'static,
107 UdpR: Clone + Send + Sync + 'static,
108{
109 type ThreadHandle<T: Send + 'static> = TaskR::ThreadHandle<T>;
110
111 #[inline]
112 #[track_caller]
113 fn spawn_blocking<F, T>(&self, f: F) -> TaskR::ThreadHandle<T>
114 where
115 F: FnOnce() -> T + Send + 'static,
116 T: Send + 'static,
117 {
118 self.inner.spawn.spawn_blocking(f)
119 }
120
121 #[inline]
122 #[track_caller]
123 fn reenter_block_on<F>(&self, future: F) -> F::Output
124 where
125 F: Future,
126 F::Output: Send + 'static,
127 {
128 self.inner.spawn.reenter_block_on(future)
129 }
130
131 #[track_caller]
132 fn blocking_io<F, T>(&self, f: F) -> impl futures::Future<Output = T>
133 where
134 F: FnOnce() -> T + Send + 'static,
135 T: Send + 'static,
136 {
137 self.inner.spawn.blocking_io(f)
138 }
139}
140
141impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> ToplevelBlockOn
142 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
143where
144 TaskR: ToplevelBlockOn,
145 SleepR: Clone + Send + Sync + 'static,
146 CoarseTimeR: Clone + Send + Sync + 'static,
147 TcpR: Clone + Send + Sync + 'static,
148 UnixR: Clone + Send + Sync + 'static,
149 TlsR: Clone + Send + Sync + 'static,
150 UdpR: Clone + Send + Sync + 'static,
151{
152 #[inline]
153 #[track_caller]
154 fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
155 self.inner.spawn.block_on(future)
156 }
157}
158
159impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> SleepProvider
160 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
161where
162 SleepR: SleepProvider,
163 TaskR: Clone + Send + Sync + 'static,
164 CoarseTimeR: Clone + Send + Sync + 'static,
165 TcpR: Clone + Send + Sync + 'static,
166 UnixR: Clone + Send + Sync + 'static,
167 TlsR: Clone + Send + Sync + 'static,
168 UdpR: Clone + Send + Sync + 'static,
169{
170 type SleepFuture = SleepR::SleepFuture;
171
172 #[inline]
173 fn sleep(&self, duration: Duration) -> Self::SleepFuture {
174 self.inner.sleep.sleep(duration)
175 }
176
177 #[inline]
178 fn now(&self) -> Instant {
179 self.inner.sleep.now()
180 }
181
182 #[inline]
183 fn wallclock(&self) -> SystemTime {
184 self.inner.sleep.wallclock()
185 }
186}
187
188impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> CoarseTimeProvider
189 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
190where
191 CoarseTimeR: CoarseTimeProvider,
192 SleepR: Clone + Send + Sync + 'static,
193 TaskR: Clone + Send + Sync + 'static,
194 CoarseTimeR: Clone + Send + Sync + 'static,
195 TcpR: Clone + Send + Sync + 'static,
196 UnixR: Clone + Send + Sync + 'static,
197 TlsR: Clone + Send + Sync + 'static,
198 UdpR: Clone + Send + Sync + 'static,
199{
200 #[inline]
201 fn now_coarse(&self) -> CoarseInstant {
202 self.inner.coarse_time.now_coarse()
203 }
204}
205
206#[async_trait]
207impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<net::SocketAddr>
208 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
209where
210 TcpR: NetStreamProvider<net::SocketAddr>,
211 TaskR: Send + Sync + 'static,
212 SleepR: Send + Sync + 'static,
213 CoarseTimeR: Send + Sync + 'static,
214 TcpR: Send + Sync + 'static,
215 UnixR: Clone + Send + Sync + 'static,
216 TlsR: Send + Sync + 'static,
217 UdpR: Send + Sync + 'static,
218{
219 type Stream = TcpR::Stream;
220
221 type Listener = TcpR::Listener;
222
223 type ListenOptions = TcpR::ListenOptions;
224
225 #[inline]
226 #[instrument(skip_all, level = "trace")]
227 async fn connect(&self, addr: &net::SocketAddr) -> IoResult<Self::Stream> {
228 self.inner.tcp.connect(addr).await
229 }
230
231 #[inline]
232 async fn listen(
233 &self,
234 addr: &net::SocketAddr,
235 options: &Self::ListenOptions,
236 ) -> IoResult<Self::Listener> {
237 self.inner.tcp.listen(addr, options).await
238 }
239}
240
241#[async_trait]
242impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<unix::SocketAddr>
243 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
244where
245 UnixR: NetStreamProvider<unix::SocketAddr>,
246 TaskR: Send + Sync + 'static,
247 SleepR: Send + Sync + 'static,
248 CoarseTimeR: Send + Sync + 'static,
249 TcpR: Send + Sync + 'static,
250 UnixR: Clone + Send + Sync + 'static,
251 TlsR: Send + Sync + 'static,
252 UdpR: Send + Sync + 'static,
253{
254 type Stream = UnixR::Stream;
255
256 type Listener = UnixR::Listener;
257
258 type ListenOptions = UnixR::ListenOptions;
259
260 #[inline]
261 #[instrument(skip_all, level = "trace")]
262 async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
263 self.inner.unix.connect(addr).await
264 }
265
266 #[inline]
267 async fn listen(
268 &self,
269 addr: &unix::SocketAddr,
270 options: &Self::ListenOptions,
271 ) -> IoResult<Self::Listener> {
272 self.inner.unix.listen(addr, options).await
273 }
274}
275
276impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR, S> TlsProvider<S>
277 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
278where
279 TcpR: NetStreamProvider,
280 TlsR: TlsProvider<S>,
281 UnixR: Clone + Send + Sync + 'static,
282 SleepR: Clone + Send + Sync + 'static,
283 CoarseTimeR: Clone + Send + Sync + 'static,
284 TaskR: Clone + Send + Sync + 'static,
285 UdpR: Clone + Send + Sync + 'static,
286 S: StreamOps,
287{
288 type Connector = TlsR::Connector;
289 type TlsStream = TlsR::TlsStream;
290 type Acceptor = TlsR::Acceptor;
291 type TlsServerStream = TlsR::TlsServerStream;
292
293 #[inline]
294 fn tls_connector(&self) -> Self::Connector {
295 self.inner.tls.tls_connector()
296 }
297
298 #[inline]
299 fn tls_acceptor(&self, settings: TlsAcceptorSettings) -> IoResult<Self::Acceptor> {
300 self.inner.tls.tls_acceptor(settings)
301 }
302
303 #[inline]
304 fn supports_keying_material_export(&self) -> bool {
305 self.inner.tls.supports_keying_material_export()
306 }
307}
308
309impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> std::fmt::Debug
310 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
311{
312 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 f.debug_struct("CompoundRuntime").finish_non_exhaustive()
314 }
315}
316
317#[async_trait]
318impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> UdpProvider
319 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
320where
321 UdpR: UdpProvider,
322 TaskR: Send + Sync + 'static,
323 SleepR: Send + Sync + 'static,
324 CoarseTimeR: Send + Sync + 'static,
325 TcpR: Send + Sync + 'static,
326 UnixR: Clone + Send + Sync + 'static,
327 TlsR: Send + Sync + 'static,
328 UdpR: Send + Sync + 'static,
329{
330 type UdpSocket = UdpR::UdpSocket;
331
332 #[inline]
333 async fn bind(&self, addr: &net::SocketAddr) -> IoResult<Self::UdpSocket> {
334 self.inner.udp.bind(addr).await
335 }
336}
337
338mod sealed {
340 #[allow(unreachable_pub)]
342 pub trait Sealed {}
343}
344pub trait RuntimeSubstExt: sealed::Sealed + Sized {
350 fn with_tcp_provider<T>(
352 &self,
353 new_tcp: T,
354 ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self>;
355 fn with_sleep_provider<T>(
357 &self,
358 new_sleep: T,
359 ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self>;
360 fn with_coarse_time_provider<T>(
362 &self,
363 new_coarse_time: T,
364 ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self>;
365}
366impl<R: Runtime> sealed::Sealed for R {}
367impl<R: Runtime + Sized> RuntimeSubstExt for R {
368 fn with_tcp_provider<T>(
369 &self,
370 new_tcp: T,
371 ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self> {
372 CompoundRuntime::new(
373 self.clone(),
374 self.clone(),
375 self.clone(),
376 new_tcp,
377 self.clone(),
378 self.clone(),
379 self.clone(),
380 )
381 }
382
383 fn with_sleep_provider<T>(
384 &self,
385 new_sleep: T,
386 ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self> {
387 CompoundRuntime::new(
388 self.clone(),
389 new_sleep,
390 self.clone(),
391 self.clone(),
392 self.clone(),
393 self.clone(),
394 self.clone(),
395 )
396 }
397
398 fn with_coarse_time_provider<T>(
399 &self,
400 new_coarse_time: T,
401 ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self> {
402 CompoundRuntime::new(
403 self.clone(),
404 self.clone(),
405 new_coarse_time,
406 self.clone(),
407 self.clone(),
408 self.clone(),
409 self.clone(),
410 )
411 }
412}