Skip to main content

tokio_dual_stack/
lib.rs

1//! [![git]](https://git.philomathiclife.com/tokio_dual_stack/log.html) [![crates-io]](https://crates.io/crates/tokio_dual_stack) [![docs-rs]](crate)
2//!
3//! [git]: https://git.philomathiclife.com/git_badge.svg
4//! [crates-io]: https://img.shields.io/badge/crates.io-fc8d62?style=for-the-badge&labelColor=555555&logo=rust
5//! [docs-rs]: https://img.shields.io/badge/docs.rs-66c2a5?style=for-the-badge&labelColor=555555&logo=docs.rs
6//!
7//! `tokio_dual_stack` is a library that adds a "dual-stack" [`TcpListener`].
8//!
9//! ## Why is this useful?
10//!
11//! Only certain platforms offer the ability for one socket to handle both IPv6 and IPv4 requests
12//! (e.g., OpenBSD does not). For the platforms that do, it is often dependent on runtime configuration
13//! (e.g., [`IPV6_V6ONLY`](https://www.man7.org/linux/man-pages/man7/ipv6.7.html)). Additionally those platforms
14//! that support it often require the "wildcard" IPv6 address to be used (i.e., `::`) which has the unfortunate
15//! consequence of preventing other services from using the same protocol port.
16//!
17//! There are a few ways to work around this issue. One is to deploy the same service twice: one that uses
18//! an IPv6 socket and the other that uses an IPv4 socket. This can complicate deployments (e.g., the application
19//! may not have been written with the expectation that multiple deployments could be running at the same time) in
20//! addition to using more resources. Another is for the application to manually handle each socket (e.g.,
21//! [`select`](https://docs.rs/tokio/latest/tokio/macro.select.html)/[`join`](https://docs.rs/tokio/latest/tokio/macro.join.html)
22//! each [`TcpListener::accept`]).
23//!
24//! [`DualStackTcpListener`] chooses an implementation similar to what the equivalent `select` would do while
25//! also ensuring that one socket does not "starve" another by ensuring each socket is fairly given an opportunity
26//! to `TcpListener::accept` a connection. This has the nice benefit of having a similar API to what a single
27//! `TcpListener` would have as well as having similar performance to a socket that does handle both IPv6 and
28//! IPv4 requests.
29#![expect(
30    clippy::doc_paragraphs_missing_punctuation,
31    reason = "false positive for crate documentation having image links"
32)]
33use core::{
34    net::{SocketAddr, SocketAddrV4, SocketAddrV6},
35    pin::Pin,
36    sync::atomic::{AtomicBool, Ordering},
37    task::{Context, Poll},
38};
39use pin_project_lite::pin_project;
40use std::io::{Error, ErrorKind, Result};
41pub use tokio;
42use tokio::net::{self, TcpListener, TcpSocket, TcpStream, ToSocketAddrs};
43/// Prevents [`Sealed`] from being publicly implementable.
44mod private {
45    /// Marker trait to prevent [`super::Tcp`] from being publicly implementable.
46    #[expect(unnameable_types, reason = "want Tcp to be 'sealed'")]
47    pub trait Sealed {}
48}
49use private::Sealed;
50/// TCP "listener".
51///
52/// This `trait` is sealed and cannot be implemented for types outside of `tokio_dual_stack`.
53///
54/// This exists primarily as a way to define type constructors or polymorphic functions
55/// that can user either a [`TcpListener`] or [`DualStackTcpListener`].
56///
57/// # Examples
58///
59/// ```no_run
60/// # use core::convert::Infallible;
61/// # use tokio_dual_stack::Tcp;
62/// async fn main_loop<T: Tcp>(listener: T) -> Infallible {
63///     loop {
64///         match listener.accept().await {
65///             Ok((_, socket)) => println!("Client socket: {socket}"),
66///             Err(e) => println!("TCP connection failure: {e}"),
67///         }
68///     }
69/// }
70/// ```
71pub trait Tcp: Sealed + Sized {
72    /// Creates a new TCP listener, which will be bound to the specified address(es).
73    ///
74    /// The returned listener is ready for accepting connections.
75    ///
76    /// Binding with a port number of 0 will request that the OS assigns a port to this listener.
77    /// The port allocated can be queried via the `local_addr` method.
78    ///
79    /// The address type can be any implementor of the [`ToSocketAddrs`] trait. If `addr` yields
80    /// multiple addresses, bind will be attempted with each of the addresses until one succeeds
81    /// and returns the listener. If none of the addresses succeed in creating a listener, the
82    /// error returned from the last attempt (the last address) is returned.
83    ///
84    /// This function sets the `SO_REUSEADDR` option on the socket.
85    ///
86    /// # Examples
87    ///
88    /// ```no_run
89    /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
90    /// # use std::io::Result;
91    /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
92    /// #[tokio::main(flavor = "current_thread")]
93    /// async fn main() -> Result<()> {
94    ///     let listener = DualStackTcpListener::bind(
95    ///         [
96    ///             SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)),
97    ///             SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)),
98    ///         ]
99    ///         .as_slice(),
100    ///     )
101    ///     .await?;
102    ///     Ok(())
103    /// }
104    /// ```
105    fn bind<A: ToSocketAddrs>(addr: A) -> impl Future<Output = Result<Self>>;
106    /// Accepts a new incoming connection from this listener.
107    ///
108    /// This function will yield once a new TCP connection is established. When established,
109    /// the corresponding `TcpStream` and the remote peer’s address will be returned.
110    ///
111    /// # Cancel safety
112    ///
113    /// This method is cancel safe. If the method is used as the event in a
114    /// [`tokio::select!`](https://docs.rs/tokio/latest/tokio/macro.select.html)
115    /// statement and some other branch completes first, then it is guaranteed that no new
116    /// connections were accepted by this method.
117    ///
118    /// # Examples
119    ///
120    /// ```no_run
121    /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
122    /// # use std::io::Result;
123    /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
124    /// #[tokio::main(flavor = "current_thread")]
125    /// async fn main() -> Result<()> {
126    ///     match DualStackTcpListener::bind(
127    ///         [
128    ///             SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)),
129    ///             SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)),
130    ///         ]
131    ///         .as_slice(),
132    ///     )
133    ///     .await?.accept().await {
134    ///         Ok((_, addr)) => println!("new client: {addr}"),
135    ///         Err(e) => println!("couldn't get client: {e}"),
136    ///     }
137    ///     Ok(())
138    /// }
139    /// ```
140    fn accept(&self) -> impl Future<Output = Result<(TcpStream, SocketAddr)>> + Send + Sync;
141    /// Polls to accept a new incoming connection to this listener.
142    ///
143    /// If there is no connection to accept, `Poll::Pending` is returned and the current task will be notified by
144    /// a waker. Note that on multiple calls to `poll_accept`, only the `Waker` from the `Context` passed to the
145    /// most recent call is scheduled to receive a wakeup.
146    fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>>;
147}
148impl Sealed for TcpListener {}
149impl Tcp for TcpListener {
150    #[inline]
151    fn bind<A: ToSocketAddrs>(addr: A) -> impl Future<Output = Result<Self>> {
152        Self::bind(addr)
153    }
154    #[inline]
155    fn accept(&self) -> impl Future<Output = Result<(TcpStream, SocketAddr)>> + Send + Sync {
156        self.accept()
157    }
158    #[inline]
159    fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>> {
160        self.poll_accept(cx)
161    }
162}
163/// "Dual-stack" TCP listener.
164///
165/// IPv6 and IPv4 TCP listener.
166#[derive(Debug)]
167pub struct DualStackTcpListener {
168    /// IPv6 TCP listener.
169    ip6: TcpListener,
170    /// IPv4 TCP listener.
171    ip4: TcpListener,
172    /// `true` iff [`Self::ip6::accept`] should be `poll`ed first; otherwise [`Self::ip4::accept`] is `poll`ed
173    /// first.
174    ///
175    /// This exists to prevent one IP version from "starving" another. Each time [`Self::accept`] or
176    /// [`Self::poll_accept`] is called, it's overwritten with the opposite `bool`.
177    ///
178    /// Note we could make this a `core::cell::Cell`; but for maximal flexibility and consistency with `TcpListener`,
179    /// we use an `AtomicBool`. This among other things means `DualStackTcpListener` will implement `Sync`.
180    ip6_first: AtomicBool,
181}
182impl DualStackTcpListener {
183    /// Creates `Self` using the [`TcpListener`]s returned from [`TcpSocket::listen`].
184    ///
185    /// [`Self::bind`] is useful when the behavior of [`TcpListener::bind`] is sufficient; however if the underlying
186    /// `TcpSocket`s need to be configured differently, then one must call this function instead.
187    ///
188    /// # Errors
189    ///
190    /// Errors iff [`TcpSocket::local_addr`] does for either socket, the underlying sockets use the same IP version,
191    /// or [`TcpSocket::listen`] errors for either socket.
192    ///
193    /// Note on Windows-based platforms `TcpSocket::local_addr` will error if [`TcpSocket::bind`] was not called.
194    ///
195    /// # Examples
196    ///
197    /// ```no_run
198    /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
199    /// # use std::io::Result;
200    /// # use tokio_dual_stack::DualStackTcpListener;
201    /// # use tokio::net::TcpSocket;
202    /// #[tokio::main(flavor = "current_thread")]
203    /// async fn main() -> Result<()> {
204    ///     let ip6 = TcpSocket::new_v6()?;
205    ///     ip6.bind(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)))?;
206    ///     let ip4 = TcpSocket::new_v4()?;
207    ///     ip4.bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)))?;
208    ///     let listener = DualStackTcpListener::from_sockets((ip6, 1024), (ip4, 1024))?;
209    ///     Ok(())
210    /// }
211    /// ```
212    #[inline]
213    pub fn from_sockets(
214        (socket_1, backlog_1): (TcpSocket, u32),
215        (socket_2, backlog_2): (TcpSocket, u32),
216    ) -> Result<Self> {
217        socket_1.local_addr().and_then(|sock| {
218            socket_2.local_addr().and_then(|sock_2| {
219                if sock.is_ipv6() {
220                    if sock_2.is_ipv4() {
221                        socket_1.listen(backlog_1).and_then(|ip6| {
222                            socket_2.listen(backlog_2).map(|ip4| Self {
223                                ip6,
224                                ip4,
225                                ip6_first: AtomicBool::new(true),
226                            })
227                        })
228                    } else {
229                        Err(Error::new(
230                            ErrorKind::InvalidData,
231                            "TcpSockets are the same IP version",
232                        ))
233                    }
234                } else if sock_2.is_ipv6() {
235                    socket_1.listen(backlog_1).and_then(|ip4| {
236                        socket_2.listen(backlog_2).map(|ip6| Self {
237                            ip6,
238                            ip4,
239                            ip6_first: AtomicBool::new(true),
240                        })
241                    })
242                } else {
243                    Err(Error::new(
244                        ErrorKind::InvalidData,
245                        "TcpSockets are the same IP version",
246                    ))
247                }
248            })
249        })
250    }
251    /// Returns the local address of each socket that the listeners are bound to.
252    ///
253    /// This can be useful, for example, when binding to port 0 to figure out which port was actually bound.
254    ///
255    /// # Errors
256    ///
257    /// Errors iff [`TcpListener::local_addr`] does for either listener.
258    ///
259    /// # Examples
260    ///
261    /// ```no_run
262    /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
263    /// # use std::io::Result;
264    /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
265    /// #[tokio::main(flavor = "current_thread")]
266    /// async fn main() -> Result<()> {
267    ///     let ip6 = SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0);
268    ///     let ip4 = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080);
269    ///     assert_eq!(
270    ///         DualStackTcpListener::bind([SocketAddr::V6(ip6), SocketAddr::V4(ip4)].as_slice())
271    ///             .await?
272    ///             .local_addr()?,
273    ///         (ip6, ip4)
274    ///     );
275    ///     Ok(())
276    /// }
277    /// ```
278    #[expect(clippy::unreachable, reason = "we want to crash when there is a bug")]
279    #[inline]
280    pub fn local_addr(&self) -> Result<(SocketAddrV6, SocketAddrV4)> {
281        self.ip6.local_addr().and_then(|ip6| {
282            self.ip4.local_addr().map(|ip4| {
283                (
284                    if let SocketAddr::V6(sock6) = ip6 {
285                        sock6
286                    } else {
287                        unreachable!("there is a bug in DualStackTcpListener::bind")
288                    },
289                    if let SocketAddr::V4(sock4) = ip4 {
290                        sock4
291                    } else {
292                        unreachable!("there is a bug in DualStackTcpListener::bind")
293                    },
294                )
295            })
296        })
297    }
298    /// Sets the value for the `IP_TTL` option on both sockets.
299    ///
300    /// This value sets the time-to-live field that is used in every packet sent from each socket.
301    /// `ttl_ip6` is the `IP_TTL` value for the IPv6 socket and `ttl_ip4` is the `IP_TTL` value for the
302    /// IPv4 socket.
303    ///
304    /// # Errors
305    ///
306    /// Errors iff [`TcpListener::set_ttl`] does for either listener.
307    ///
308    /// # Examples
309    ///
310    /// ```no_run
311    /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
312    /// # use std::io::Result;
313    /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
314    /// #[tokio::main(flavor = "current_thread")]
315    /// async fn main() -> Result<()> {
316    ///     DualStackTcpListener::bind(
317    ///         [
318    ///             SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)),
319    ///             SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)),
320    ///         ]
321    ///         .as_slice(),
322    ///     )
323    ///     .await?.set_ttl(100, 100).expect("could not set TTL");
324    ///     Ok(())
325    /// }
326    /// ```
327    #[inline]
328    pub fn set_ttl(&self, ttl_ip6: u32, ttl_ip4: u32) -> Result<()> {
329        self.ip6
330            .set_ttl(ttl_ip6)
331            .and_then(|()| self.ip4.set_ttl(ttl_ip4))
332    }
333    /// Gets the values of the `IP_TTL` option for both sockets.
334    ///
335    /// The first `u32` represents the `IP_TTL` value for the IPv6 socket and the second `u32` is the
336    /// `IP_TTL` value for the IPv4 socket. For more information about this option, see [`Self::set_ttl`].
337    ///
338    /// # Errors
339    ///
340    /// Errors iff [`TcpListener::ttl`] does for either listener.
341    ///
342    /// # Examples
343    ///
344    /// ```no_run
345    /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
346    /// # use std::io::Result;
347    /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
348    /// #[tokio::main(flavor = "current_thread")]
349    /// async fn main() -> Result<()> {
350    ///     let listener = DualStackTcpListener::bind(
351    ///         [
352    ///             SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)),
353    ///             SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)),
354    ///         ]
355    ///         .as_slice(),
356    ///     )
357    ///     .await?;
358    ///     listener.set_ttl(100, 100).expect("could not set TTL");
359    ///     assert_eq!(listener.ttl()?, (100, 100));
360    ///     Ok(())
361    /// }
362    /// ```
363    #[inline]
364    pub fn ttl(&self) -> Result<(u32, u32)> {
365        self.ip6
366            .ttl()
367            .and_then(|ip6| self.ip4.ttl().map(|ip4| (ip6, ip4)))
368    }
369}
370pin_project! {
371    /// `Future` returned by [`DualStackTcpListener::accept]`.
372    struct AcceptFut<
373        F: Future<Output = Result<(TcpStream, SocketAddr)>>,
374        F2: Future<Output = Result<(TcpStream, SocketAddr)>>,
375    > {
376        // Accept future for one `TcpListener`.
377        #[pin]
378        fut_1: F,
379        // Accept future for the other `TcpListener`.
380        #[pin]
381        fut_2: F2,
382    }
383}
384impl<
385    F: Future<Output = Result<(TcpStream, SocketAddr)>>,
386    F2: Future<Output = Result<(TcpStream, SocketAddr)>>,
387> Future for AcceptFut<F, F2>
388{
389    type Output = Result<(TcpStream, SocketAddr)>;
390    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
391        let this = self.project();
392        // Note we defer errors caused from polling a completed `Future` to the contained `tokio` `Future`s.
393        // The only time this `Future` can be polled after completion without an error (due to `tokio`) is
394        // if `fut_2` completes first, `self` is polled, then `fut_1` completes. We don't actually care
395        // that this happens since the correctness of the code is still fine.
396        // This means any bugs that could occur from polling this `Future` after completion are dependency-based
397        // bugs where the correct solution is to fix the bugs in `tokio`.
398        match this.fut_1.poll(cx) {
399            Poll::Ready(res) => Poll::Ready(res),
400            Poll::Pending => this.fut_2.poll(cx),
401        }
402    }
403}
404impl Sealed for DualStackTcpListener {}
405impl Tcp for DualStackTcpListener {
406    #[inline]
407    async fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self> {
408        match net::lookup_host(addr).await {
409            Ok(socks) => {
410                let mut last_err = None;
411                let mut ip6_opt = None;
412                let mut ip4_opt = None;
413                for sock in socks {
414                    match ip6_opt {
415                        None => match ip4_opt {
416                            None => {
417                                let is_ip6 = sock.is_ipv6();
418                                match TcpListener::bind(sock).await {
419                                    Ok(ip) => {
420                                        if is_ip6 {
421                                            ip6_opt = Some(ip);
422                                        } else {
423                                            ip4_opt = Some(ip);
424                                        }
425                                    }
426                                    Err(err) => last_err = Some(err),
427                                }
428                            }
429                            Some(ip4) => {
430                                if sock.is_ipv6() {
431                                    match TcpListener::bind(sock).await {
432                                        Ok(ip6) => {
433                                            return Ok(Self {
434                                                ip6,
435                                                ip4,
436                                                ip6_first: AtomicBool::new(true),
437                                            });
438                                        }
439                                        Err(err) => last_err = Some(err),
440                                    }
441                                }
442                                ip4_opt = Some(ip4);
443                            }
444                        },
445                        Some(ip6) => {
446                            if sock.is_ipv4() {
447                                match TcpListener::bind(sock).await {
448                                    Ok(ip4) => {
449                                        return Ok(Self {
450                                            ip6,
451                                            ip4,
452                                            ip6_first: AtomicBool::new(true),
453                                        });
454                                    }
455                                    Err(err) => last_err = Some(err),
456                                }
457                            }
458                            ip6_opt = Some(ip6);
459                        }
460                    }
461                }
462                Err(last_err.unwrap_or_else(|| {
463                    Error::new(
464                        ErrorKind::InvalidInput,
465                        "could not resolve to an IPv6 and IPv4 address",
466                    )
467                }))
468            }
469            Err(err) => Err(err),
470        }
471    }
472    #[inline]
473    fn accept(&self) -> impl Future<Output = Result<(TcpStream, SocketAddr)>> + Send + Sync {
474        // The correctness of code does not depend on `self.ip6_first`; therefore
475        // we elect for the most performant `Ordering`.
476        if self.ip6_first.swap(false, Ordering::Relaxed) {
477            AcceptFut {
478                fut_1: self.ip6.accept(),
479                fut_2: self.ip4.accept(),
480            }
481        } else {
482            // The correctness of code does not depend on `self.ip6_first`; therefore
483            // we elect for the most performant `Ordering`.
484            self.ip6_first.store(true, Ordering::Relaxed);
485            AcceptFut {
486                fut_1: self.ip4.accept(),
487                fut_2: self.ip6.accept(),
488            }
489        }
490    }
491    #[inline]
492    fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>> {
493        // The correctness of code does not depend on `self.ip6_first`; therefore
494        // we elect for the most performant `Ordering`.
495        if self.ip6_first.swap(false, Ordering::Relaxed) {
496            self.ip6.poll_accept(cx)
497        } else {
498            // The correctness of code does not depend on `self.ip6_first`; therefore
499            // we elect for the most performant `Ordering`.
500            self.ip6_first.store(true, Ordering::Relaxed);
501            self.ip4.poll_accept(cx)
502        }
503    }
504}