socket2_plus/
sockaddr.rs

1use std::hash::Hash;
2use std::mem::{self, size_of, MaybeUninit};
3use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
4use std::path::Path;
5use std::{fmt, io, ptr};
6
7#[cfg(windows)]
8use windows_sys::Win32::Networking::WinSock::SOCKADDR_IN6_0;
9
10use crate::sys::{
11    c_int, sa_family_t, sockaddr, sockaddr_in, sockaddr_in6, sockaddr_storage, socklen_t, AF_INET,
12    AF_INET6, AF_UNIX,
13};
14use crate::Domain;
15
16/// The address of a socket.
17///
18/// `SockAddr`s may be constructed directly to and from the standard library
19/// [`SocketAddr`], [`SocketAddrV4`], and [`SocketAddrV6`] types.
20#[derive(Clone)]
21pub struct SockAddr {
22    storage: sockaddr_storage,
23    len: socklen_t,
24}
25
26#[allow(clippy::len_without_is_empty)]
27impl SockAddr {
28    /// Create a `SockAddr` from the underlying storage and its length.
29    ///
30    /// # Safety
31    ///
32    /// Caller must ensure that the address family and length match the type of
33    /// storage address. For example if `storage.ss_family` is set to `AF_INET`
34    /// the `storage` must be initialised as `sockaddr_in`, setting the content
35    /// and length appropriately.
36    ///
37    /// # Examples
38    ///
39    /// ```
40    /// # fn main() -> std::io::Result<()> {
41    /// # #[cfg(unix)] {
42    /// use std::io;
43    /// use std::mem;
44    /// use std::os::unix::io::AsRawFd;
45    ///
46    /// use socket2_plus::{SockAddr, Socket, Domain, Type};
47    ///
48    /// let socket = Socket::new(Domain::IPV4, Type::STREAM, None)?;
49    ///
50    /// // Initialise a `SocketAddr` byte calling `getsockname(2)`.
51    /// let mut addr_storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
52    /// let mut len = mem::size_of_val(&addr_storage) as libc::socklen_t;
53    ///
54    /// // The `getsockname(2)` system call will intiliase `storage` for
55    /// // us, setting `len` to the correct length.
56    /// let res = unsafe {
57    ///     libc::getsockname(
58    ///         socket.as_raw_fd(),
59    ///         (&mut addr_storage as *mut libc::sockaddr_storage).cast(),
60    ///         &mut len,
61    ///     )
62    /// };
63    /// if res == -1 {
64    ///     return Err(io::Error::last_os_error());
65    /// }
66    ///
67    /// let address = unsafe { SockAddr::new(addr_storage, len) };
68    /// # drop(address);
69    /// # }
70    /// # Ok(())
71    /// # }
72    /// ```
73    pub const unsafe fn new(storage: sockaddr_storage, len: socklen_t) -> SockAddr {
74        SockAddr { storage, len }
75    }
76
77    /// Initialise a `SockAddr` by calling the function `init`.
78    ///
79    /// The type of the address storage and length passed to the function `init`
80    /// is OS/architecture specific.
81    ///
82    /// The address is zeroed before `init` is called and is thus valid to
83    /// dereference and read from. The length initialised to the maximum length
84    /// of the storage.
85    ///
86    /// # Safety
87    ///
88    /// Caller must ensure that the address family and length match the type of
89    /// storage address. For example if `storage.ss_family` is set to `AF_INET`
90    /// the `storage` must be initialised as `sockaddr_in`, setting the content
91    /// and length appropriately.
92    ///
93    /// # Examples
94    ///
95    /// ```
96    /// # fn main() -> std::io::Result<()> {
97    /// # #[cfg(unix)] {
98    /// use std::io;
99    /// use std::os::unix::io::AsRawFd;
100    ///
101    /// use socket2_plus::{SockAddr, Socket, Domain, Type};
102    ///
103    /// let socket = Socket::new(Domain::IPV4, Type::STREAM, None)?;
104    ///
105    /// // Initialise a `SocketAddr` byte calling `getsockname(2)`.
106    /// let (_, address) = unsafe {
107    ///     SockAddr::try_init(|addr_storage, len| {
108    ///         // The `getsockname(2)` system call will intiliase `storage` for
109    ///         // us, setting `len` to the correct length.
110    ///         if libc::getsockname(socket.as_raw_fd(), addr_storage.cast(), len) == -1 {
111    ///             Err(io::Error::last_os_error())
112    ///         } else {
113    ///             Ok(())
114    ///         }
115    ///     })
116    /// }?;
117    /// # drop(address);
118    /// # }
119    /// # Ok(())
120    /// # }
121    /// ```
122    pub unsafe fn try_init<F, T>(init: F) -> io::Result<(T, SockAddr)>
123    where
124        F: FnOnce(*mut sockaddr_storage, *mut socklen_t) -> io::Result<T>,
125    {
126        const STORAGE_SIZE: socklen_t = size_of::<sockaddr_storage>() as socklen_t;
127        // NOTE: `SockAddr::unix` depends on the storage being zeroed before
128        // calling `init`.
129        // NOTE: calling `recvfrom` with an empty buffer also depends on the
130        // storage being zeroed before calling `init` as the OS might not
131        // initialise it.
132        let mut storage = MaybeUninit::<sockaddr_storage>::zeroed();
133        let mut len = STORAGE_SIZE;
134        init(storage.as_mut_ptr(), &mut len).map(|res| {
135            debug_assert!(len <= STORAGE_SIZE, "overflown address storage");
136            let addr = SockAddr {
137                // Safety: zeroed-out `sockaddr_storage` is valid, caller must
138                // ensure at least `len` bytes are valid.
139                storage: storage.assume_init(),
140                len,
141            };
142            (res, addr)
143        })
144    }
145
146    /// Create an empty `SockAddr` with the address storage initialzed with zeros, and
147    /// its `len` set to the full length of the address storage.
148    ///
149    /// This is a convenient method to create a valid `SockAddr` to be filled in as any
150    /// kind of socket addresses (e.g. IPv4 or IPv6).
151    pub fn empty() -> SockAddr {
152        // SAFETY: a `sockaddr_storage` of all zeros is valid.
153        let storage = unsafe { mem::zeroed::<sockaddr_storage>() };
154        let len = size_of::<sockaddr_storage>() as socklen_t;
155        SockAddr { storage, len }
156    }
157
158    /// Constructs a `SockAddr` with the family `AF_UNIX` and the provided path.
159    ///
160    /// Returns an error if the path is longer than `SUN_LEN`.
161    pub fn unix<P>(path: P) -> io::Result<SockAddr>
162    where
163        P: AsRef<Path>,
164    {
165        crate::sys::unix_sockaddr(path.as_ref())
166    }
167
168    /// Set the length of the address.
169    ///
170    /// # Safety
171    ///
172    /// Caller must ensure that the address up to `length` bytes are properly
173    /// initialised.
174    pub unsafe fn set_length(&mut self, length: socklen_t) {
175        self.len = length;
176    }
177
178    /// Returns this address's family.
179    pub const fn family(&self) -> sa_family_t {
180        self.storage.ss_family
181    }
182
183    /// Returns this address's `Domain`.
184    pub const fn domain(&self) -> Domain {
185        Domain(self.storage.ss_family as c_int)
186    }
187
188    /// Returns the size of this address in bytes.
189    pub const fn len(&self) -> socklen_t {
190        self.len
191    }
192
193    /// Returns a raw pointer to the address.
194    pub const fn as_ptr(&self) -> *const sockaddr {
195        ptr::addr_of!(self.storage).cast()
196    }
197
198    /// Retuns the address as the storage.
199    pub const fn as_storage(self) -> sockaddr_storage {
200        self.storage
201    }
202
203    /// Returns true if this address is in the `AF_INET` (IPv4) family, false otherwise.
204    pub const fn is_ipv4(&self) -> bool {
205        self.storage.ss_family == AF_INET as sa_family_t
206    }
207
208    /// Returns true if this address is in the `AF_INET6` (IPv6) family, false
209    /// otherwise.
210    pub const fn is_ipv6(&self) -> bool {
211        self.storage.ss_family == AF_INET6 as sa_family_t
212    }
213
214    /// Returns true if this address is of a unix socket (for local interprocess communication),
215    /// i.e. it is from the `AF_UNIX` family, false otherwise.
216    pub fn is_unix(&self) -> bool {
217        self.storage.ss_family == AF_UNIX as sa_family_t
218    }
219
220    /// Returns this address as a `SocketAddr` if it is in the `AF_INET` (IPv4)
221    /// or `AF_INET6` (IPv6) family, otherwise returns `None`.
222    pub fn as_socket(&self) -> Option<SocketAddr> {
223        if self.storage.ss_family == AF_INET as sa_family_t {
224            // SAFETY: if the `ss_family` field is `AF_INET` then storage must
225            // be a `sockaddr_in`.
226            let addr = unsafe { &*(ptr::addr_of!(self.storage).cast::<sockaddr_in>()) };
227            let ip = crate::sys::from_in_addr(addr.sin_addr);
228            let port = u16::from_be(addr.sin_port);
229            Some(SocketAddr::V4(SocketAddrV4::new(ip, port)))
230        } else if self.storage.ss_family == AF_INET6 as sa_family_t {
231            // SAFETY: if the `ss_family` field is `AF_INET6` then storage must
232            // be a `sockaddr_in6`.
233            let addr = unsafe { &*(ptr::addr_of!(self.storage).cast::<sockaddr_in6>()) };
234            let ip = crate::sys::from_in6_addr(addr.sin6_addr);
235            let port = u16::from_be(addr.sin6_port);
236            Some(SocketAddr::V6(SocketAddrV6::new(
237                ip,
238                port,
239                addr.sin6_flowinfo,
240                #[cfg(unix)]
241                addr.sin6_scope_id,
242                #[cfg(windows)]
243                unsafe {
244                    addr.Anonymous.sin6_scope_id
245                },
246            )))
247        } else {
248            None
249        }
250    }
251
252    /// Returns this address as a [`SocketAddrV4`] if it is in the `AF_INET`
253    /// family.
254    pub fn as_socket_ipv4(&self) -> Option<SocketAddrV4> {
255        match self.as_socket() {
256            Some(SocketAddr::V4(addr)) => Some(addr),
257            _ => None,
258        }
259    }
260
261    /// Returns this address as a [`SocketAddrV6`] if it is in the `AF_INET6`
262    /// family.
263    pub fn as_socket_ipv6(&self) -> Option<SocketAddrV6> {
264        match self.as_socket() {
265            Some(SocketAddr::V6(addr)) => Some(addr),
266            _ => None,
267        }
268    }
269
270    /// Returns the initialised storage bytes.
271    fn as_bytes(&self) -> &[u8] {
272        // SAFETY: `self.storage` is a C struct which can always be treated a
273        // slice of bytes. Futhermore we ensure we don't read any unitialised
274        // bytes by using `self.len`.
275        unsafe { std::slice::from_raw_parts(self.as_ptr().cast(), self.len as usize) }
276    }
277}
278
279impl From<SocketAddr> for SockAddr {
280    fn from(addr: SocketAddr) -> SockAddr {
281        match addr {
282            SocketAddr::V4(addr) => addr.into(),
283            SocketAddr::V6(addr) => addr.into(),
284        }
285    }
286}
287
288impl From<SocketAddrV4> for SockAddr {
289    fn from(addr: SocketAddrV4) -> SockAddr {
290        // SAFETY: a `sockaddr_storage` of all zeros is valid.
291        let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
292        let len = {
293            let storage = unsafe { &mut *ptr::addr_of_mut!(storage).cast::<sockaddr_in>() };
294            storage.sin_family = AF_INET as sa_family_t;
295            storage.sin_port = addr.port().to_be();
296            storage.sin_addr = crate::sys::to_in_addr(addr.ip());
297            storage.sin_zero = Default::default();
298            mem::size_of::<sockaddr_in>() as socklen_t
299        };
300        #[cfg(any(
301            target_os = "dragonfly",
302            target_os = "freebsd",
303            target_os = "haiku",
304            target_os = "hermit",
305            target_os = "ios",
306            target_os = "macos",
307            target_os = "netbsd",
308            target_os = "nto",
309            target_os = "openbsd",
310            target_os = "tvos",
311            target_os = "vxworks",
312            target_os = "watchos",
313        ))]
314        {
315            storage.ss_len = len as u8;
316        }
317        SockAddr { storage, len }
318    }
319}
320
321impl From<SocketAddrV6> for SockAddr {
322    fn from(addr: SocketAddrV6) -> SockAddr {
323        // SAFETY: a `sockaddr_storage` of all zeros is valid.
324        let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
325        let len = {
326            let storage = unsafe { &mut *ptr::addr_of_mut!(storage).cast::<sockaddr_in6>() };
327            storage.sin6_family = AF_INET6 as sa_family_t;
328            storage.sin6_port = addr.port().to_be();
329            storage.sin6_addr = crate::sys::to_in6_addr(addr.ip());
330            storage.sin6_flowinfo = addr.flowinfo();
331            #[cfg(unix)]
332            {
333                storage.sin6_scope_id = addr.scope_id();
334            }
335            #[cfg(windows)]
336            {
337                storage.Anonymous = SOCKADDR_IN6_0 {
338                    sin6_scope_id: addr.scope_id(),
339                };
340            }
341            mem::size_of::<sockaddr_in6>() as socklen_t
342        };
343        #[cfg(any(
344            target_os = "dragonfly",
345            target_os = "freebsd",
346            target_os = "haiku",
347            target_os = "hermit",
348            target_os = "ios",
349            target_os = "macos",
350            target_os = "netbsd",
351            target_os = "nto",
352            target_os = "openbsd",
353            target_os = "tvos",
354            target_os = "vxworks",
355            target_os = "watchos",
356        ))]
357        {
358            storage.ss_len = len as u8;
359        }
360        SockAddr { storage, len }
361    }
362}
363
364impl fmt::Debug for SockAddr {
365    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
366        let mut f = fmt.debug_struct("SockAddr");
367        #[cfg(any(
368            target_os = "dragonfly",
369            target_os = "freebsd",
370            target_os = "haiku",
371            target_os = "hermit",
372            target_os = "ios",
373            target_os = "macos",
374            target_os = "netbsd",
375            target_os = "nto",
376            target_os = "openbsd",
377            target_os = "tvos",
378            target_os = "vxworks",
379            target_os = "watchos",
380        ))]
381        f.field("ss_len", &self.storage.ss_len);
382        f.field("ss_family", &self.storage.ss_family)
383            .field("len", &self.len)
384            .finish()
385    }
386}
387
388impl PartialEq for SockAddr {
389    fn eq(&self, other: &Self) -> bool {
390        self.as_bytes() == other.as_bytes()
391    }
392}
393
394impl Eq for SockAddr {}
395
396impl Hash for SockAddr {
397    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
398        self.as_bytes().hash(state);
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405
406    #[test]
407    fn ipv4() {
408        use std::net::Ipv4Addr;
409        let std = SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 9876);
410        let addr = SockAddr::from(std);
411        assert!(addr.is_ipv4());
412        assert!(!addr.is_ipv6());
413        assert!(!addr.is_unix());
414        assert_eq!(addr.family(), AF_INET as sa_family_t);
415        assert_eq!(addr.domain(), Domain::IPV4);
416        assert_eq!(addr.len(), size_of::<sockaddr_in>() as socklen_t);
417        assert_eq!(addr.as_socket(), Some(SocketAddr::V4(std)));
418        assert_eq!(addr.as_socket_ipv4(), Some(std));
419        assert!(addr.as_socket_ipv6().is_none());
420
421        let addr = SockAddr::from(SocketAddr::from(std));
422        assert_eq!(addr.family(), AF_INET as sa_family_t);
423        assert_eq!(addr.len(), size_of::<sockaddr_in>() as socklen_t);
424        assert_eq!(addr.as_socket(), Some(SocketAddr::V4(std)));
425        assert_eq!(addr.as_socket_ipv4(), Some(std));
426        assert!(addr.as_socket_ipv6().is_none());
427        #[cfg(unix)]
428        {
429            assert!(addr.as_pathname().is_none());
430            assert!(addr.as_abstract_namespace().is_none());
431        }
432    }
433
434    #[test]
435    fn ipv6() {
436        use std::net::Ipv6Addr;
437        let std = SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 9876, 11, 12);
438        let addr = SockAddr::from(std);
439        assert!(addr.is_ipv6());
440        assert!(!addr.is_ipv4());
441        assert!(!addr.is_unix());
442        assert_eq!(addr.family(), AF_INET6 as sa_family_t);
443        assert_eq!(addr.domain(), Domain::IPV6);
444        assert_eq!(addr.len(), size_of::<sockaddr_in6>() as socklen_t);
445        assert_eq!(addr.as_socket(), Some(SocketAddr::V6(std)));
446        assert!(addr.as_socket_ipv4().is_none());
447        assert_eq!(addr.as_socket_ipv6(), Some(std));
448
449        let addr = SockAddr::from(SocketAddr::from(std));
450        assert_eq!(addr.family(), AF_INET6 as sa_family_t);
451        assert_eq!(addr.len(), size_of::<sockaddr_in6>() as socklen_t);
452        assert_eq!(addr.as_socket(), Some(SocketAddr::V6(std)));
453        assert!(addr.as_socket_ipv4().is_none());
454        assert_eq!(addr.as_socket_ipv6(), Some(std));
455        #[cfg(unix)]
456        {
457            assert!(addr.as_pathname().is_none());
458            assert!(addr.as_abstract_namespace().is_none());
459        }
460    }
461
462    #[test]
463    fn ipv4_eq() {
464        use std::net::Ipv4Addr;
465
466        let std1 = SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 9876);
467        let std2 = SocketAddrV4::new(Ipv4Addr::new(5, 6, 7, 8), 8765);
468
469        test_eq(
470            SockAddr::from(std1),
471            SockAddr::from(std1),
472            SockAddr::from(std2),
473        );
474    }
475
476    #[test]
477    fn ipv4_hash() {
478        use std::net::Ipv4Addr;
479
480        let std1 = SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 9876);
481        let std2 = SocketAddrV4::new(Ipv4Addr::new(5, 6, 7, 8), 8765);
482
483        test_hash(
484            SockAddr::from(std1),
485            SockAddr::from(std1),
486            SockAddr::from(std2),
487        );
488    }
489
490    #[test]
491    fn ipv6_eq() {
492        use std::net::Ipv6Addr;
493
494        let std1 = SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 9876, 11, 12);
495        let std2 = SocketAddrV6::new(Ipv6Addr::new(3, 4, 5, 6, 7, 8, 9, 0), 7654, 13, 14);
496
497        test_eq(
498            SockAddr::from(std1),
499            SockAddr::from(std1),
500            SockAddr::from(std2),
501        );
502    }
503
504    #[test]
505    fn ipv6_hash() {
506        use std::net::Ipv6Addr;
507
508        let std1 = SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 9876, 11, 12);
509        let std2 = SocketAddrV6::new(Ipv6Addr::new(3, 4, 5, 6, 7, 8, 9, 0), 7654, 13, 14);
510
511        test_hash(
512            SockAddr::from(std1),
513            SockAddr::from(std1),
514            SockAddr::from(std2),
515        );
516    }
517
518    #[test]
519    fn ipv4_ipv6_eq() {
520        use std::net::Ipv4Addr;
521        use std::net::Ipv6Addr;
522
523        let std1 = SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 9876);
524        let std2 = SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 9876, 11, 12);
525
526        test_eq(
527            SockAddr::from(std1),
528            SockAddr::from(std1),
529            SockAddr::from(std2),
530        );
531
532        test_eq(
533            SockAddr::from(std2),
534            SockAddr::from(std2),
535            SockAddr::from(std1),
536        );
537    }
538
539    #[test]
540    fn ipv4_ipv6_hash() {
541        use std::net::Ipv4Addr;
542        use std::net::Ipv6Addr;
543
544        let std1 = SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 9876);
545        let std2 = SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 9876, 11, 12);
546
547        test_hash(
548            SockAddr::from(std1),
549            SockAddr::from(std1),
550            SockAddr::from(std2),
551        );
552
553        test_hash(
554            SockAddr::from(std2),
555            SockAddr::from(std2),
556            SockAddr::from(std1),
557        );
558    }
559
560    #[allow(clippy::eq_op)] // allow a0 == a0 check
561    fn test_eq(a0: SockAddr, a1: SockAddr, b: SockAddr) {
562        assert!(a0 == a0);
563        assert!(a0 == a1);
564        assert!(a1 == a0);
565        assert!(a0 != b);
566        assert!(b != a0);
567    }
568
569    fn test_hash(a0: SockAddr, a1: SockAddr, b: SockAddr) {
570        assert!(calculate_hash(&a0) == calculate_hash(&a0));
571        assert!(calculate_hash(&a0) == calculate_hash(&a1));
572        // technically unequal values can have the same hash, in this case x != z and both have different hashes
573        assert!(calculate_hash(&a0) != calculate_hash(&b));
574    }
575
576    fn calculate_hash(x: &SockAddr) -> u64 {
577        use std::collections::hash_map::DefaultHasher;
578        use std::hash::Hasher;
579
580        let mut hasher = DefaultHasher::new();
581        x.hash(&mut hasher);
582        hasher.finish()
583    }
584}