uni_addr/
lib.rs

1#![doc = include_str!("../README.md")]
2#![allow(clippy::must_use_candidate)]
3
4use std::borrow::Cow;
5use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
6use std::str::FromStr;
7use std::sync::Arc;
8use std::{fmt, io};
9
10#[cfg(unix)]
11pub mod unix;
12
13/// The prefix for Unix domain socket URIs.
14///
15/// - `unix:///path/to/socket` for a pathname socket address.
16/// - `unix://@abstract.unix.socket` for an abstract socket address.
17pub const UNIX_URI_PREFIX: &str = "unix://";
18
19wrapper_lite::wrapper!(
20    #[wrapper_impl(Debug)]
21    #[wrapper_impl(Display)]
22    #[wrapper_impl(AsRef)]
23    #[wrapper_impl(Deref)]
24    #[repr(align(cache))]
25    #[derive(Clone, PartialEq, Eq, Hash)]
26    /// A unified address type that can represent:
27    ///
28    /// - [`std::net::SocketAddr`]
29    /// - [`unix::SocketAddr`] (a wrapper over
30    ///   [`std::os::unix::net::SocketAddr`])
31    /// - A host name with port. See [`ToSocketAddrs`].
32    ///
33    /// # Parsing Behaviour
34    ///
35    /// - Checks if the address started with [`UNIX_URI_PREFIX`]: parse as a UDS
36    ///   address.
37    /// - Checks if the address is started with a alphabetic character (a-z,
38    ///   A-Z): treat as a host name. Notes that we will not validate if the
39    ///   host name is valid.
40    /// - Tries to parse as a network socket address.
41    /// - Otherwise, treats the input as a host name.
42    pub struct UniAddr(UniAddrInner);
43);
44
45impl From<SocketAddr> for UniAddr {
46    fn from(addr: SocketAddr) -> Self {
47        UniAddr::from_inner(UniAddrInner::Inet(addr))
48    }
49}
50
51#[cfg(unix)]
52impl From<std::os::unix::net::SocketAddr> for UniAddr {
53    fn from(addr: std::os::unix::net::SocketAddr) -> Self {
54        UniAddr::from_inner(UniAddrInner::Unix(addr.into()))
55    }
56}
57
58#[cfg(all(unix, feature = "feat-tokio"))]
59impl From<tokio::net::unix::SocketAddr> for UniAddr {
60    fn from(addr: tokio::net::unix::SocketAddr) -> Self {
61        UniAddr::from_inner(UniAddrInner::Unix(unix::SocketAddr::from(addr.into())))
62    }
63}
64
65#[cfg(feature = "feat-socket2")]
66impl TryFrom<socket2::SockAddr> for UniAddr {
67    type Error = io::Error;
68
69    fn try_from(addr: socket2::SockAddr) -> Result<Self, Self::Error> {
70        UniAddr::try_from(&addr)
71    }
72}
73
74#[cfg(feature = "feat-socket2")]
75impl TryFrom<&socket2::SockAddr> for UniAddr {
76    type Error = io::Error;
77
78    fn try_from(addr: &socket2::SockAddr) -> Result<Self, Self::Error> {
79        if let Some(addr) = addr.as_socket() {
80            return Ok(Self::from(addr));
81        }
82
83        #[cfg(unix)]
84        if let Some(addr) = addr.as_unix() {
85            return Ok(Self::from(addr));
86        }
87
88        #[cfg(unix)]
89        if addr.is_unnamed() {
90            return Ok(Self::from(crate::unix::SocketAddr::new_unnamed()));
91        }
92
93        #[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))]
94        if let Some(addr) = addr.as_abstract_namespace() {
95            return crate::unix::SocketAddr::new_abstract(addr).map(Self::from);
96        }
97
98        Err(io::Error::new(
99            io::ErrorKind::Other,
100            "unsupported address type",
101        ))
102    }
103}
104
105#[cfg(feature = "feat-socket2")]
106impl TryFrom<UniAddr> for socket2::SockAddr {
107    type Error = io::Error;
108
109    fn try_from(addr: UniAddr) -> Result<Self, Self::Error> {
110        socket2::SockAddr::try_from(&addr)
111    }
112}
113
114#[cfg(feature = "feat-socket2")]
115impl TryFrom<&UniAddr> for socket2::SockAddr {
116    type Error = io::Error;
117
118    fn try_from(addr: &UniAddr) -> Result<Self, Self::Error> {
119        match &addr.inner {
120            UniAddrInner::Inet(addr) => Ok(socket2::SockAddr::from(*addr)),
121            #[cfg(unix)]
122            UniAddrInner::Unix(addr) => socket2::SockAddr::unix(addr.to_os_string()),
123            UniAddrInner::Host(_) => Err(io::Error::new(
124                io::ErrorKind::Other,
125                "The host name address must be resolved before converting to SockAddr",
126            )),
127        }
128    }
129}
130
131#[cfg(unix)]
132impl From<crate::unix::SocketAddr> for UniAddr {
133    fn from(addr: crate::unix::SocketAddr) -> Self {
134        UniAddr::from_inner(UniAddrInner::Unix(addr))
135    }
136}
137
138impl FromStr for UniAddr {
139    type Err = ParseError;
140
141    fn from_str(addr: &str) -> Result<Self, Self::Err> {
142        Self::new(addr)
143    }
144}
145
146#[cfg(feature = "feat-serde")]
147impl serde::Serialize for UniAddr {
148    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
149    where
150        S: serde::Serializer,
151    {
152        serializer.serialize_str(&self.to_str())
153    }
154}
155
156#[cfg(feature = "feat-serde")]
157impl<'de> serde::Deserialize<'de> for UniAddr {
158    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
159    where
160        D: serde::Deserializer<'de>,
161    {
162        Self::new(&String::deserialize(deserializer)?).map_err(serde::de::Error::custom)
163    }
164}
165
166impl UniAddr {
167    #[inline]
168    /// Creates a new [`UniAddr`] from its string representation.
169    ///
170    /// # Errors
171    ///
172    /// Not a valid address string.
173    pub fn new(addr: &str) -> Result<Self, ParseError> {
174        if addr.is_empty() {
175            return Err(ParseError::Empty);
176        }
177
178        #[cfg(unix)]
179        if let Some(addr) = addr.strip_prefix(UNIX_URI_PREFIX) {
180            return unix::SocketAddr::new(addr)
181                .map(UniAddrInner::Unix)
182                .map(Self::from_inner)
183                .map_err(ParseError::InvalidUDSAddress);
184        }
185
186        #[cfg(not(unix))]
187        if let Some(_addr) = addr.strip_prefix(UNIX_URI_PREFIX) {
188            return Err(ParseError::Unsupported);
189        }
190
191        let Some((host, port)) = addr.rsplit_once(':') else {
192            return Err(ParseError::InvalidPort);
193        };
194
195        let Ok(port) = port.parse::<u16>() else {
196            return Err(ParseError::InvalidPort);
197        };
198
199        // Short-circuit: IPv4 address starts with a digit.
200        if host.chars().next().is_some_and(|c| c.is_ascii_digit()) {
201            return Ipv4Addr::from_str(host)
202                .map(|ip| SocketAddr::V4(SocketAddrV4::new(ip, port)))
203                .map(UniAddrInner::Inet)
204                .map(Self::from_inner)
205                .map_err(|_| ParseError::InvalidHost)
206                .or_else(|_| {
207                    // A host name may also start with a digit.
208                    Self::new_host(addr, Some((host, port)))
209                });
210        }
211
212        // Short-circuit: if starts with '[' and ends with ']', may be an IPv6 address
213        // and can never be a host.
214        if let Some(ipv6_addr) = host.strip_prefix('[').and_then(|s| s.strip_suffix(']')) {
215            return Ipv6Addr::from_str(ipv6_addr)
216                .map(|ip| SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)))
217                .map(UniAddrInner::Inet)
218                .map(Self::from_inner)
219                .map_err(|_| ParseError::InvalidHost);
220        }
221
222        // Fallback: check if is a valid host name.
223        Self::new_host(addr, Some((host, port)))
224    }
225
226    /// Creates a new [`UniAddr`] from a string containing a host name and port,
227    /// like `example.com:8080`.
228    ///
229    /// # Errors
230    ///
231    /// - [`ParseError::InvalidHost`] if the host name is invalid.
232    /// - [`ParseError::InvalidPort`] if the port is invalid.
233    pub fn new_host(addr: &str, parsed: Option<(&str, u16)>) -> Result<Self, ParseError> {
234        let (hostname, _port) = match parsed {
235            Some((hostname, port)) => (hostname, port),
236            None => addr
237                .rsplit_once(':')
238                .ok_or(ParseError::InvalidPort)
239                .and_then(|(hostname, port)| {
240                    let Ok(port) = port.parse::<u16>() else {
241                        return Err(ParseError::InvalidPort);
242                    };
243
244                    Ok((hostname, port))
245                })?,
246        };
247
248        Self::validate_host_name(hostname.as_bytes()).map_err(|()| ParseError::InvalidHost)?;
249
250        Ok(Self::from_inner(UniAddrInner::Host(Arc::from(addr))))
251    }
252
253    // https://github.com/rustls/pki-types/blob/b8c04aa6b7a34875e2c4a33edc9b78d31da49523/src/server_name.rs
254    const fn validate_host_name(input: &[u8]) -> Result<(), ()> {
255        enum State {
256            Start,
257            Next,
258            NumericOnly { len: usize },
259            NextAfterNumericOnly,
260            Subsequent { len: usize },
261            Hyphen { len: usize },
262        }
263
264        use State::{Hyphen, Next, NextAfterNumericOnly, NumericOnly, Start, Subsequent};
265
266        /// "Labels must be 63 characters or less."
267        const MAX_LABEL_LENGTH: usize = 63;
268
269        /// <https://devblogs.microsoft.com/oldnewthing/20120412-00/?p=7873>
270        const MAX_NAME_LENGTH: usize = 253;
271
272        let mut state = Start;
273
274        if input.len() > MAX_NAME_LENGTH {
275            return Err(());
276        }
277
278        let mut idx = 0;
279        while idx < input.len() {
280            let ch = input[idx];
281            state = match (state, ch) {
282                (Start | Next | NextAfterNumericOnly | Hyphen { .. }, b'.') => {
283                    return Err(());
284                }
285                (Subsequent { .. }, b'.') => Next,
286                (NumericOnly { .. }, b'.') => NextAfterNumericOnly,
287                (Subsequent { len } | NumericOnly { len } | Hyphen { len }, _)
288                    if len >= MAX_LABEL_LENGTH =>
289                {
290                    return Err(());
291                }
292                (Start | Next | NextAfterNumericOnly, b'0'..=b'9') => NumericOnly { len: 1 },
293                (NumericOnly { len }, b'0'..=b'9') => NumericOnly { len: len + 1 },
294                (Start | Next | NextAfterNumericOnly, b'a'..=b'z' | b'A'..=b'Z' | b'_') => {
295                    Subsequent { len: 1 }
296                }
297                (Subsequent { len } | NumericOnly { len } | Hyphen { len }, b'-') => {
298                    Hyphen { len: len + 1 }
299                }
300                (
301                    Subsequent { len } | NumericOnly { len } | Hyphen { len },
302                    b'a'..=b'z' | b'A'..=b'Z' | b'_' | b'0'..=b'9',
303                ) => Subsequent { len: len + 1 },
304                _ => return Err(()),
305            };
306            idx += 1;
307        }
308
309        if matches!(
310            state,
311            Start | Hyphen { .. } | NumericOnly { .. } | NextAfterNumericOnly | Next
312        ) {
313            return Err(());
314        }
315
316        Ok(())
317    }
318
319    /// Resolves the address if it is a host name.
320    ///
321    /// By default, we utilize the method [`ToSocketAddrs::to_socket_addrs`]
322    /// provided by the standard library to perform DNS resolution, which is a
323    /// **blocking** operation and may take an arbitrary amount of time to
324    /// complete, use with caution when called in asynchronous contexts.
325    ///
326    /// # Errors
327    ///
328    /// Resolution failure, or if no socket address resolved.
329    pub fn blocking_resolve_socket_addrs(&mut self) -> io::Result<()> {
330        self.blocking_resolve_socket_addrs_with(ToSocketAddrs::to_socket_addrs)
331    }
332
333    /// Resolves the address if it is a host name using a custom resolver
334    /// function.
335    ///
336    /// # Errors
337    ///
338    /// Resolution failure, or if no socket address resolved.
339    pub fn blocking_resolve_socket_addrs_with<F, A>(&mut self, f: F) -> io::Result<()>
340    where
341        F: FnOnce(&str) -> io::Result<A>,
342        A: Iterator<Item = SocketAddr>,
343    {
344        if let UniAddrInner::Host(addr) = self.as_inner() {
345            let resolved = f(addr)?.next().ok_or_else(|| {
346                io::Error::new(
347                    io::ErrorKind::Other,
348                    "Host resolution failed, no available address",
349                )
350            })?;
351
352            *self = Self::from_inner(UniAddrInner::Inet(resolved));
353        }
354
355        Ok(())
356    }
357
358    #[cfg(feature = "feat-tokio")]
359    /// Asynchronously resolves the address if it is a host name.
360    ///
361    /// This method will spawn a blocking Tokio task to perform the resolution
362    /// using [`ToSocketAddrs::to_socket_addrs`] provided by the standard
363    /// library.
364    ///
365    /// # Errors
366    ///
367    /// Resolution failure, or if no socket address resolved.
368    pub async fn resolve_socket_addrs(&mut self) -> io::Result<()> {
369        if let UniAddrInner::Host(addr) = self.as_inner() {
370            let addr = addr.clone();
371            let resolved = tokio::task::spawn_blocking(move || addr.to_socket_addrs())
372                .await??
373                .next()
374                .ok_or_else(|| {
375                    io::Error::new(
376                        io::ErrorKind::Other,
377                        "Host resolution failed, no available address",
378                    )
379                })?;
380
381            *self = Self::from_inner(UniAddrInner::Inet(resolved));
382        }
383
384        Ok(())
385    }
386
387    #[inline]
388    /// Serializes the address to a string.
389    pub fn to_str(&self) -> Cow<'_, str> {
390        self.as_inner().to_str()
391    }
392}
393
394#[non_exhaustive]
395#[derive(Debug, Clone, PartialEq, Eq, Hash)]
396/// See [`UniAddr`].
397///
398/// Generally, you should use [`UniAddr`] instead of this type directly, as
399/// we expose this type only for easier pattern matching. A valid [`UniAddr`]
400/// can be constructed only through [`FromStr`] implementation.
401pub enum UniAddrInner {
402    /// See [`SocketAddr`].
403    Inet(SocketAddr),
404
405    #[cfg(unix)]
406    /// See [`SocketAddr`](crate::unix::SocketAddr).
407    Unix(crate::unix::SocketAddr),
408
409    /// A host name with port.
410    ///
411    /// Please refer to [`ToSocketAddrs`], and
412    /// [`UniAddr::blocking_resolve_socket_addrs`], etc to resolve the
413    /// address when needed.
414    Host(Arc<str>),
415}
416
417impl fmt::Display for UniAddrInner {
418    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
419        self.to_str().fmt(f)
420    }
421}
422
423impl UniAddrInner {
424    #[inline]
425    /// Serializes the address to a string.
426    pub fn to_str(&self) -> Cow<'_, str> {
427        match self {
428            Self::Inet(addr) => addr.to_string().into(),
429            #[cfg(unix)]
430            Self::Unix(addr) => addr
431                .to_os_string_impl(UNIX_URI_PREFIX, "@")
432                .to_string_lossy()
433                .to_string()
434                .into(),
435            Self::Host(host) => Cow::Borrowed(host),
436        }
437    }
438}
439
440#[derive(Debug)]
441/// Errors that can occur when parsing a [`UniAddr`] from a string.
442pub enum ParseError {
443    /// Empty input string
444    Empty,
445
446    /// Invalid or missing hostname, or an invalid Ipv4 / IPv6 address
447    InvalidHost,
448
449    /// Invalid address format: missing or invalid port
450    InvalidPort,
451
452    /// Invalid UDS address format
453    InvalidUDSAddress(io::Error),
454
455    /// Unsupported address type on this platform
456    Unsupported,
457}
458
459impl fmt::Display for ParseError {
460    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
461        match self {
462            Self::Empty => write!(f, "empty address string"),
463            Self::InvalidHost => write!(f, "invalid or missing host address"),
464            Self::InvalidPort => write!(f, "invalid or missing port"),
465            Self::InvalidUDSAddress(err) => write!(f, "invalid UDS address: {err}"),
466            Self::Unsupported => write!(f, "unsupported address type on this platform"),
467        }
468    }
469}
470
471impl std::error::Error for ParseError {
472    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
473        match self {
474            Self::InvalidUDSAddress(err) => Some(err),
475            _ => None,
476        }
477    }
478}
479
480impl From<ParseError> for io::Error {
481    fn from(value: ParseError) -> Self {
482        io::Error::new(io::ErrorKind::Other, value)
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    #![allow(non_snake_case)]
489
490    use rstest::rstest;
491
492    use super::*;
493
494    #[rstest]
495    #[case("0.0.0.0:0")]
496    #[case("0.0.0.0:8080")]
497    #[case("127.0.0.1:0")]
498    #[case("127.0.0.1:8080")]
499    #[case("[::]:0")]
500    #[case("[::]:8080")]
501    #[case("[::1]:0")]
502    #[case("[::1]:8080")]
503    #[case("example.com:8080")]
504    #[case("1example.com:8080")]
505    #[cfg_attr(unix, case("unix://"))]
506    #[cfg_attr(
507        any(target_os = "android", target_os = "linux", target_os = "cygwin"),
508        case("unix://@")
509    )]
510    #[cfg_attr(unix, case("unix:///tmp/test_UniAddr_new_Display.socket"))]
511    #[cfg_attr(
512        any(target_os = "android", target_os = "linux", target_os = "cygwin"),
513        case("unix://@test_UniAddr_new_Display.socket")
514    )]
515    fn test_UniAddr_new_Display(#[case] addr: &str) {
516        let addr_displayed = UniAddr::new(addr).unwrap().to_string();
517
518        assert_eq!(
519            addr_displayed, addr,
520            "addr_displayed {addr_displayed:?} != {addr:?}"
521        );
522    }
523
524    #[rstest]
525    #[case("example.com:8080")]
526    #[case("1example.com:8080")]
527    #[should_panic]
528    #[case::panic("1example.com")]
529    #[should_panic]
530    #[case::panic("1example.com.")]
531    #[should_panic]
532    #[case::panic("1example.com.:14514")]
533    #[should_panic]
534    #[case::panic("1example.com:1919810")]
535    #[should_panic]
536    #[case::panic("this-is-a-long-host-name-this-is-a-long-host-name-this-is-a-long-host-name-this-is-a-long-host-name-this-is-a-long-host-name-this-is-a-long-host-name-this-is-a-long-host-name-this-is-a-long-host-name-this-is-a-long-host-name-this-is-a-long-host-name-this-is-a-long-host-name:19810")]
537    fn test_UniAddr_new_host(#[case] addr: &str) {
538        let addr_displayed = UniAddr::new_host(addr, None).unwrap().to_string();
539
540        assert_eq!(
541            addr_displayed, addr,
542            "addr_displayed {addr_displayed:?} != {addr:?}"
543        );
544    }
545
546    #[rstest]
547    #[should_panic]
548    #[case::panic("")]
549    #[should_panic]
550    #[case::panic("not-an-address")]
551    #[should_panic]
552    #[case::panic("127.0.0.1")]
553    #[should_panic]
554    #[case::panic("127.0.0.1:99999")]
555    #[should_panic]
556    #[case::panic("127.0.0.256:99999")]
557    #[should_panic]
558    #[case::panic("::1")]
559    #[should_panic]
560    #[case::panic("[::1]")]
561    #[should_panic]
562    #[case::panic("[::1]:99999")]
563    #[should_panic]
564    #[case::panic("[::gg]:99999")]
565    #[should_panic]
566    #[case::panic("example.com")]
567    #[should_panic]
568    #[case::panic("example.com:99999")]
569    #[should_panic]
570    #[case::panic("examp😀le.com:99999")]
571    fn test_UniAddr_new_invalid(#[case] addr: &str) {
572        let _ = UniAddr::new(addr).unwrap();
573    }
574
575    #[cfg(not(unix))]
576    #[test]
577    fn test_UniAddr_new_unsupported() {
578        // Unix sockets should be unsupported on non-Unix platforms
579        let result = UniAddr::new("unix:///tmp/test.sock");
580
581        assert!(matches!(result.unwrap_err(), ParseError::Unsupported));
582    }
583
584    #[rstest]
585    #[case("0.0.0.0:0")]
586    #[case("0.0.0.0:8080")]
587    #[case("127.0.0.1:0")]
588    #[case("127.0.0.1:8080")]
589    #[case("[::]:0")]
590    #[case("[::]:8080")]
591    #[case("[::1]:0")]
592    #[case("[::1]:8080")]
593    #[cfg_attr(unix, case("unix:///tmp/test_socket2_sock_addr_conversion.socket"))]
594    #[cfg_attr(unix, case("unix://"))]
595    #[cfg_attr(
596        any(target_os = "android", target_os = "linux", target_os = "cygwin"),
597        case("unix://@test_socket2_sock_addr_conversion.socket")
598    )]
599    fn test_socket2_SockAddr_conversion(#[case] addr: &str) {
600        let uni_addr = UniAddr::new(addr).unwrap();
601        let sock_addr = socket2::SockAddr::try_from(&uni_addr).unwrap();
602        let uni_addr_converted = UniAddr::try_from(sock_addr).unwrap();
603
604        assert_eq!(
605            uni_addr, uni_addr_converted,
606            "{uni_addr} != {uni_addr_converted}"
607        );
608    }
609}