Skip to main content

tor_rtcompat/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3// @@ begin lint list maintained by maint/add_warning @@
4#![allow(renamed_and_removed_lints)] // @@REMOVE_WHEN(ci_arti_stable)
5#![allow(unknown_lints)] // @@REMOVE_WHEN(ci_arti_nightly)
6#![warn(missing_docs)]
7#![warn(noop_method_call)]
8#![warn(unreachable_pub)]
9#![warn(clippy::all)]
10#![deny(clippy::await_holding_lock)]
11#![deny(clippy::cargo_common_metadata)]
12#![deny(clippy::cast_lossless)]
13#![deny(clippy::checked_conversions)]
14#![warn(clippy::cognitive_complexity)]
15#![deny(clippy::debug_assert_with_mut_call)]
16#![deny(clippy::exhaustive_enums)]
17#![deny(clippy::exhaustive_structs)]
18#![deny(clippy::expl_impl_clone_on_copy)]
19#![deny(clippy::fallible_impl_from)]
20#![deny(clippy::implicit_clone)]
21#![deny(clippy::large_stack_arrays)]
22#![warn(clippy::manual_ok_or)]
23#![deny(clippy::missing_docs_in_private_items)]
24#![warn(clippy::needless_borrow)]
25#![warn(clippy::needless_pass_by_value)]
26#![warn(clippy::option_option)]
27#![deny(clippy::print_stderr)]
28#![deny(clippy::print_stdout)]
29#![warn(clippy::rc_buffer)]
30#![deny(clippy::ref_option_ref)]
31#![warn(clippy::semicolon_if_nothing_returned)]
32#![warn(clippy::trait_duplication_in_bounds)]
33#![deny(clippy::unchecked_time_subtraction)]
34#![deny(clippy::unnecessary_wraps)]
35#![warn(clippy::unseparated_literal_suffix)]
36#![deny(clippy::unwrap_used)]
37#![deny(clippy::mod_module_files)]
38#![allow(clippy::let_unit_value)] // This can reasonably be done for explicitness
39#![allow(clippy::uninlined_format_args)]
40#![allow(clippy::significant_drop_in_scrutinee)] // arti/-/merge_requests/588/#note_2812945
41#![allow(clippy::result_large_err)] // temporary workaround for arti#587
42#![allow(clippy::needless_raw_string_hashes)] // complained-about code is fine, often best
43#![allow(clippy::needless_lifetimes)] // See arti#1765
44#![allow(mismatched_lifetime_syntaxes)] // temporary workaround for arti#2060
45#![deny(clippy::unused_async)]
46//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
47
48// TODO #1645 (either remove this, or decide to have it everywhere)
49#![cfg_attr(not(all(feature = "full")), allow(unused))]
50
51#[cfg(all(
52    any(feature = "native-tls", feature = "rustls"),
53    any(feature = "async-std", feature = "tokio", feature = "smol")
54))]
55pub(crate) mod impls;
56pub mod task;
57
58mod coarse_time;
59mod compound;
60mod dyn_time;
61pub mod general;
62mod opaque;
63pub mod scheduler;
64mod timer;
65mod traits;
66pub mod unimpl;
67pub mod unix;
68
69#[cfg(any(feature = "async-std", feature = "tokio", feature = "smol"))]
70use std::io;
71pub use traits::{
72    Blocking, CertifiedConn, CoarseTimeProvider, NetStreamListener, NetStreamProvider,
73    NoOpStreamOpsHandle, Runtime, SleepProvider, SpawnExt, StreamOps, TlsProvider, ToplevelBlockOn,
74    ToplevelRuntime, UdpProvider, UdpSocket, UnsupportedStreamOp,
75};
76
77pub use coarse_time::{CoarseDuration, CoarseInstant, RealCoarseTimeProvider};
78pub use dyn_time::DynTimeProvider;
79pub use timer::{SleepProviderExt, Timeout, TimeoutError};
80
81/// Traits used to describe TLS connections and objects that can
82/// create them.
83pub mod tls {
84    #[cfg(all(
85        any(feature = "native-tls", feature = "rustls"),
86        any(feature = "async-std", feature = "tokio", feature = "smol")
87    ))]
88    pub use crate::impls::unimpl_tls::UnimplementedTls;
89    pub use crate::traits::{
90        CertifiedConn, TlsAcceptorSettings, TlsConnector, TlsServerUnsupported,
91    };
92
93    #[cfg(all(
94        feature = "native-tls",
95        any(feature = "tokio", feature = "async-std", feature = "smol")
96    ))]
97    pub use crate::impls::native_tls::NativeTlsProvider;
98    #[cfg(all(
99        feature = "rustls",
100        any(feature = "tokio", feature = "async-std", feature = "smol")
101    ))]
102    pub use crate::impls::rustls::RustlsProvider;
103    #[cfg(all(
104        feature = "rustls",
105        feature = "tls-server",
106        any(feature = "tokio", feature = "async-std", feature = "smol")
107    ))]
108    pub use crate::impls::rustls::rustls_server::{RustlsAcceptor, RustlsServerStream};
109}
110
111#[cfg(all(any(feature = "native-tls", feature = "rustls"), feature = "tokio"))]
112pub mod tokio;
113
114#[cfg(all(any(feature = "native-tls", feature = "rustls"), feature = "async-std"))]
115pub mod async_std;
116
117#[cfg(all(any(feature = "native-tls", feature = "rustls"), feature = "smol"))]
118pub mod smol;
119
120pub use compound::{CompoundRuntime, RuntimeSubstExt};
121
122#[cfg(all(
123    any(feature = "native-tls", feature = "rustls"),
124    feature = "async-std",
125    not(feature = "tokio")
126))]
127use async_std as preferred_backend_mod;
128#[cfg(all(any(feature = "native-tls", feature = "rustls"), feature = "tokio"))]
129use tokio as preferred_backend_mod;
130
131/// The runtime that we prefer to use, out of all the runtimes compiled into the
132/// tor-rtcompat crate.
133///
134/// If `tokio` and `async-std` are both available, we prefer `tokio` for its
135/// performance.
136/// If `native_tls` and `rustls` are both available, we prefer `native_tls` since
137/// it has been used in Arti for longer.
138///
139/// The process [**may not fork**](crate#do-not-fork)
140/// (except, very carefully, before exec)
141/// after creating this or any other `Runtime`.
142#[cfg(all(
143    any(feature = "native-tls", feature = "rustls"),
144    any(feature = "async-std", feature = "tokio")
145))]
146#[derive(Clone)]
147pub struct PreferredRuntime {
148    /// The underlying runtime object.
149    inner: preferred_backend_mod::PreferredRuntime,
150}
151
152#[cfg(all(
153    any(feature = "native-tls", feature = "rustls"),
154    any(feature = "async-std", feature = "tokio")
155))]
156crate::opaque::implement_opaque_runtime! {
157    PreferredRuntime { inner : preferred_backend_mod::PreferredRuntime }
158}
159
160#[cfg(all(
161    any(feature = "native-tls", feature = "rustls"),
162    any(feature = "async-std", feature = "tokio")
163))]
164impl PreferredRuntime {
165    /// Obtain a [`PreferredRuntime`] from the currently running asynchronous runtime.
166    /// Generally, this is what you want.
167    ///
168    /// This tries to get a handle to a currently running asynchronous runtime, and
169    /// wraps it; the returned [`PreferredRuntime`] isn't the same thing as the
170    /// asynchronous runtime object itself (e.g. `tokio::runtime::Runtime`).
171    ///
172    /// # Panics
173    ///
174    /// When `tor-rtcompat` is compiled with the `tokio` feature enabled
175    /// (regardless of whether the `async-std` feature is also enabled),
176    /// panics if called outside of Tokio runtime context.
177    /// See `tokio::runtime::Handle::current`.
178    ///
179    /// # Usage notes
180    ///
181    /// Once you have a runtime returned by this function, you should
182    /// just create more handles to it via [`Clone`].
183    ///
184    /// # Limitations
185    ///
186    /// If the `tor-rtcompat` crate was compiled with `tokio` support,
187    /// this function will never return a runtime based on `async_std`.
188    ///
189    /// The process [**may not fork**](crate#do-not-fork)
190    /// (except, very carefully, before exec)
191    /// after creating this or any other `Runtime`.
192    //
193    // ## Note to Arti developers
194    //
195    // We should never call this from inside other Arti crates, or from
196    // library crates that want to support multiple runtimes!  This
197    // function is for Arti _users_ who want to wrap some existing Tokio
198    // or Async_std runtime as a [`Runtime`].  It is not for library
199    // crates that want to work with multiple runtimes.
200    pub fn current() -> io::Result<Self> {
201        let rt = preferred_backend_mod::PreferredRuntime::current()?;
202
203        Ok(Self { inner: rt })
204    }
205
206    /// Create and return a new instance of the default [`Runtime`].
207    ///
208    /// Generally you should call this function at most once, and then use
209    /// [`Clone::clone()`] to create additional references to that runtime.
210    ///
211    /// Tokio users may want to avoid this function and instead obtain a runtime using
212    /// [`PreferredRuntime::current`]: this function always _builds_ a runtime,
213    /// and if you already have a runtime, that isn't what you want with Tokio.
214    ///
215    /// If you need more fine-grained control over a runtime, you can create it
216    /// using an appropriate builder type or function.
217    ///
218    /// The process [**may not fork**](crate#do-not-fork)
219    /// (except, very carefully, before exec)
220    /// after creating this or any other `Runtime`.
221    //
222    // ## Note to Arti developers
223    //
224    // We should never call this from inside other Arti crates, or from
225    // library crates that want to support multiple runtimes!  This
226    // function is for Arti _users_ who want to wrap some existing Tokio
227    // or Async_std runtime as a [`Runtime`].  It is not for library
228    // crates that want to work with multiple runtimes.
229    pub fn create() -> io::Result<Self> {
230        let rt = preferred_backend_mod::PreferredRuntime::create()?;
231
232        Ok(Self { inner: rt })
233    }
234
235    /// Helper to run a single test function in a freshly created runtime.
236    ///
237    /// # Panics
238    ///
239    /// Panics if we can't create this runtime.
240    ///
241    /// # Warning
242    ///
243    /// This API is **NOT** for consumption outside Arti. Semver guarantees are not provided.
244    #[doc(hidden)]
245    pub fn run_test<P, F, O>(func: P) -> O
246    where
247        P: FnOnce(Self) -> F,
248        F: futures::Future<Output = O>,
249    {
250        let runtime = Self::create().expect("Failed to create runtime");
251        runtime.clone().block_on(func(runtime))
252    }
253}
254
255/// Helpers for test_with_all_runtimes
256///
257/// # Warning
258///
259/// This API is **NOT** for consumption outside Arti. Semver guarantees are not provided.
260#[doc(hidden)]
261pub mod testing__ {
262    /// A trait for an object that might represent a test failure, or which
263    /// might just be `()`.
264    pub trait TestOutcome {
265        /// Abort if the test has failed.
266        fn check_ok(&self);
267    }
268    impl TestOutcome for () {
269        fn check_ok(&self) {}
270    }
271    impl<E: std::fmt::Debug> TestOutcome for Result<(), E> {
272        fn check_ok(&self) {
273            self.as_ref().expect("Test failure");
274        }
275    }
276}
277
278/// Helper: define a macro that expands a token tree iff a pair of features are
279/// both present.
280macro_rules! declare_conditional_macro {
281    ( $(#[$meta:meta])* macro $name:ident = ($f1:expr, $f2:expr) ) => {
282        $( #[$meta] )*
283        #[cfg(all(feature=$f1, feature=$f2))]
284        #[macro_export]
285        macro_rules! $name {
286            ($tt:tt) => {
287                $tt
288            };
289        }
290
291        $( #[$meta] )*
292        #[cfg(not(all(feature=$f1, feature=$f2)))]
293        #[macro_export]
294        macro_rules! $name {
295            ($tt:tt) => {};
296        }
297
298        // Needed so that we can access this macro at this path, both within the
299        // crate and without.
300        pub use $name;
301    };
302}
303
304/// Defines macros that will expand when certain runtimes are available.
305#[doc(hidden)]
306pub mod cond {
307    declare_conditional_macro! {
308        /// Expand a token tree if the TokioNativeTlsRuntime is available.
309        #[doc(hidden)]
310        macro if_tokio_native_tls_present = ("tokio", "native-tls")
311    }
312    declare_conditional_macro! {
313        /// Expand a token tree if the TokioRustlsRuntime is available.
314        #[doc(hidden)]
315        macro if_tokio_rustls_present = ("tokio", "rustls")
316    }
317    declare_conditional_macro! {
318        /// Expand a token tree if the TokioNativeTlsRuntime is available.
319        #[doc(hidden)]
320        macro if_async_std_native_tls_present = ("async-std", "native-tls")
321    }
322    declare_conditional_macro! {
323        /// Expand a token tree if the TokioNativeTlsRuntime is available.
324        #[doc(hidden)]
325        macro if_async_std_rustls_present = ("async-std", "rustls")
326    }
327    declare_conditional_macro! {
328        /// Expand a token tree if the SmolNativeTlsRuntime is available.
329        #[doc(hidden)]
330        macro if_smol_native_tls_present = ("smol", "native-tls")
331    }
332    declare_conditional_macro! {
333        /// Expand a token tree if the SmolRustlsRuntime is available.
334        #[doc(hidden)]
335        macro if_smol_rustls_present = ("smol", "rustls")
336    }
337}
338
339/// Run a test closure, passing as argument every supported runtime.
340///
341/// Usually, prefer `tor_rtmock::MockRuntime::test_with_various` to this.
342/// Use this macro only when you need to interact with things
343/// that `MockRuntime` can't handle,
344///
345/// If everything in your test case is supported by `MockRuntime`,
346/// you should use that instead:
347/// that will give superior test coverage *and* a (more) deterministic test.
348///
349/// (This is a macro so that it can repeat the closure as multiple separate
350/// expressions, so it can take on two different types, if needed.)
351//
352// NOTE(eta): changing this #[cfg] can affect tests inside this crate that use
353//            this macro, like in scheduler.rs
354#[macro_export]
355#[cfg(all(
356    any(feature = "native-tls", feature = "rustls"),
357    any(feature = "tokio", feature = "async-std", feature = "smol"),
358))]
359macro_rules! test_with_all_runtimes {
360    ( $fn:expr ) => {{
361        use $crate::cond::*;
362        use $crate::testing__::TestOutcome;
363        // We have to do this outcome-checking business rather than just using
364        // the ? operator or calling expect() because some of the closures that
365        // we use this macro with return (), and some return Result.
366
367        if_tokio_native_tls_present! {{
368           $crate::tokio::TokioNativeTlsRuntime::run_test($fn).check_ok();
369        }}
370        if_tokio_rustls_present! {{
371            $crate::tokio::TokioRustlsRuntime::run_test($fn).check_ok();
372        }}
373        if_async_std_native_tls_present! {{
374            $crate::async_std::AsyncStdNativeTlsRuntime::run_test($fn).check_ok();
375        }}
376        if_async_std_rustls_present! {{
377            $crate::async_std::AsyncStdRustlsRuntime::run_test($fn).check_ok();
378        }}
379        if_smol_native_tls_present! {{
380            $crate::smol::SmolNativeTlsRuntime::run_test($fn).check_ok();
381        }}
382        if_smol_rustls_present! {{
383            $crate::smol::SmolRustlsRuntime::run_test($fn).check_ok();
384        }}
385    }};
386}
387
388/// Run a test closure, passing as argument one supported runtime.
389///
390/// Usually, prefer `tor_rtmock::MockRuntime::test_with_various` to this.
391/// Use this macro only when you need to interact with things
392/// that `MockRuntime` can't handle.
393///
394/// If everything in your test case is supported by `MockRuntime`,
395/// you should use that instead:
396/// that will give superior test coverage *and* a (more) deterministic test.
397///
398/// (Always prefers tokio if present.)
399#[macro_export]
400#[cfg(all(
401    any(feature = "native-tls", feature = "rustls"),
402    any(feature = "tokio", feature = "async-std"),
403))]
404macro_rules! test_with_one_runtime {
405    ( $fn:expr ) => {{ $crate::PreferredRuntime::run_test($fn) }};
406}
407
408#[cfg(all(
409    test,
410    any(feature = "native-tls", feature = "rustls"),
411    any(feature = "async-std", feature = "tokio", feature = "smol"),
412    not(miri), // Many of these tests use real sockets or SystemTime.
413))]
414mod test {
415    // @@ begin test lint list maintained by maint/add_warning @@
416    #![allow(clippy::bool_assert_comparison)]
417    #![allow(clippy::clone_on_copy)]
418    #![allow(clippy::dbg_macro)]
419    #![allow(clippy::mixed_attributes_style)]
420    #![allow(clippy::print_stderr)]
421    #![allow(clippy::print_stdout)]
422    #![allow(clippy::single_char_pattern)]
423    #![allow(clippy::unwrap_used)]
424    #![allow(clippy::unchecked_time_subtraction)]
425    #![allow(clippy::useless_vec)]
426    #![allow(clippy::needless_pass_by_value)]
427    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
428    #![allow(clippy::unnecessary_wraps)]
429    use crate::SleepProviderExt;
430    use crate::ToplevelRuntime;
431
432    use crate::traits::*;
433
434    use futures::io::{AsyncReadExt, AsyncWriteExt};
435    use futures::stream::StreamExt;
436    use native_tls_crate as native_tls;
437    use std::io::Result as IoResult;
438    use std::net::SocketAddr;
439    use std::net::{Ipv4Addr, SocketAddrV4};
440    use std::time::{Duration, Instant};
441
442    // Test "sleep" with a tiny delay, and make sure that at least that
443    // much delay happens.
444    fn small_delay<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
445        let rt = runtime.clone();
446        runtime.block_on(async {
447            let i1 = Instant::now();
448            let one_msec = Duration::from_millis(1);
449            rt.sleep(one_msec).await;
450            let i2 = Instant::now();
451            assert!(i2 >= i1 + one_msec);
452        });
453        Ok(())
454    }
455
456    // Try a timeout operation that will succeed.
457    fn small_timeout_ok<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
458        let rt = runtime.clone();
459        runtime.block_on(async {
460            let one_day = Duration::from_secs(86400);
461            let outcome = rt.timeout(one_day, async { 413_u32 }).await;
462            assert_eq!(outcome, Ok(413));
463        });
464        Ok(())
465    }
466
467    // Try a timeout operation that will time out.
468    fn small_timeout_expire<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
469        use futures::future::pending;
470
471        let rt = runtime.clone();
472        runtime.block_on(async {
473            let one_micros = Duration::from_micros(1);
474            let outcome = rt.timeout(one_micros, pending::<()>()).await;
475            assert_eq!(outcome, Err(crate::TimeoutError));
476            assert_eq!(
477                outcome.err().unwrap().to_string(),
478                "Timeout expired".to_string()
479            );
480        });
481        Ok(())
482    }
483    // Try a little wallclock delay.
484    //
485    // NOTE: This test will fail if the clock jumps a lot while it's
486    // running.  We should use simulated time instead.
487    fn tiny_wallclock<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
488        let rt = runtime.clone();
489        runtime.block_on(async {
490            let i1 = Instant::now();
491            let now = runtime.wallclock();
492            let one_millis = Duration::from_millis(1);
493            let one_millis_later = now + one_millis;
494
495            rt.sleep_until_wallclock(one_millis_later).await;
496
497            let i2 = Instant::now();
498            let newtime = runtime.wallclock();
499            assert!(newtime >= one_millis_later);
500            assert!(i2 - i1 >= one_millis);
501        });
502        Ok(())
503    }
504
505    // Try connecting to ourself and sending a little data.
506    //
507    // NOTE: requires Ipv4 localhost.
508    fn self_connect_tcp<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
509        let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0);
510        let rt1 = runtime.clone();
511
512        let listener = runtime.block_on(rt1.listen(&(SocketAddr::from(localhost))))?;
513        let addr = listener.local_addr()?;
514
515        runtime.block_on(async {
516            let task1 = async {
517                let mut buf = vec![0_u8; 11];
518                let (mut con, _addr) = listener.incoming().next().await.expect("closed?")?;
519                con.read_exact(&mut buf[..]).await?;
520                IoResult::Ok(buf)
521            };
522            let task2 = async {
523                let mut con = rt1.connect(&addr).await?;
524                con.write_all(b"Hello world").await?;
525                con.flush().await?;
526                IoResult::Ok(())
527            };
528
529            let (data, send_r) = futures::join!(task1, task2);
530            send_r?;
531
532            assert_eq!(&data?[..], b"Hello world");
533
534            Ok(())
535        })
536    }
537
538    // Try connecting to ourself and sending a little data.
539    //
540    // NOTE: requires Ipv4 localhost.
541    fn self_connect_udp<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
542        let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0);
543        let rt1 = runtime.clone();
544
545        let socket1 = runtime.block_on(rt1.bind(&(localhost.into())))?;
546        let addr1 = socket1.local_addr()?;
547
548        let socket2 = runtime.block_on(rt1.bind(&(localhost.into())))?;
549        let addr2 = socket2.local_addr()?;
550
551        runtime.block_on(async {
552            let task1 = async {
553                let mut buf = [0_u8; 16];
554                let (len, addr) = socket1.recv(&mut buf[..]).await?;
555                IoResult::Ok((buf[..len].to_vec(), addr))
556            };
557            let task2 = async {
558                socket2.send(b"Hello world", &addr1).await?;
559                IoResult::Ok(())
560            };
561
562            let (recv_r, send_r) = futures::join!(task1, task2);
563            send_r?;
564            let (buff, addr) = recv_r?;
565            assert_eq!(addr2, addr);
566            assert_eq!(&buff, b"Hello world");
567
568            Ok(())
569        })
570    }
571
572    // Try out our incoming connection stream code.
573    //
574    // We launch a few connections and make sure that we can read data on
575    // them.
576    fn listener_stream<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
577        let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0);
578        let rt1 = runtime.clone();
579
580        let listener = runtime
581            .block_on(rt1.listen(&SocketAddr::from(localhost)))
582            .unwrap();
583        let addr = listener.local_addr().unwrap();
584        let mut stream = listener.incoming();
585
586        runtime.block_on(async {
587            let task1 = async {
588                let mut n = 0_u32;
589                loop {
590                    let (mut con, _addr) = stream.next().await.unwrap()?;
591                    let mut buf = [0_u8; 11];
592                    con.read_exact(&mut buf[..]).await?;
593                    n += 1;
594                    if &buf[..] == b"world done!" {
595                        break IoResult::Ok(n);
596                    }
597                }
598            };
599            let task2 = async {
600                for _ in 0_u8..5 {
601                    let mut con = rt1.connect(&addr).await?;
602                    con.write_all(b"Hello world").await?;
603                    con.flush().await?;
604                }
605                let mut con = rt1.connect(&addr).await?;
606                con.write_all(b"world done!").await?;
607                con.flush().await?;
608                con.close().await?;
609                IoResult::Ok(())
610            };
611
612            let (n, send_r) = futures::join!(task1, task2);
613            send_r?;
614
615            assert_eq!(n?, 6);
616
617            Ok(())
618        })
619    }
620
621    // Try listening on an address and connecting there, except using TLS.
622    //
623    // Note that since we didn't have TLS server support when this test was first written,
624    // we're going to use a thread.
625    fn simple_tls<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
626        /*
627         A simple expired self-signed rsa-2048 certificate.
628
629         Generated by running the make-cert.c program in tor-rtcompat/test-data-helper,
630         and then making a PFX file using
631
632         openssl pkcs12 -export -certpbe PBE-SHA1-3DES -out test.pfx -inkey test.key -in test.crt
633
634         The password is "abc".
635        */
636        static PFX_ID: &[u8] = include_bytes!("test.pfx");
637        // Note that we need to set a password on the pkcs12 file, since apparently
638        // OSX doesn't support pkcs12 with empty passwords. (That was arti#111).
639        static PFX_PASSWORD: &str = "abc";
640
641        let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0);
642        let listener = std::net::TcpListener::bind(localhost)?;
643        let addr = listener.local_addr()?;
644
645        let identity = native_tls::Identity::from_pkcs12(PFX_ID, PFX_PASSWORD).unwrap();
646
647        // See note on function for why we're using a thread here.
648        let th = std::thread::spawn(move || {
649            // Accept a single TLS connection and run an echo server
650            use std::io::{Read, Write};
651            let acceptor = native_tls::TlsAcceptor::new(identity).unwrap();
652            let (con, _addr) = listener.accept()?;
653            let mut con = acceptor.accept(con).unwrap();
654            let mut buf = [0_u8; 16];
655            loop {
656                let n = con.read(&mut buf)?;
657                if n == 0 {
658                    break;
659                }
660                con.write_all(&buf[..n])?;
661            }
662            IoResult::Ok(())
663        });
664
665        let connector = runtime.tls_connector();
666
667        runtime.block_on(async {
668            let text = b"I Suddenly Dont Understand Anything";
669            let mut buf = vec![0_u8; text.len()];
670            let conn = runtime.connect(&addr).await?;
671            let mut conn = connector.negotiate_unvalidated(conn, "Kan.Aya").await?;
672            assert!(conn.peer_certificate()?.is_some());
673            conn.write_all(text).await?;
674            conn.flush().await?;
675            conn.read_exact(&mut buf[..]).await?;
676            assert_eq!(&buf[..], text);
677            conn.close().await?;
678            IoResult::Ok(())
679        })?;
680
681        th.join().unwrap()?;
682        IoResult::Ok(())
683    }
684
685    fn simple_tls_server<R: ToplevelRuntime>(runtime: &R) -> IoResult<()> {
686        let mut rng = tor_basic_utils::test_rng::testing_rng();
687        let tls_cert = tor_cert_x509::TlsKeyAndCert::create(
688            &mut rng,
689            std::time::SystemTime::now(),
690            "prospit.example.org",
691            "derse.example.org",
692        )
693        .unwrap();
694        let cert = tls_cert.certificates_der()[0].to_vec();
695        let settings = TlsAcceptorSettings::new(tls_cert).unwrap();
696
697        let Ok(tls_acceptor) = runtime.tls_acceptor(settings) else {
698            println!("Skipping tls-server test for runtime {:?}", &runtime);
699            return IoResult::Ok(());
700        };
701        println!("Running tls-server test for runtime {:?}", &runtime);
702
703        let tls_connector = runtime.tls_connector();
704
705        let localhost: SocketAddr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0).into();
706        let rt1 = runtime.clone();
707
708        let msg = b"Derse Reviles Him And Outlaws Frogs Wherever They Can";
709        runtime.block_on(async move {
710            let listener = runtime.listen(&localhost).await.unwrap();
711            let address = listener.local_addr().unwrap();
712
713            let h1 = runtime
714                .spawn_with_handle(async move {
715                    let conn = listener.incoming().next().await.unwrap().unwrap().0;
716                    let mut conn = tls_acceptor.negotiate_unvalidated(conn, "").await.unwrap();
717
718                    let mut buf = vec![];
719                    conn.read_to_end(&mut buf).await.unwrap();
720                    (buf, conn.own_certificate())
721                })
722                .unwrap();
723
724            let h2 = runtime
725                .spawn_with_handle(async move {
726                    let conn = rt1.connect(&address).await.unwrap();
727                    let mut conn = tls_connector
728                        .negotiate_unvalidated(conn, "prospit.example.org")
729                        .await
730                        .unwrap();
731                    conn.write_all(msg).await.unwrap();
732                    conn.close().await.unwrap();
733                    conn.peer_certificate()
734                })
735                .unwrap();
736
737            let (received, server_own_cert) = h1.await;
738            let client_peer_cert = h2.await;
739            assert_eq!(received, msg);
740            assert_eq!(&server_own_cert.unwrap().unwrap(), &cert);
741            assert_eq!(&client_peer_cert.unwrap().unwrap(), &cert);
742        });
743        IoResult::Ok(())
744    }
745
746    macro_rules! tests_with_runtime {
747        { $runtime:expr  => $($id:ident),* $(,)? } => {
748            $(
749                #[test]
750                fn $id() -> std::io::Result<()> {
751                    super::$id($runtime)
752                }
753            )*
754        }
755    }
756
757    macro_rules! runtime_tests {
758        { $($id:ident),* $(,)? } =>
759        {
760           #[cfg(feature="tokio")]
761            mod tokio_runtime_tests {
762                tests_with_runtime! { &crate::tokio::PreferredRuntime::create()? => $($id),* }
763            }
764            #[cfg(feature="async-std")]
765            mod async_std_runtime_tests {
766                tests_with_runtime! { &crate::async_std::PreferredRuntime::create()? => $($id),* }
767            }
768            #[cfg(feature="smol")]
769            mod smol_runtime_tests {
770                tests_with_runtime! { &crate::smol::PreferredRuntime::create()? => $($id),* }
771            }
772            mod default_runtime_tests {
773                tests_with_runtime! { &crate::PreferredRuntime::create()? => $($id),* }
774            }
775        }
776    }
777
778    macro_rules! tls_runtime_tests {
779        { $($id:ident),* $(,)? } =>
780        {
781            #[cfg(all(feature="tokio", feature = "native-tls"))]
782            mod tokio_native_tls_tests {
783                tests_with_runtime! { &crate::tokio::TokioNativeTlsRuntime::create()? => $($id),* }
784            }
785            #[cfg(all(feature="async-std", feature = "native-tls"))]
786            mod async_std_native_tls_tests {
787                tests_with_runtime! { &crate::async_std::AsyncStdNativeTlsRuntime::create()? => $($id),* }
788            }
789            #[cfg(all(feature="smol", feature = "native-tls"))]
790            mod smol_native_tls_tests {
791                tests_with_runtime! { &crate::smol::SmolNativeTlsRuntime::create()? => $($id),* }
792            }
793            #[cfg(all(feature="tokio", feature="rustls"))]
794            mod tokio_rustls_tests {
795                tests_with_runtime! {  &crate::tokio::TokioRustlsRuntime::create()? => $($id),* }
796            }
797            #[cfg(all(feature="async-std", feature="rustls"))]
798            mod async_std_rustls_tests {
799                tests_with_runtime! {  &crate::async_std::AsyncStdRustlsRuntime::create()? => $($id),* }
800            }
801            #[cfg(all(feature="smol", feature="rustls"))]
802            mod smol_rustls_tests {
803                tests_with_runtime! {  &crate::smol::SmolRustlsRuntime::create()? => $($id),* }
804            }
805            mod default_runtime_tls_tests {
806                tests_with_runtime! { &crate::PreferredRuntime::create()? => $($id),* }
807            }
808        }
809    }
810
811    runtime_tests! {
812        small_delay,
813        small_timeout_ok,
814        small_timeout_expire,
815        tiny_wallclock,
816        self_connect_tcp,
817        self_connect_udp,
818        listener_stream,
819    }
820
821    tls_runtime_tests! {
822        simple_tls,
823        simple_tls_server,
824    }
825}