1use std::ffi::{CStr, OsStr, OsString};
4use std::hash::{Hash, Hasher};
5use std::os::unix::ffi::OsStrExt;
6use std::path::Path;
7use std::{fmt, fs, io};
8
9wrapper_lite::general_wrapper! {
10 #[wrapper_impl(Deref)]
11 #[derive(Clone)]
12 pub struct SocketAddr(std::os::unix::net::SocketAddr);
16}
17
18impl SocketAddr {
19 pub fn new<S: AsRef<OsStr> + ?Sized>(addr: &S) -> io::Result<Self> {
58 let addr = addr.as_ref();
59
60 match addr.as_bytes() {
61 #[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))]
62 [b'@' | b'\0', rest @ ..] => Self::new_abstract(rest),
63 #[cfg(not(any(target_os = "android", target_os = "linux", target_os = "cygwin")))]
64 [b'@' | b'\0', ..] => Err(io::Error::new(
65 io::ErrorKind::Unsupported,
66 "abstract unix socket address is not supported",
67 )),
68 _ => Self::new_pathname(addr),
69 }
70 }
71
72 pub fn new_strict<S: AsRef<OsStr> + ?Sized>(addr: &S) -> io::Result<Self> {
78 let addr = addr.as_ref();
79
80 match addr.as_bytes() {
81 #[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))]
82 [b'@' | b'\0', rest @ ..] => Self::new_abstract_strict(rest),
83 #[cfg(not(any(target_os = "android", target_os = "linux", target_os = "cygwin")))]
84 [b'@' | b'\0', ..] => Err(io::Error::new(
85 io::ErrorKind::Unsupported,
86 "abstract unix socket address is not supported",
87 )),
88 _ => Self::new_pathname(addr),
89 }
90 }
91
92 #[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))]
93 pub fn new_abstract(bytes: &[u8]) -> io::Result<Self> {
109 #[cfg(target_os = "android")]
110 use std::os::android::net::SocketAddrExt;
111 #[cfg(target_os = "cygwin")]
112 use std::os::cygwin::net::SocketAddrExt;
113 #[cfg(target_os = "linux")]
114 use std::os::linux::net::SocketAddrExt;
115
116 std::os::unix::net::SocketAddr::from_abstract_name(bytes).map(Self::from_inner)
117 }
118
119 #[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))]
120 pub fn new_abstract_strict(bytes: &[u8]) -> io::Result<Self> {
126 if bytes.is_empty() || bytes.contains(&b'\0') {
127 return Err(io::Error::new(
128 io::ErrorKind::InvalidInput,
129 "parse abstract socket name in strict mode: reject NULL bytes",
130 ));
131 }
132
133 Self::new_abstract(bytes)
134 }
135
136 pub fn new_pathname<P: AsRef<Path>>(pathname: P) -> io::Result<Self> {
144 let _ = fs::remove_file(pathname.as_ref());
145
146 std::os::unix::net::SocketAddr::from_pathname(pathname).map(Self::from_inner)
147 }
148
149 #[allow(clippy::missing_panics_doc)]
150 pub fn new_unnamed() -> Self {
152 std::os::unix::net::SocketAddr::from_pathname("")
154 .map(Self::from_inner)
155 .unwrap()
156 }
157
158 #[inline]
159 pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
165 Self::new(OsStr::from_bytes(bytes))
166 }
167
168 #[inline]
169 pub fn from_bytes_until_nul(bytes: &[u8]) -> io::Result<Self> {
182 #[allow(clippy::single_match_else)]
183 match bytes {
184 #[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))]
185 [b'\0', rest @ ..] => {
186 let addr = CStr::from_bytes_until_nul(rest)
187 .map(CStr::to_bytes)
188 .unwrap_or(rest);
189
190 Self::new_abstract_strict(addr)
191 }
192 #[cfg(not(any(target_os = "android", target_os = "linux", target_os = "cygwin")))]
193 [b'\0', ..] => Err(io::Error::new(
194 io::ErrorKind::Unsupported,
195 "abstract unix socket address is not supported",
196 )),
197 _ => {
198 let addr = CStr::from_bytes_until_nul(bytes)
199 .map(CStr::to_bytes)
200 .unwrap_or(bytes);
201
202 Self::new_pathname(OsStr::from_bytes(addr))
203 }
204 }
205 }
206
207 pub fn to_os_string(&self) -> OsString {
215 self.to_os_string_impl("", "\0")
216 }
217
218 pub fn to_string_lossy(&self) -> String {
227 self.to_os_string_impl("", "@")
228 .to_string_lossy()
229 .into_owned()
230 }
231
232 #[cfg_attr(
233 not(any(target_os = "android", target_os = "linux", target_os = "cygwin")),
234 allow(unused_variables)
235 )]
236 pub(crate) fn to_os_string_impl(&self, prefix: &str, abstract_identifier: &str) -> OsString {
237 let mut os_string = OsString::from(prefix);
238
239 if let Some(pathname) = self.as_pathname() {
240 os_string.push(pathname);
242
243 return os_string;
244 }
245
246 #[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))]
247 {
248 #[cfg(target_os = "android")]
249 use std::os::android::net::SocketAddrExt;
250 #[cfg(target_os = "cygwin")]
251 use std::os::cygwin::net::SocketAddrExt;
252 #[cfg(target_os = "linux")]
253 use std::os::linux::net::SocketAddrExt;
254
255 if let Some(abstract_name) = self.as_abstract_name() {
256 os_string.push(abstract_identifier);
257 os_string.push(OsStr::from_bytes(abstract_name));
258
259 return os_string;
260 }
261 }
262
263 os_string
265 }
266}
267
268impl fmt::Debug for SocketAddr {
269 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270 self.as_inner().fmt(f)
271 }
272}
273
274impl PartialEq for SocketAddr {
275 fn eq(&self, other: &Self) -> bool {
276 if let Some((l, r)) = self.as_pathname().zip(other.as_pathname()) {
277 return l == r;
278 }
279
280 #[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))]
281 {
282 #[cfg(target_os = "android")]
283 use std::os::android::net::SocketAddrExt;
284 #[cfg(target_os = "cygwin")]
285 use std::os::cygwin::net::SocketAddrExt;
286 #[cfg(target_os = "linux")]
287 use std::os::linux::net::SocketAddrExt;
288
289 if let Some((l, r)) = self.as_abstract_name().zip(other.as_abstract_name()) {
290 return l == r;
291 }
292 }
293
294 if self.is_unnamed() && other.is_unnamed() {
295 return true;
296 }
297
298 false
299 }
300}
301
302impl Eq for SocketAddr {}
303
304impl Hash for SocketAddr {
305 fn hash<H: Hasher>(&self, state: &mut H) {
306 if let Some(pathname) = self.as_pathname() {
307 pathname.hash(state);
308
309 return;
310 }
311
312 #[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))]
313 {
314 #[cfg(target_os = "android")]
315 use std::os::android::net::SocketAddrExt;
316 #[cfg(target_os = "cygwin")]
317 use std::os::cygwin::net::SocketAddrExt;
318 #[cfg(target_os = "linux")]
319 use std::os::linux::net::SocketAddrExt;
320
321 if let Some(abstract_name) = self.as_abstract_name() {
322 b'\0'.hash(state);
323 abstract_name.hash(state);
324
325 return;
326 }
327 }
328
329 debug_assert!(self.is_unnamed(), "SocketAddr is not unnamed one");
330
331 b"(unnamed)\0".hash(state);
334 }
335}
336
337#[cfg(feature = "feat-serde")]
338impl serde::Serialize for SocketAddr {
339 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
340 where
341 S: serde::Serializer,
342 {
343 serializer.serialize_str(&self.to_string_lossy())
344 }
345}
346
347#[cfg(feature = "feat-serde")]
348impl<'de> serde::Deserialize<'de> for SocketAddr {
349 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
350 where
351 D: serde::Deserializer<'de>,
352 {
353 Self::new(<&str>::deserialize(deserializer)?).map_err(serde::de::Error::custom)
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_unnamed() {
363 const TEST_CASE: &str = "";
364
365 let addr = SocketAddr::new(TEST_CASE).unwrap();
366
367 assert!(addr.as_ref().is_unnamed());
368 }
369
370 #[test]
371 fn test_pathname() {
372 const TEST_CASE: &str = "/tmp/test_pathname.socket";
373
374 let addr = SocketAddr::new(TEST_CASE).unwrap();
375
376 assert_eq!(addr.to_os_string().to_str().unwrap(), TEST_CASE);
377 assert_eq!(addr.to_string_lossy(), TEST_CASE);
378 assert_eq!(addr.as_pathname().unwrap().to_str().unwrap(), TEST_CASE);
379 }
380
381 #[test]
382 #[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))]
383 fn test_abstract() {
384 #[cfg(target_os = "android")]
385 use std::os::android::net::SocketAddrExt;
386 #[cfg(target_os = "cygwin")]
387 use std::os::cygwin::net::SocketAddrExt;
388 #[cfg(target_os = "linux")]
389 use std::os::linux::net::SocketAddrExt;
390
391 const TEST_CASE_1: &[u8] = b"@abstract.socket";
392 const TEST_CASE_2: &[u8] = b"\0abstract.socket";
393 const TEST_CASE_3: &[u8] = b"@";
394 const TEST_CASE_4: &[u8] = b"\0";
395
396 assert_eq!(
397 SocketAddr::new(OsStr::from_bytes(TEST_CASE_1))
398 .unwrap()
399 .as_abstract_name()
400 .unwrap(),
401 &TEST_CASE_1[1..]
402 );
403
404 assert_eq!(
405 SocketAddr::new(OsStr::from_bytes(TEST_CASE_2))
406 .unwrap()
407 .as_abstract_name()
408 .unwrap(),
409 &TEST_CASE_2[1..]
410 );
411
412 assert_eq!(
413 SocketAddr::new(OsStr::from_bytes(TEST_CASE_3))
414 .unwrap()
415 .as_abstract_name()
416 .unwrap(),
417 &TEST_CASE_3[1..]
418 );
419
420 assert_eq!(
421 SocketAddr::new(OsStr::from_bytes(TEST_CASE_4))
422 .unwrap()
423 .as_abstract_name()
424 .unwrap(),
425 &TEST_CASE_4[1..]
426 );
427 }
428
429 #[test]
430 #[should_panic]
431 fn test_pathname_with_null_byte() {
432 let _addr = SocketAddr::new_pathname("(unamed)\0").unwrap();
433 }
434
435 #[test]
436 fn test_partial_eq_hash() {
437 let addr_pathname_1 = SocketAddr::new("/tmp/test_pathname_1.socket").unwrap();
438 let addr_pathname_2 = SocketAddr::new("/tmp/test_pathname_2.socket").unwrap();
439 let addr_unnamed = SocketAddr::new_unnamed();
440
441 assert_eq!(addr_pathname_1, addr_pathname_1);
442 assert_ne!(addr_pathname_1, addr_pathname_2);
443 assert_ne!(addr_pathname_2, addr_pathname_1);
444
445 assert_eq!(addr_unnamed, addr_unnamed);
446 assert_ne!(addr_pathname_1, addr_unnamed);
447 assert_ne!(addr_unnamed, addr_pathname_1);
448 assert_ne!(addr_pathname_2, addr_unnamed);
449 assert_ne!(addr_unnamed, addr_pathname_2);
450
451 #[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))]
452 {
453 use core::hash::{BuildHasher, Hash, Hasher};
454
455 use foldhash::fast::RandomState;
456
457 let addr_abstract_1 = SocketAddr::new_abstract(b"/tmp/test_pathname_1.socket").unwrap();
458 let addr_abstract_2 = SocketAddr::new_abstract(b"/tmp/test_pathname_2.socket").unwrap();
459 let addr_abstract_empty = SocketAddr::new_abstract(&[]).unwrap();
460 let addr_abstract_unnamed = SocketAddr::new_abstract(b"(unamed)\0").unwrap();
461
462 assert_eq!(addr_abstract_1, addr_abstract_1);
463 assert_ne!(addr_abstract_1, addr_abstract_2);
464 assert_ne!(addr_abstract_2, addr_abstract_1);
465
466 assert_ne!(addr_unnamed, addr_abstract_empty);
468
469 assert_ne!(addr_pathname_1, addr_abstract_1);
471
472 let state = RandomState::default();
475 let addr_unnamed_hash = {
476 let mut state = state.build_hasher();
477 addr_unnamed.hash(&mut state);
478 state.finish()
479 };
480 let addr_abstract_unnamed_hash = {
481 let mut state = state.build_hasher();
482 addr_abstract_unnamed.hash(&mut state);
483 state.finish()
484 };
485 assert_ne!(addr_unnamed_hash, addr_abstract_unnamed_hash);
486 }
487 }
488}