uni_addr/
unix.rs

1//! Platform-specific code for Unix-like systems
2
3use std::ffi::{CStr, OsStr, OsString};
4use std::os::unix::ffi::OsStrExt;
5use std::os::unix::net::{SocketAddr as StdSocketAddr, UnixDatagram, UnixListener, UnixStream};
6use std::path::Path;
7use std::{fmt, fs, io};
8
9use wrapper_lite::general_wrapper;
10
11general_wrapper! {
12    #[wrapper_impl(Deref)]
13    #[derive(Clone)]
14    /// Wrapper over [`std::os::unix::net::SocketAddr`].
15    ///
16    /// See [`SocketAddr::new`] for more details.
17    pub SocketAddr(StdSocketAddr)
18}
19
20impl SocketAddr {
21    /// Creates a new unix [`SocketAddr`] from its string representation.
22    ///
23    /// # Address Types
24    ///
25    /// - Strings starting with `@` or `\0` are parsed as abstract unix socket
26    ///   addresses (Linux-specific).
27    /// - All other strings are parsed as pathname unix socket addresses.
28    /// - Empty strings create unnamed unix socket addresses.
29    ///
30    /// # Important
31    ///
32    /// This method accepts an `OsStr` and does not guarantee proper null
33    /// termination. While pathname addresses reject interior null bytes,
34    /// abstract addresses accept them silently, potentially causing unexpected
35    /// behavior (e.g., `\0abstract` differs from `\0abstract\0\0\0\0\0...`).
36    ///
37    /// Use [`SocketAddr::from_bytes_until_nul`] to ensure only the portion
38    /// before the first null byte is used for address parsing.
39    ///
40    /// # Examples
41    ///
42    /// ```rust
43    /// # use uni_addr::unix::SocketAddr;
44    /// #[cfg(any(target_os = "android", target_os = "linux"))]
45    /// // Abstract address (Linux-specific)
46    /// let abstract_addr = SocketAddr::new("@abstract.example.socket").unwrap();
47    ///
48    /// // Pathname address
49    /// let pathname_addr = SocketAddr::new("/run/pathname.example.socket").unwrap();
50    ///
51    /// // Unnamed address
52    /// let unnamed_addr = SocketAddr::new("").unwrap();
53    /// ```
54    pub fn new<S: AsRef<OsStr> + ?Sized>(addr: &S) -> io::Result<Self> {
55        let addr = addr.as_ref();
56
57        match addr.as_bytes() {
58            #[cfg(any(target_os = "android", target_os = "linux"))]
59            [b'@', rest @ ..] | [b'\0', rest @ ..] => {
60                use std::os::linux::net::SocketAddrExt;
61
62                StdSocketAddr::from_abstract_name(rest).map(Self::const_from)
63            }
64            #[cfg(not(any(target_os = "android", target_os = "linux")))]
65            [b'@', ..] | [b'\0', ..] => Err(io::Error::new(
66                io::ErrorKind::Unsupported,
67                "abstract unix socket address is not supported",
68            )),
69            _ => {
70                let _ = fs::remove_file(addr);
71
72                StdSocketAddr::from_pathname(addr).map(Self::const_from)
73            }
74        }
75    }
76
77    #[cfg(any(target_os = "android", target_os = "linux"))]
78    /// Creates a new abstract unix [`SocketAddr`].
79    pub fn new_abstract(bytes: &[u8]) -> io::Result<Self> {
80        use std::os::linux::net::SocketAddrExt;
81
82        StdSocketAddr::from_abstract_name(bytes).map(Self::const_from)
83    }
84
85    /// Creates a new pathname unix [`SocketAddr`].
86    pub fn new_pathname<P: AsRef<Path>>(pathname: P) -> io::Result<Self> {
87        StdSocketAddr::from_pathname(pathname).map(Self::const_from)
88    }
89
90    #[allow(clippy::missing_panics_doc)]
91    /// Creates a new unnamed unix [`SocketAddr`].
92    pub fn new_unnamed() -> Self {
93        // SAFEY: `from_pathname` will not fail at all.
94        StdSocketAddr::from_pathname("").map(Self::const_from).unwrap()
95    }
96
97    #[inline]
98    /// Creates a new unix [`SocketAddr`] from bytes.
99    ///
100    /// # Note
101    ///
102    /// This method does not validate null terminators. Pathname addresses
103    /// will reject paths containing null bytes during parsing, but abstract
104    /// addresses accept null bytes silently, which may lead to unexpected
105    /// behavior.
106    ///
107    /// Consider using [`from_bytes_until_nul`](Self::from_bytes_until_nul)
108    /// for null-terminated parsing.
109    pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
110        Self::new(OsStr::from_bytes(bytes))
111    }
112
113    /// Creates a new unix [`SocketAddr`] from bytes until the first null byte.
114    pub fn from_bytes_until_nul(bytes: &[u8]) -> io::Result<Self> {
115        let first_nul = match bytes {
116            [b'\0', rest @ ..] => CStr::from_bytes_until_nul(rest),
117            rest => CStr::from_bytes_until_nul(rest),
118        }
119        .map_err(|_| {
120            io::Error::new(
121                io::ErrorKind::InvalidInput,
122                "bytes must be a valid C string with a null terminator",
123            )
124        })?;
125
126        Self::new(OsStr::from_bytes(first_nul.to_bytes()))
127    }
128
129    #[inline]
130    /// Creates a new [`UnixListener`] bound to the specified socket.
131    pub fn bind_std(&self) -> io::Result<UnixListener> {
132        UnixListener::bind_addr(self)
133    }
134
135    #[cfg(feature = "feat-tokio")]
136    /// Creates a new [`tokio::net::UnixListener`] bound to the specified
137    /// socket.
138    pub fn bind(&self) -> io::Result<tokio::net::UnixListener> {
139        self.bind_std()
140            .and_then(|l| {
141                l.set_nonblocking(true)?;
142                Ok(l)
143            })
144            .and_then(tokio::net::UnixListener::from_std)
145    }
146
147    #[inline]
148    /// Creates a Unix datagram socket bound to the given path.
149    pub fn bind_dgram_std(&self) -> io::Result<UnixDatagram> {
150        UnixDatagram::bind_addr(self)
151    }
152
153    #[cfg(feature = "feat-tokio")]
154    /// Creates a Unix datagram socket bound to the given path.
155    pub fn bind_dgram(&self) -> io::Result<tokio::net::UnixDatagram> {
156        self.bind_dgram_std()
157            .and_then(|d| {
158                d.set_nonblocking(true)?;
159                Ok(d)
160            })
161            .and_then(tokio::net::UnixDatagram::from_std)
162    }
163
164    #[inline]
165    /// Connects to the Unix socket address and returns a
166    /// [`std::os::unix::net::UnixStream`].
167    pub fn connect_std(&self) -> io::Result<UnixStream> {
168        UnixStream::connect_addr(self)
169    }
170
171    #[cfg(feature = "feat-tokio")]
172    /// Connects to the Unix socket address and returns a
173    /// [`tokio::net::UnixStream`].
174    pub fn connect(&self) -> io::Result<tokio::net::UnixStream> {
175        self.connect_std()
176            .and_then(|s| {
177                s.set_nonblocking(true)?;
178                Ok(s)
179            })
180            .and_then(tokio::net::UnixStream::from_std)
181    }
182
183    /// Serializes the Unix socket address to an `OsString`.
184    ///
185    /// # Returns
186    ///
187    /// - For abstract ones: returns the name prefixed with **`\0`**
188    /// - For pathname ones: returns the pathname
189    /// - For unnamed ones: returns an empty string.
190    pub fn to_os_string(&self) -> OsString {
191        self._to_os_string("", "\0")
192    }
193
194    /// Likes [`to_os_string`](Self::to_os_string), but returns a `String`
195    /// instead of `OsString`, performing UTF-8 verification.
196    ///
197    /// # Returns
198    ///
199    /// - For abstract ones: returns the name prefixed with **`@`**
200    /// - For pathname ones: returns the pathname
201    /// - For unnamed ones: returns an empty string.
202    pub fn to_string_ext(&self) -> Option<String> {
203        self._to_os_string("", "@").into_string().ok()
204    }
205
206    pub(crate) fn _to_os_string(&self, prefix: &str, abstract_identifier: &str) -> OsString {
207        let mut os_string = OsString::from(prefix);
208
209        if let Some(pathname) = self.as_pathname() {
210            // Notice: cannot use `extend` here
211            os_string.push(pathname);
212
213            return os_string;
214        }
215
216        #[cfg(any(target_os = "android", target_os = "linux"))]
217        {
218            use std::os::linux::net::SocketAddrExt;
219
220            if let Some(abstract_name) = self.as_abstract_name() {
221                os_string.push(abstract_identifier);
222                os_string.push(OsStr::from_bytes(abstract_name));
223
224                return os_string;
225            }
226        }
227
228        os_string
229    }
230}
231
232impl fmt::Debug for SocketAddr {
233    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234        self.as_inner().fmt(f)
235    }
236}
237
238impl PartialEq for SocketAddr {
239    fn eq(&self, other: &Self) -> bool {
240        if let Some((l, r)) = self.as_pathname().zip(other.as_pathname()) {
241            return l == r;
242        }
243
244        #[cfg(any(target_os = "android", target_os = "linux"))]
245        {
246            use std::os::linux::net::SocketAddrExt;
247
248            if let Some((l, r)) = self.as_abstract_name().zip(other.as_abstract_name()) {
249                return l == r;
250            }
251        }
252
253        if self.is_unnamed() && other.is_unnamed() {
254            return true;
255        }
256
257        false
258    }
259}
260
261impl Eq for SocketAddr {}
262
263impl std::hash::Hash for SocketAddr {
264    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
265        if let Some(pathname) = self.as_pathname() {
266            pathname.hash(state);
267
268            return;
269        }
270
271        #[cfg(any(target_os = "android", target_os = "linux"))]
272        {
273            use std::os::linux::net::SocketAddrExt;
274
275            if let Some(abstract_name) = self.as_abstract_name() {
276                b'\0'.hash(state);
277                abstract_name.hash(state);
278
279                return;
280            }
281        }
282
283        debug_assert!(self.is_unnamed(), "SocketAddr is not unnamed one");
284
285        // `Path` cannot contain null bytes, so we can safely use it as a
286        // sentinel value.
287        b"(unnamed)\0".hash(state);
288    }
289}
290
291#[cfg(feature = "feat-serde")]
292impl serde::Serialize for SocketAddr {
293    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
294    where
295        S: serde::Serializer,
296    {
297        serializer.serialize_str(
298            &self
299                .to_string_ext()
300                .ok_or_else(|| serde::ser::Error::custom("invalid UTF-8"))?,
301        )
302    }
303}
304
305#[cfg(feature = "feat-serde")]
306impl<'de> serde::Deserialize<'de> for SocketAddr {
307    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
308    where
309        D: serde::Deserializer<'de>,
310    {
311        Self::new(<&str>::deserialize(deserializer)?).map_err(serde::de::Error::custom)
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use core::hash::{Hash, Hasher};
318    use std::hash::DefaultHasher;
319
320    use super::*;
321
322    #[test]
323    fn test_unnamed() {
324        const TEST_CASE: &str = "";
325
326        let addr = SocketAddr::new(TEST_CASE).unwrap();
327
328        assert!(addr.as_ref().is_unnamed());
329    }
330
331    #[test]
332    fn test_pathname() {
333        const TEST_CASE: &str = "/tmp/test_pathname.socket";
334
335        let addr = SocketAddr::new(TEST_CASE).unwrap();
336
337        assert_eq!(addr.to_os_string().to_str().unwrap(), TEST_CASE);
338        assert_eq!(addr.to_string_ext().unwrap(), TEST_CASE);
339        assert_eq!(addr.as_pathname().unwrap().to_str().unwrap(), TEST_CASE);
340    }
341
342    #[test]
343    #[cfg(any(target_os = "android", target_os = "linux"))]
344    fn test_abstract() {
345        use std::os::linux::net::SocketAddrExt;
346
347        const TEST_CASE_1: &[u8] = b"@abstract.socket";
348        const TEST_CASE_2: &[u8] = b"\0abstract.socket";
349        const TEST_CASE_3: &[u8] = b"@";
350        const TEST_CASE_4: &[u8] = b"\0";
351
352        assert_eq!(
353            SocketAddr::new(OsStr::from_bytes(TEST_CASE_1))
354                .unwrap()
355                .as_abstract_name()
356                .unwrap(),
357            &TEST_CASE_1[1..]
358        );
359
360        assert_eq!(
361            SocketAddr::new(OsStr::from_bytes(TEST_CASE_2))
362                .unwrap()
363                .as_abstract_name()
364                .unwrap(),
365            &TEST_CASE_2[1..]
366        );
367
368        assert_eq!(
369            SocketAddr::new(OsStr::from_bytes(TEST_CASE_3))
370                .unwrap()
371                .as_abstract_name()
372                .unwrap(),
373            &TEST_CASE_3[1..]
374        );
375
376        assert_eq!(
377            SocketAddr::new(OsStr::from_bytes(TEST_CASE_4))
378                .unwrap()
379                .as_abstract_name()
380                .unwrap(),
381            &TEST_CASE_4[1..]
382        );
383    }
384
385    #[test]
386    #[should_panic]
387    fn test_pathname_with_null_byte() {
388        let _addr = SocketAddr::new_pathname("(unamed)\0").unwrap();
389    }
390
391    #[test]
392    fn test_partial_eq_hash() {
393        let addr_pathname_1 = SocketAddr::new("/tmp/test_pathname_1.socket").unwrap();
394        let addr_pathname_2 = SocketAddr::new("/tmp/test_pathname_2.socket").unwrap();
395        let addr_unnamed = SocketAddr::new_unnamed();
396
397        assert_eq!(addr_pathname_1, addr_pathname_1);
398        assert_ne!(addr_pathname_1, addr_pathname_2);
399        assert_ne!(addr_pathname_2, addr_pathname_1);
400
401        assert_eq!(addr_unnamed, addr_unnamed);
402        assert_ne!(addr_pathname_1, addr_unnamed);
403        assert_ne!(addr_unnamed, addr_pathname_1);
404        assert_ne!(addr_pathname_2, addr_unnamed);
405        assert_ne!(addr_unnamed, addr_pathname_2);
406
407        #[cfg(any(target_os = "android", target_os = "linux"))]
408        {
409            let addr_abstract_1 = SocketAddr::new_abstract(b"/tmp/test_pathname_1.socket").unwrap();
410            let addr_abstract_2 = SocketAddr::new_abstract(b"/tmp/test_pathname_2.socket").unwrap();
411            let addr_abstract_empty = SocketAddr::new_abstract(&[]).unwrap();
412            let addr_abstract_unnamed_hash = SocketAddr::new_abstract(b"(unamed)\0").unwrap();
413
414            assert_eq!(addr_abstract_1, addr_abstract_1);
415            assert_ne!(addr_abstract_1, addr_abstract_2);
416            assert_ne!(addr_abstract_2, addr_abstract_1);
417
418            // Empty abstract addresses should be equal to unnamed addresses
419            assert_ne!(addr_unnamed, addr_abstract_empty);
420
421            // Abstract addresses should not be equal to pathname addresses
422            assert_ne!(addr_pathname_1, addr_abstract_1);
423
424            // Abstract unnamed address `@(unamed)\0`' hash should not be equal to unname
425            // ones'
426            let addr_unnamed_hash = {
427                let mut state = DefaultHasher::new();
428                addr_unnamed.hash(&mut state);
429                state.finish()
430            };
431            let addr_abstract_unnamed_hash = {
432                let mut state = DefaultHasher::new();
433                addr_abstract_unnamed_hash.hash(&mut state);
434                state.finish()
435            };
436            assert_ne!(addr_unnamed_hash, addr_abstract_unnamed_hash);
437        }
438    }
439}