uni_addr/
lib.rs

1//! # uni-addr
2
3use std::borrow::Cow;
4use std::str::FromStr;
5use std::sync::Arc;
6use std::{fmt, io};
7
8pub mod listener;
9#[cfg(unix)]
10pub mod unix;
11
12/// The prefix for Unix domain socket URIs.
13///
14/// - `unix:///path/to/socket` for a pathname socket address.
15/// - `unix://@abstract.unix.socket` for an abstract socket address.
16pub const UNIX_URI_PREFIX: &str = "unix://";
17
18wrapper_lite::wrapper!(
19    #[wrapper_impl(Debug)]
20    #[wrapper_impl(Display)]
21    #[wrapper_impl(AsRef)]
22    #[wrapper_impl(Deref)]
23    #[derive(Clone, PartialEq, Eq, Hash)]
24    /// A unified address type that can represent:
25    ///
26    /// - [`std::net::SocketAddr`]
27    /// - [`unix::SocketAddr`] (a wrapper over
28    ///   [`std::os::unix::net::SocketAddr`])
29    /// - A host name with port. See [`ToSocketAddrs`].
30    ///
31    /// # Parsing Behaviour
32    ///
33    /// - Checks if the address started with [`UNIX_URI_PREFIX`]: parse as a UDS
34    ///   address.
35    /// - Checks if the address is started with a alphabetic character (a-z,
36    ///   A-Z): treat as a host name. Notes that we will not validate if the
37    ///   host name is valid.
38    /// - Tries to parse as a network socket address.
39    /// - Otherwise, treats the input as a host name.
40    pub struct UniAddr(UniAddrInner);
41);
42
43impl From<std::net::SocketAddr> for UniAddr {
44    fn from(addr: std::net::SocketAddr) -> Self {
45        UniAddr::const_from(UniAddrInner::Inet(addr))
46    }
47}
48
49#[cfg(unix)]
50impl From<unix::SocketAddr> for UniAddr {
51    fn from(addr: unix::SocketAddr) -> Self {
52        UniAddr::const_from(UniAddrInner::Unix(addr))
53    }
54}
55
56#[cfg(all(unix, feature = "feat-tokio"))]
57impl From<tokio::net::unix::SocketAddr> for UniAddr {
58    fn from(addr: tokio::net::unix::SocketAddr) -> Self {
59        UniAddr::const_from(UniAddrInner::Unix(unix::SocketAddr::from(addr.into())))
60    }
61}
62
63impl FromStr for UniAddr {
64    type Err = ParseError;
65
66    fn from_str(addr: &str) -> Result<Self, Self::Err> {
67        if addr.is_empty() {
68            return Err(ParseError::Empty);
69        }
70
71        if let Some(addr) = addr.strip_prefix(UNIX_URI_PREFIX) {
72            #[cfg(unix)]
73            {
74                return unix::SocketAddr::new(addr)
75                    .map(UniAddrInner::Unix)
76                    .map(Self::const_from)
77                    .map_err(ParseError::InvalidUDSAddress);
78            }
79
80            #[cfg(not(unix))]
81            {
82                return Err(ParseError::Unsupported);
83            }
84        }
85
86        let Some((host, port)) = addr.rsplit_once(':') else {
87            return Err(ParseError::InvalidPort);
88        };
89
90        {
91            let Some(char) = host.chars().next() else {
92                return Err(ParseError::InvalidHost);
93            };
94
95            if char.is_ascii_alphabetic() {
96                if port.parse::<u16>().is_err() {
97                    return Err(ParseError::InvalidPort);
98                }
99
100                return Ok(Self::const_from(UniAddrInner::Host(Arc::from(addr))));
101            }
102        }
103
104        if let Ok(addr) = addr.parse::<std::net::SocketAddr>() {
105            return Ok(Self::const_from(UniAddrInner::Inet(addr)));
106        }
107
108        if port.parse::<u16>().is_err() {
109            return Err(ParseError::InvalidPort);
110        }
111
112        Ok(Self::const_from(UniAddrInner::Host(Arc::from(addr))))
113    }
114}
115
116#[derive(Debug)]
117/// Errors that can occur when parsing a [`UniAddr`] from a string.
118pub enum ParseError {
119    /// Empty input string
120    Empty,
121
122    /// Missing host address
123    InvalidHost,
124
125    /// Invalid address format: missing or invalid port
126    InvalidPort,
127
128    /// Invalid UDS address format
129    InvalidUDSAddress(io::Error),
130
131    /// Unsupported address type on this platform
132    Unsupported,
133}
134
135impl fmt::Display for ParseError {
136    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137        match self {
138            Self::Empty => write!(f, "empty address string"),
139            Self::InvalidHost => write!(f, "invalid or missing host address"),
140            Self::InvalidPort => write!(f, "invalid or missing port"),
141            Self::InvalidUDSAddress(err) => write!(f, "invalid UDS address: {}", err),
142            Self::Unsupported => write!(f, "unsupported address type on this platform"),
143        }
144    }
145}
146
147impl std::error::Error for ParseError {
148    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
149        match self {
150            Self::InvalidUDSAddress(err) => Some(err),
151            _ => None,
152        }
153    }
154}
155
156#[cfg(feature = "feat-serde")]
157impl serde::Serialize for UniAddr {
158    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
159    where
160        S: serde::Serializer,
161    {
162        serializer.serialize_str(&self.to_str())
163    }
164}
165
166#[cfg(feature = "feat-serde")]
167impl<'de> serde::Deserialize<'de> for UniAddr {
168    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
169    where
170        D: serde::Deserializer<'de>,
171    {
172        Self::new(<&str>::deserialize(deserializer)?).map_err(serde::de::Error::custom)
173    }
174}
175
176impl UniAddr {
177    #[inline]
178    /// Creates a new [`UniAddr`] from its string representation.
179    pub fn new(addr: &str) -> Result<Self, ParseError> {
180        addr.parse()
181    }
182
183    #[inline]
184    /// Serializes the address to a string.
185    pub fn to_str(&self) -> Cow<'_, str> {
186        match self.as_inner() {
187            UniAddrInner::Inet(addr) => addr.to_string().into(),
188            UniAddrInner::Unix(addr) => addr
189                ._to_os_string(UNIX_URI_PREFIX, "@")
190                .to_string_lossy()
191                .to_string()
192                .into(),
193            UniAddrInner::Host(host) => (&**host).into(),
194        }
195    }
196}
197
198#[non_exhaustive]
199#[derive(Debug, Clone, PartialEq, Eq, Hash)]
200/// See [`UniAddr`].
201///
202/// Generally, you should use [`UniAddr`] instead of this type directly, as
203/// we expose this type only for easier pattern matching. A valid [`UniAddr`]
204/// can be constructed only through [`FromStr`] implementation.
205pub enum UniAddrInner {
206    /// See [`std::net::SocketAddr`].
207    Inet(std::net::SocketAddr),
208
209    #[cfg(unix)]
210    /// See [`unix::SocketAddr`].
211    Unix(unix::SocketAddr),
212
213    /// A host name with port. See [`ToSocketAddrs`](std::net::ToSocketAddrs).
214    Host(Arc<str>),
215}
216
217impl fmt::Display for UniAddrInner {
218    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219        match self {
220            Self::Inet(addr) => addr.fmt(f),
221            #[cfg(unix)]
222            Self::Unix(addr) => write!(f, "{}", addr._to_os_string(UNIX_URI_PREFIX, "@").to_string_lossy()),
223            Self::Host(host) => host.fmt(f),
224        }
225    }
226}
227
228#[deprecated(since = "0.2.4", note = "Please use `UniAddr` instead")]
229#[derive(Clone, PartialEq, Eq, Hash)]
230/// A unified address type that can represent both
231/// [`std::net::SocketAddr`] and [`unix::SocketAddr`] (a wrapper over
232/// [`std::os::unix::net::SocketAddr`]).
233///
234/// ## Notes
235///
236/// For Unix domain sockets addresses, serialization/deserialization will be
237/// performed in URI format (see [`UNIX_URI_PREFIX`]), which is different from
238/// [`unix::SocketAddr`]'s serialization/deserialization behaviour.
239pub enum SocketAddr {
240    /// See [`std::net::SocketAddr`].
241    Inet(std::net::SocketAddr),
242
243    #[cfg(unix)]
244    /// See [`unix::SocketAddr`].
245    Unix(unix::SocketAddr),
246}
247
248#[allow(deprecated)]
249impl fmt::Debug for SocketAddr {
250    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251        match self {
252            SocketAddr::Inet(addr) => addr.fmt(f),
253            #[cfg(unix)]
254            SocketAddr::Unix(addr) => addr.fmt(f),
255        }
256    }
257}
258
259#[allow(deprecated)]
260impl From<std::net::SocketAddr> for SocketAddr {
261    fn from(addr: std::net::SocketAddr) -> Self {
262        SocketAddr::Inet(addr)
263    }
264}
265
266#[allow(deprecated)]
267#[cfg(unix)]
268impl From<unix::SocketAddr> for SocketAddr {
269    fn from(addr: unix::SocketAddr) -> Self {
270        SocketAddr::Unix(addr)
271    }
272}
273
274#[allow(deprecated)]
275#[cfg(all(unix, feature = "feat-tokio"))]
276impl From<tokio::net::unix::SocketAddr> for SocketAddr {
277    fn from(addr: tokio::net::unix::SocketAddr) -> Self {
278        SocketAddr::Unix(unix::SocketAddr::from(addr.into()))
279    }
280}
281
282#[allow(deprecated)]
283impl FromStr for SocketAddr {
284    type Err = io::Error;
285
286    fn from_str(s: &str) -> Result<Self, Self::Err> {
287        SocketAddr::new(s)
288    }
289}
290
291#[allow(deprecated)]
292impl SocketAddr {
293    #[inline]
294    /// Creates a new [`SocketAddr`] from its string representation.
295    ///
296    /// The string can be in one of the following formats:
297    ///
298    /// - Network socket address: `"127.0.0.1:8080"`, `"[::1]:8080"`
299    /// - Unix domain socket (pathname): `"unix:///run/listen.sock"`
300    /// - Unix domain socket (abstract): `"unix://@abstract.unix.socket"`
301    ///
302    /// # Examples
303    ///
304    /// ```rust
305    /// # use uni_addr::SocketAddr;
306    /// // Network addresses
307    /// let addr_v4 = SocketAddr::new("127.0.0.1:8080").unwrap();
308    /// let addr_v6 = SocketAddr::new("[::1]:8080").unwrap();
309    ///
310    /// // Unix domain sockets
311    /// let addr_unix_filename = SocketAddr::new("unix:///run/listen.sock").unwrap();
312    /// let addr_unix_abstract = SocketAddr::new("unix://@abstract.unix.socket").unwrap();
313    /// ```
314    ///
315    /// See [`unix::SocketAddr::new`] for more details on Unix socket address
316    /// formats.
317    pub fn new(addr: &str) -> io::Result<Self> {
318        if let Some(addr) = addr.strip_prefix(UNIX_URI_PREFIX) {
319            #[cfg(unix)]
320            return unix::SocketAddr::new(addr).map(SocketAddr::Unix);
321
322            #[cfg(not(unix))]
323            return Err(io::Error::new(
324                io::ErrorKind::Unsupported,
325                "Unix socket addresses are not supported on this platform",
326            ));
327        }
328
329        addr.parse()
330            .map(SocketAddr::Inet)
331            .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "unknown format"))
332    }
333
334    #[inline]
335    /// Binds a standard (TCP) listener to the address.
336    pub fn bind_std(&self) -> io::Result<listener::StdListener> {
337        match self {
338            SocketAddr::Inet(addr) => std::net::TcpListener::bind(addr).map(listener::StdListener::Tcp),
339            #[cfg(unix)]
340            SocketAddr::Unix(addr) => addr.bind_std().map(listener::StdListener::Unix),
341        }
342    }
343
344    #[cfg(feature = "feat-tokio")]
345    #[inline]
346    /// Binds a Tokio (TCP) listener to the address.
347    pub async fn bind(&self) -> io::Result<listener::Listener> {
348        match self {
349            SocketAddr::Inet(addr) => tokio::net::TcpListener::bind(addr).await.map(listener::Listener::Tcp),
350            #[cfg(unix)]
351            SocketAddr::Unix(addr) => addr.bind().map(listener::Listener::Unix),
352        }
353    }
354
355    /// Serializes the address to a `String`.
356    pub fn to_string_ext(&self) -> Option<String> {
357        match self {
358            Self::Inet(addr) => Some(addr.to_string()),
359            Self::Unix(addr) => addr._to_os_string(UNIX_URI_PREFIX, "@").into_string().ok(),
360        }
361    }
362}
363
364#[allow(deprecated)]
365#[cfg(feature = "feat-serde")]
366impl serde::Serialize for SocketAddr {
367    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
368    where
369        S: serde::Serializer,
370    {
371        serializer.serialize_str(
372            &self
373                .to_string_ext()
374                .ok_or_else(|| serde::ser::Error::custom("invalid UTF-8"))?,
375        )
376    }
377}
378
379#[allow(deprecated)]
380#[cfg(feature = "feat-serde")]
381impl<'de> serde::Deserialize<'de> for SocketAddr {
382    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
383    where
384        D: serde::Deserializer<'de>,
385    {
386        Self::new(<&str>::deserialize(deserializer)?).map_err(serde::de::Error::custom)
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    #[cfg(unix)]
393    use std::os::linux::net::SocketAddrExt;
394
395    use super::*;
396
397    #[test]
398    fn test_socket_addr_new_ipv4() {
399        let addr = UniAddr::new("127.0.0.1:8080").unwrap();
400
401        match addr.as_inner() {
402            UniAddrInner::Inet(std_addr) => {
403                assert_eq!(std_addr.ip().to_string(), "127.0.0.1");
404                assert_eq!(std_addr.port(), 8080);
405            }
406            _ => panic!("Expected Inet address, got {:?}", addr),
407        }
408    }
409
410    #[test]
411    fn test_socket_addr_new_ipv6() {
412        let addr = UniAddr::new("[::1]:8080").unwrap();
413
414        match addr.as_inner() {
415            UniAddrInner::Inet(std_addr) => {
416                assert_eq!(std_addr.ip().to_string(), "::1");
417                assert_eq!(std_addr.port(), 8080);
418            }
419            #[cfg(unix)]
420            _ => unreachable!(),
421        }
422    }
423
424    #[cfg(unix)]
425    #[test]
426    fn test_socket_addr_new_unix_pathname() {
427        let addr = UniAddr::new("unix:///tmp/test.sock").unwrap();
428
429        match addr.as_inner() {
430            UniAddrInner::Unix(unix_addr) => {
431                assert!(unix_addr.as_pathname().is_some());
432            }
433            _ => unreachable!(),
434        }
435    }
436
437    #[cfg(unix)]
438    #[test]
439    fn test_socket_addr_new_unix_abstract() {
440        let addr = UniAddr::new("unix://@test.abstract").unwrap();
441
442        match addr.as_inner() {
443            UniAddrInner::Unix(unix_addr) => {
444                assert!(unix_addr.as_abstract_name().is_some());
445            }
446            _ => unreachable!(),
447        }
448    }
449
450    #[test]
451    fn test_socket_addr_new_host() {
452        let addr = UniAddr::new("example.com:8080").unwrap();
453
454        match addr.as_inner() {
455            UniAddrInner::Host(host) => {
456                assert_eq!(&**host, "example.com:8080");
457            }
458            _ => unreachable!(),
459        }
460    }
461
462    #[test]
463    fn test_socket_addr_new_invalid() {
464        // Invalid format
465        assert!(UniAddr::new("invalid").is_err());
466        assert!(UniAddr::new("127.0.0.1").is_err()); // Missing port
467        assert!(UniAddr::new("example.com:invalid").is_err()); // Invalid port
468        assert!(UniAddr::new("127.0.0.1:invalid").is_err()); // Invalid port
469    }
470
471    #[cfg(not(unix))]
472    #[test]
473    fn test_socket_addr_new_unix_unsupported() {
474        // Unix sockets should be unsupported on non-Unix platforms
475        let result = UniAddr::new("unix:///tmp/test.sock");
476
477        assert!(matches!(result.unwrap_err(), ParseError::Unsupported));
478    }
479
480    #[test]
481    fn test_socket_addr_display() {
482        let addr = UniAddr::new("127.0.0.1:8080").unwrap();
483        assert_eq!(&addr.to_str(), "127.0.0.1:8080");
484
485        let addr = UniAddr::new("[::1]:8080").unwrap();
486        assert_eq!(&addr.to_str(), "[::1]:8080");
487
488        #[cfg(unix)]
489        {
490            let addr = UniAddr::new("unix:///tmp/test.sock").unwrap();
491            assert_eq!(&addr.to_str(), "unix:///tmp/test.sock");
492
493            let addr = UniAddr::new("unix://@test.abstract").unwrap();
494            assert_eq!(&addr.to_str(), "unix://@test.abstract");
495        }
496
497        let addr = UniAddr::new("example.com:8080").unwrap();
498        assert_eq!(&addr.to_str(), "example.com:8080");
499    }
500
501    #[test]
502    fn test_socket_addr_debug() {
503        let addr = UniAddr::new("127.0.0.1:8080").unwrap();
504        let debug_str = format!("{:?}", addr);
505
506        assert!(debug_str.contains("127.0.0.1:8080"));
507    }
508
509    #[test]
510    fn test_edge_cases() {
511        assert!(UniAddr::new("").is_err());
512        assert!(UniAddr::new("not-an-address").is_err());
513        assert!(UniAddr::new("127.0.0.1:99999").is_err()); // Port too high
514
515        #[cfg(unix)]
516        {
517            assert!(UniAddr::new("unix://").is_ok()); // Empty path -> unnamed one
518            #[cfg(any(target_os = "android", target_os = "linux"))]
519            assert!(UniAddr::new("unix://@").is_ok()); // Empty abstract one
520        }
521    }
522}