1use std::ffi::{CStr, OsStr, OsString};
4use std::os::unix::ffi::OsStrExt;
5use std::os::unix::net::{SocketAddr as StdSocketAddr, UnixDatagram, UnixListener, UnixStream};
6use std::path::Path;
7use std::{fmt, fs, io};
8
9wrapper_lite::general_wrapper! {
10 #[wrapper_impl(Deref)]
11 #[derive(Clone)]
12 pub struct SocketAddr(StdSocketAddr);
16}
17
18impl SocketAddr {
19 pub fn new<S: AsRef<OsStr> + ?Sized>(addr: &S) -> io::Result<Self> {
53 let addr = addr.as_ref();
54
55 match addr.as_bytes() {
56 #[cfg(any(target_os = "android", target_os = "linux"))]
57 [b'@', rest @ ..] | [b'\0', rest @ ..] => {
58 use std::os::linux::net::SocketAddrExt;
59
60 StdSocketAddr::from_abstract_name(rest).map(Self::const_from)
61 }
62 #[cfg(not(any(target_os = "android", target_os = "linux")))]
63 [b'@', ..] | [b'\0', ..] => Err(io::Error::new(
64 io::ErrorKind::Unsupported,
65 "abstract unix socket address is not supported",
66 )),
67 _ => {
68 let _ = fs::remove_file(addr);
69
70 StdSocketAddr::from_pathname(addr).map(Self::const_from)
71 }
72 }
73 }
74
75 #[cfg(any(target_os = "android", target_os = "linux"))]
76 pub fn new_abstract(bytes: &[u8]) -> io::Result<Self> {
78 use std::os::linux::net::SocketAddrExt;
79
80 StdSocketAddr::from_abstract_name(bytes).map(Self::const_from)
81 }
82
83 pub fn new_pathname<P: AsRef<Path>>(pathname: P) -> io::Result<Self> {
85 StdSocketAddr::from_pathname(pathname).map(Self::const_from)
86 }
87
88 #[allow(clippy::missing_panics_doc)]
89 pub fn new_unnamed() -> Self {
91 StdSocketAddr::from_pathname("").map(Self::const_from).unwrap()
93 }
94
95 #[inline]
96 pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
108 Self::new(OsStr::from_bytes(bytes))
109 }
110
111 pub fn from_bytes_until_nul(bytes: &[u8]) -> io::Result<Self> {
113 let first_nul = match bytes {
114 [b'\0', rest @ ..] => CStr::from_bytes_until_nul(rest),
115 rest => CStr::from_bytes_until_nul(rest),
116 }
117 .map_err(|_| {
118 io::Error::new(
119 io::ErrorKind::InvalidInput,
120 "bytes must be a valid C string with a null terminator",
121 )
122 })?;
123
124 Self::new(OsStr::from_bytes(first_nul.to_bytes()))
125 }
126
127 #[inline]
128 pub fn bind_std(&self) -> io::Result<UnixListener> {
130 UnixListener::bind_addr(self)
131 }
132
133 #[cfg(feature = "feat-tokio")]
134 pub fn bind(&self) -> io::Result<tokio::net::UnixListener> {
137 self.bind_std()
138 .and_then(|l| {
139 l.set_nonblocking(true)?;
140 Ok(l)
141 })
142 .and_then(tokio::net::UnixListener::from_std)
143 }
144
145 #[inline]
146 pub fn bind_dgram_std(&self) -> io::Result<UnixDatagram> {
148 UnixDatagram::bind_addr(self)
149 }
150
151 #[cfg(feature = "feat-tokio")]
152 pub fn bind_dgram(&self) -> io::Result<tokio::net::UnixDatagram> {
154 self.bind_dgram_std()
155 .and_then(|d| {
156 d.set_nonblocking(true)?;
157 Ok(d)
158 })
159 .and_then(tokio::net::UnixDatagram::from_std)
160 }
161
162 #[inline]
163 pub fn connect_std(&self) -> io::Result<UnixStream> {
166 UnixStream::connect_addr(self)
167 }
168
169 #[cfg(feature = "feat-tokio")]
170 pub fn connect(&self) -> io::Result<tokio::net::UnixStream> {
173 self.connect_std()
174 .and_then(|s| {
175 s.set_nonblocking(true)?;
176 Ok(s)
177 })
178 .and_then(tokio::net::UnixStream::from_std)
179 }
180
181 pub fn to_os_string(&self) -> OsString {
189 self._to_os_string("", "\0")
190 }
191
192 pub fn to_string_ext(&self) -> Option<String> {
201 self._to_os_string("", "@").into_string().ok()
202 }
203
204 pub(crate) fn _to_os_string(&self, prefix: &str, abstract_identifier: &str) -> OsString {
205 let mut os_string = OsString::from(prefix);
206
207 if let Some(pathname) = self.as_pathname() {
208 os_string.push(pathname);
210
211 return os_string;
212 }
213
214 #[cfg(any(target_os = "android", target_os = "linux"))]
215 {
216 use std::os::linux::net::SocketAddrExt;
217
218 if let Some(abstract_name) = self.as_abstract_name() {
219 os_string.push(abstract_identifier);
220 os_string.push(OsStr::from_bytes(abstract_name));
221
222 return os_string;
223 }
224 }
225
226 os_string
227 }
228}
229
230impl fmt::Debug for SocketAddr {
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 self.as_inner().fmt(f)
233 }
234}
235
236impl PartialEq for SocketAddr {
237 fn eq(&self, other: &Self) -> bool {
238 if let Some((l, r)) = self.as_pathname().zip(other.as_pathname()) {
239 return l == r;
240 }
241
242 #[cfg(any(target_os = "android", target_os = "linux"))]
243 {
244 use std::os::linux::net::SocketAddrExt;
245
246 if let Some((l, r)) = self.as_abstract_name().zip(other.as_abstract_name()) {
247 return l == r;
248 }
249 }
250
251 if self.is_unnamed() && other.is_unnamed() {
252 return true;
253 }
254
255 false
256 }
257}
258
259impl Eq for SocketAddr {}
260
261impl std::hash::Hash for SocketAddr {
262 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
263 if let Some(pathname) = self.as_pathname() {
264 pathname.hash(state);
265
266 return;
267 }
268
269 #[cfg(any(target_os = "android", target_os = "linux"))]
270 {
271 use std::os::linux::net::SocketAddrExt;
272
273 if let Some(abstract_name) = self.as_abstract_name() {
274 b'\0'.hash(state);
275 abstract_name.hash(state);
276
277 return;
278 }
279 }
280
281 debug_assert!(self.is_unnamed(), "SocketAddr is not unnamed one");
282
283 b"(unnamed)\0".hash(state);
286 }
287}
288
289#[cfg(feature = "feat-serde")]
290impl serde::Serialize for SocketAddr {
291 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
292 where
293 S: serde::Serializer,
294 {
295 serializer.serialize_str(
296 &self
297 .to_string_ext()
298 .ok_or_else(|| serde::ser::Error::custom("invalid UTF-8"))?,
299 )
300 }
301}
302
303#[cfg(feature = "feat-serde")]
304impl<'de> serde::Deserialize<'de> for SocketAddr {
305 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
306 where
307 D: serde::Deserializer<'de>,
308 {
309 Self::new(<&str>::deserialize(deserializer)?).map_err(serde::de::Error::custom)
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use core::hash::{Hash, Hasher};
316 use std::hash::DefaultHasher;
317
318 use super::*;
319
320 #[test]
321 fn test_unnamed() {
322 const TEST_CASE: &str = "";
323
324 let addr = SocketAddr::new(TEST_CASE).unwrap();
325
326 assert!(addr.as_ref().is_unnamed());
327 }
328
329 #[test]
330 fn test_pathname() {
331 const TEST_CASE: &str = "/tmp/test_pathname.socket";
332
333 let addr = SocketAddr::new(TEST_CASE).unwrap();
334
335 assert_eq!(addr.to_os_string().to_str().unwrap(), TEST_CASE);
336 assert_eq!(addr.to_string_ext().unwrap(), TEST_CASE);
337 assert_eq!(addr.as_pathname().unwrap().to_str().unwrap(), TEST_CASE);
338 }
339
340 #[test]
341 #[cfg(any(target_os = "android", target_os = "linux"))]
342 fn test_abstract() {
343 use std::os::linux::net::SocketAddrExt;
344
345 const TEST_CASE_1: &[u8] = b"@abstract.socket";
346 const TEST_CASE_2: &[u8] = b"\0abstract.socket";
347 const TEST_CASE_3: &[u8] = b"@";
348 const TEST_CASE_4: &[u8] = b"\0";
349
350 assert_eq!(
351 SocketAddr::new(OsStr::from_bytes(TEST_CASE_1))
352 .unwrap()
353 .as_abstract_name()
354 .unwrap(),
355 &TEST_CASE_1[1..]
356 );
357
358 assert_eq!(
359 SocketAddr::new(OsStr::from_bytes(TEST_CASE_2))
360 .unwrap()
361 .as_abstract_name()
362 .unwrap(),
363 &TEST_CASE_2[1..]
364 );
365
366 assert_eq!(
367 SocketAddr::new(OsStr::from_bytes(TEST_CASE_3))
368 .unwrap()
369 .as_abstract_name()
370 .unwrap(),
371 &TEST_CASE_3[1..]
372 );
373
374 assert_eq!(
375 SocketAddr::new(OsStr::from_bytes(TEST_CASE_4))
376 .unwrap()
377 .as_abstract_name()
378 .unwrap(),
379 &TEST_CASE_4[1..]
380 );
381 }
382
383 #[test]
384 #[should_panic]
385 fn test_pathname_with_null_byte() {
386 let _addr = SocketAddr::new_pathname("(unamed)\0").unwrap();
387 }
388
389 #[test]
390 fn test_partial_eq_hash() {
391 let addr_pathname_1 = SocketAddr::new("/tmp/test_pathname_1.socket").unwrap();
392 let addr_pathname_2 = SocketAddr::new("/tmp/test_pathname_2.socket").unwrap();
393 let addr_unnamed = SocketAddr::new_unnamed();
394
395 assert_eq!(addr_pathname_1, addr_pathname_1);
396 assert_ne!(addr_pathname_1, addr_pathname_2);
397 assert_ne!(addr_pathname_2, addr_pathname_1);
398
399 assert_eq!(addr_unnamed, addr_unnamed);
400 assert_ne!(addr_pathname_1, addr_unnamed);
401 assert_ne!(addr_unnamed, addr_pathname_1);
402 assert_ne!(addr_pathname_2, addr_unnamed);
403 assert_ne!(addr_unnamed, addr_pathname_2);
404
405 #[cfg(any(target_os = "android", target_os = "linux"))]
406 {
407 let addr_abstract_1 = SocketAddr::new_abstract(b"/tmp/test_pathname_1.socket").unwrap();
408 let addr_abstract_2 = SocketAddr::new_abstract(b"/tmp/test_pathname_2.socket").unwrap();
409 let addr_abstract_empty = SocketAddr::new_abstract(&[]).unwrap();
410 let addr_abstract_unnamed_hash = SocketAddr::new_abstract(b"(unamed)\0").unwrap();
411
412 assert_eq!(addr_abstract_1, addr_abstract_1);
413 assert_ne!(addr_abstract_1, addr_abstract_2);
414 assert_ne!(addr_abstract_2, addr_abstract_1);
415
416 assert_ne!(addr_unnamed, addr_abstract_empty);
418
419 assert_ne!(addr_pathname_1, addr_abstract_1);
421
422 let addr_unnamed_hash = {
425 let mut state = DefaultHasher::new();
426 addr_unnamed.hash(&mut state);
427 state.finish()
428 };
429 let addr_abstract_unnamed_hash = {
430 let mut state = DefaultHasher::new();
431 addr_abstract_unnamed_hash.hash(&mut state);
432 state.finish()
433 };
434 assert_ne!(addr_unnamed_hash, addr_abstract_unnamed_hash);
435 }
436 }
437}