1use std::ffi::{CStr, OsStr, OsString};
4use std::os::unix::ffi::OsStrExt;
5use std::os::unix::net::{SocketAddr as StdSocketAddr, UnixDatagram, UnixListener, UnixStream};
6use std::path::Path;
7use std::{fmt, fs, io};
8
9use wrapper_lite::general_wrapper;
10
11general_wrapper! {
12 #[wrapper_impl(Deref)]
13 #[derive(Clone)]
14 pub SocketAddr(StdSocketAddr)
18}
19
20impl SocketAddr {
21 pub fn new<S: AsRef<OsStr> + ?Sized>(addr: &S) -> io::Result<Self> {
55 let addr = addr.as_ref();
56
57 match addr.as_bytes() {
58 #[cfg(any(target_os = "android", target_os = "linux"))]
59 [b'@', rest @ ..] | [b'\0', rest @ ..] => {
60 use std::os::linux::net::SocketAddrExt;
61
62 StdSocketAddr::from_abstract_name(rest).map(Self::const_from)
63 }
64 #[cfg(not(any(target_os = "android", target_os = "linux")))]
65 [b'@', ..] | [b'\0', ..] => Err(io::Error::new(
66 io::ErrorKind::Unsupported,
67 "abstract unix socket address is not supported",
68 )),
69 _ => {
70 let _ = fs::remove_file(addr);
71
72 StdSocketAddr::from_pathname(addr).map(Self::const_from)
73 }
74 }
75 }
76
77 #[cfg(any(target_os = "android", target_os = "linux"))]
78 pub fn new_abstract(bytes: &[u8]) -> io::Result<Self> {
80 use std::os::linux::net::SocketAddrExt;
81
82 StdSocketAddr::from_abstract_name(bytes).map(Self::const_from)
83 }
84
85 pub fn new_pathname<P: AsRef<Path>>(pathname: P) -> io::Result<Self> {
87 StdSocketAddr::from_pathname(pathname).map(Self::const_from)
88 }
89
90 #[allow(clippy::missing_panics_doc)]
91 pub fn new_unnamed() -> Self {
93 StdSocketAddr::from_pathname("").map(Self::const_from).unwrap()
95 }
96
97 #[inline]
98 pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
110 Self::new(OsStr::from_bytes(bytes))
111 }
112
113 pub fn from_bytes_until_nul(bytes: &[u8]) -> io::Result<Self> {
115 let first_nul = match bytes {
116 [b'\0', rest @ ..] => CStr::from_bytes_until_nul(rest),
117 rest => CStr::from_bytes_until_nul(rest),
118 }
119 .map_err(|_| {
120 io::Error::new(
121 io::ErrorKind::InvalidInput,
122 "bytes must be a valid C string with a null terminator",
123 )
124 })?;
125
126 Self::new(OsStr::from_bytes(first_nul.to_bytes()))
127 }
128
129 #[inline]
130 pub fn bind_std(&self) -> io::Result<UnixListener> {
132 UnixListener::bind_addr(self)
133 }
134
135 #[cfg(feature = "feat-tokio")]
136 pub fn bind(&self) -> io::Result<tokio::net::UnixListener> {
139 self.bind_std()
140 .and_then(|l| {
141 l.set_nonblocking(true)?;
142 Ok(l)
143 })
144 .and_then(tokio::net::UnixListener::from_std)
145 }
146
147 #[inline]
148 pub fn bind_dgram_std(&self) -> io::Result<UnixDatagram> {
150 UnixDatagram::bind_addr(self)
151 }
152
153 #[cfg(feature = "feat-tokio")]
154 pub fn bind_dgram(&self) -> io::Result<tokio::net::UnixDatagram> {
156 self.bind_dgram_std()
157 .and_then(|d| {
158 d.set_nonblocking(true)?;
159 Ok(d)
160 })
161 .and_then(tokio::net::UnixDatagram::from_std)
162 }
163
164 #[inline]
165 pub fn connect_std(&self) -> io::Result<UnixStream> {
168 UnixStream::connect_addr(self)
169 }
170
171 #[cfg(feature = "feat-tokio")]
172 pub fn connect(&self) -> io::Result<tokio::net::UnixStream> {
175 self.connect_std()
176 .and_then(|s| {
177 s.set_nonblocking(true)?;
178 Ok(s)
179 })
180 .and_then(tokio::net::UnixStream::from_std)
181 }
182
183 pub fn to_os_string(&self) -> OsString {
191 self._to_os_string("", "\0")
192 }
193
194 pub fn to_string_ext(&self) -> Option<String> {
203 self._to_os_string("", "@").into_string().ok()
204 }
205
206 pub(crate) fn _to_os_string(&self, prefix: &str, abstract_identifier: &str) -> OsString {
207 let mut os_string = OsString::from(prefix);
208
209 if let Some(pathname) = self.as_pathname() {
210 os_string.push(pathname);
212
213 return os_string;
214 }
215
216 #[cfg(any(target_os = "android", target_os = "linux"))]
217 {
218 use std::os::linux::net::SocketAddrExt;
219
220 if let Some(abstract_name) = self.as_abstract_name() {
221 os_string.push(abstract_identifier);
222 os_string.push(OsStr::from_bytes(abstract_name));
223
224 return os_string;
225 }
226 }
227
228 os_string
229 }
230}
231
232impl fmt::Debug for SocketAddr {
233 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234 self.as_inner().fmt(f)
235 }
236}
237
238impl PartialEq for SocketAddr {
239 fn eq(&self, other: &Self) -> bool {
240 if let Some((l, r)) = self.as_pathname().zip(other.as_pathname()) {
241 return l == r;
242 }
243
244 #[cfg(any(target_os = "android", target_os = "linux"))]
245 {
246 use std::os::linux::net::SocketAddrExt;
247
248 if let Some((l, r)) = self.as_abstract_name().zip(other.as_abstract_name()) {
249 return l == r;
250 }
251 }
252
253 if self.is_unnamed() && other.is_unnamed() {
254 return true;
255 }
256
257 false
258 }
259}
260
261impl Eq for SocketAddr {}
262
263impl std::hash::Hash for SocketAddr {
264 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
265 if let Some(pathname) = self.as_pathname() {
266 pathname.hash(state);
267
268 return;
269 }
270
271 #[cfg(any(target_os = "android", target_os = "linux"))]
272 {
273 use std::os::linux::net::SocketAddrExt;
274
275 if let Some(abstract_name) = self.as_abstract_name() {
276 b'\0'.hash(state);
277 abstract_name.hash(state);
278
279 return;
280 }
281 }
282
283 debug_assert!(self.is_unnamed(), "SocketAddr is not unnamed one");
284
285 b"(unnamed)\0".hash(state);
288 }
289}
290
291#[cfg(feature = "feat-serde")]
292impl serde::Serialize for SocketAddr {
293 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
294 where
295 S: serde::Serializer,
296 {
297 serializer.serialize_str(
298 &self
299 .to_string_ext()
300 .ok_or_else(|| serde::ser::Error::custom("invalid UTF-8"))?,
301 )
302 }
303}
304
305#[cfg(feature = "feat-serde")]
306impl<'de> serde::Deserialize<'de> for SocketAddr {
307 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
308 where
309 D: serde::Deserializer<'de>,
310 {
311 Self::new(<&str>::deserialize(deserializer)?).map_err(serde::de::Error::custom)
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use core::hash::{Hash, Hasher};
318 use std::hash::DefaultHasher;
319
320 use super::*;
321
322 #[test]
323 fn test_unnamed() {
324 const TEST_CASE: &str = "";
325
326 let addr = SocketAddr::new(TEST_CASE).unwrap();
327
328 assert!(addr.as_ref().is_unnamed());
329 }
330
331 #[test]
332 fn test_pathname() {
333 const TEST_CASE: &str = "/tmp/test_pathname.socket";
334
335 let addr = SocketAddr::new(TEST_CASE).unwrap();
336
337 assert_eq!(addr.to_os_string().to_str().unwrap(), TEST_CASE);
338 assert_eq!(addr.to_string_ext().unwrap(), TEST_CASE);
339 assert_eq!(addr.as_pathname().unwrap().to_str().unwrap(), TEST_CASE);
340 }
341
342 #[test]
343 #[cfg(any(target_os = "android", target_os = "linux"))]
344 fn test_abstract() {
345 use std::os::linux::net::SocketAddrExt;
346
347 const TEST_CASE_1: &[u8] = b"@abstract.socket";
348 const TEST_CASE_2: &[u8] = b"\0abstract.socket";
349 const TEST_CASE_3: &[u8] = b"@";
350 const TEST_CASE_4: &[u8] = b"\0";
351
352 assert_eq!(
353 SocketAddr::new(OsStr::from_bytes(TEST_CASE_1))
354 .unwrap()
355 .as_abstract_name()
356 .unwrap(),
357 &TEST_CASE_1[1..]
358 );
359
360 assert_eq!(
361 SocketAddr::new(OsStr::from_bytes(TEST_CASE_2))
362 .unwrap()
363 .as_abstract_name()
364 .unwrap(),
365 &TEST_CASE_2[1..]
366 );
367
368 assert_eq!(
369 SocketAddr::new(OsStr::from_bytes(TEST_CASE_3))
370 .unwrap()
371 .as_abstract_name()
372 .unwrap(),
373 &TEST_CASE_3[1..]
374 );
375
376 assert_eq!(
377 SocketAddr::new(OsStr::from_bytes(TEST_CASE_4))
378 .unwrap()
379 .as_abstract_name()
380 .unwrap(),
381 &TEST_CASE_4[1..]
382 );
383 }
384
385 #[test]
386 #[should_panic]
387 fn test_pathname_with_null_byte() {
388 let _addr = SocketAddr::new_pathname("(unamed)\0").unwrap();
389 }
390
391 #[test]
392 fn test_partial_eq_hash() {
393 let addr_pathname_1 = SocketAddr::new("/tmp/test_pathname_1.socket").unwrap();
394 let addr_pathname_2 = SocketAddr::new("/tmp/test_pathname_2.socket").unwrap();
395 let addr_unnamed = SocketAddr::new_unnamed();
396
397 assert_eq!(addr_pathname_1, addr_pathname_1);
398 assert_ne!(addr_pathname_1, addr_pathname_2);
399 assert_ne!(addr_pathname_2, addr_pathname_1);
400
401 assert_eq!(addr_unnamed, addr_unnamed);
402 assert_ne!(addr_pathname_1, addr_unnamed);
403 assert_ne!(addr_unnamed, addr_pathname_1);
404 assert_ne!(addr_pathname_2, addr_unnamed);
405 assert_ne!(addr_unnamed, addr_pathname_2);
406
407 #[cfg(any(target_os = "android", target_os = "linux"))]
408 {
409 let addr_abstract_1 = SocketAddr::new_abstract(b"/tmp/test_pathname_1.socket").unwrap();
410 let addr_abstract_2 = SocketAddr::new_abstract(b"/tmp/test_pathname_2.socket").unwrap();
411 let addr_abstract_empty = SocketAddr::new_abstract(&[]).unwrap();
412 let addr_abstract_unnamed_hash = SocketAddr::new_abstract(b"(unamed)\0").unwrap();
413
414 assert_eq!(addr_abstract_1, addr_abstract_1);
415 assert_ne!(addr_abstract_1, addr_abstract_2);
416 assert_ne!(addr_abstract_2, addr_abstract_1);
417
418 assert_ne!(addr_unnamed, addr_abstract_empty);
420
421 assert_ne!(addr_pathname_1, addr_abstract_1);
423
424 let addr_unnamed_hash = {
427 let mut state = DefaultHasher::new();
428 addr_unnamed.hash(&mut state);
429 state.finish()
430 };
431 let addr_abstract_unnamed_hash = {
432 let mut state = DefaultHasher::new();
433 addr_abstract_unnamed_hash.hash(&mut state);
434 state.finish()
435 };
436 assert_ne!(addr_unnamed_hash, addr_abstract_unnamed_hash);
437 }
438 }
439}