systemd_socket/
lib.rs

1//! A convenience crate for optionally supporting systemd socket activation.
2//! 
3//! ## About
4//!
5//! **Important:** because of various reasons it is recommended to call the [`init`] function at
6//! the start of your program!
7//! 
8//! The goal of this crate is to make socket activation with systemd in your project trivial.
9//! It provides a replacement for `std::net::SocketAddr` that allows parsing the bind address from string just like the one from `std`
10//! but on top of that also allows `systemd://socket_name` format that tells it to use systemd activation with given socket name.
11//! Then it provides a method to bind the address which will return the socket from systemd if available.
12//!
13//! The provided type supports conversions from various types of strings and also `serde` and `parse_arg` via feature flag.
14//! Thanks to this the change to your code should be minimal - parsing will continue to work, it'll just allow a new format.
15//! You only need to change the code to use `SocketAddr::bind()` instead of `TcpListener::bind()` for binding.
16//!
17//! You also don't need to worry about conditional compilation to ensure OS compatibility.
18//! This crate handles that for you by disabling systemd on non-linux systems.
19//!
20//! Further, the crate also provides methods for binding `tokio` 1.0, 0.2, 0.3, and `async_std` sockets if the appropriate features are
21//! activated.
22//! 
23//! ## Example
24//! 
25//! ```no_run
26//! use systemd_socket::SocketAddr;
27//! use std::convert::TryFrom;
28//! use std::io::Write;
29//! 
30//! systemd_socket::init().expect("Failed to initialize systemd sockets");
31//! let mut args = std::env::args_os();
32//! let program_name = args.next().expect("unknown program name");
33//! let socket_addr = args.next().expect("missing socket address");
34//! let socket_addr = SocketAddr::try_from(socket_addr).expect("failed to parse socket address");
35//! let socket = socket_addr.bind().expect("failed to bind socket");
36//!
37//! loop {
38//!     let _ = socket
39//!     .accept()
40//!     .expect("failed to accept connection")
41//!     .0
42//!     .write_all(b"Hello world!")
43//!     .map_err(|err| eprintln!("Failed to send {}", err));
44//! }
45//! ```
46//!
47//! ## Features
48//!
49//! * `enable_systemd` - on by default, the existence of this feature can allow your users to turn
50//!   off systemd support if they don't need it. Note that it's already disabled on non-linux
51//!   systems, so you don't need to care about that.
52//! * `serde` - implements `serde::Deserialize` for `SocketAddr`
53//! * `parse_arg` - implements `parse_arg::ParseArg` for `SocketAddr`
54//! * `tokio` - adds `bind_tokio` method to `SocketAddr` (tokio 1.0)
55//! * `tokio_0_2` - adds `bind_tokio_0_2` method to `SocketAddr`
56//! * `tokio_0_3` - adds `bind_tokio_0_3` method to `SocketAddr`
57//! * `async_std` - adds `bind_async_std` method to `SocketAddr`
58//!
59//! ## Soundness
60//!
61//! The systemd file descriptors are transferred using environment variables and since they are
62//! file descriptors, they should have move semantics. However environment variables in Rust do not
63//! have move semantics and even modifying them is very dangerous.
64//!
65//! Because of this, the crate only allows initialization when there's only one thread running.
66//! However that still doesn't prevent all possible problems: if some other code closes file
67//! descriptors stored in those environment variables you can get an invalid socket.
68//!
69//! This situation is obviously ridiculous because there shouldn't be a reason to use another
70//! library to do the same thing. It could also be argued that whichever code doesn't clear the
71//! evironment variable is broken (even though understandably) and it's not a fault of this library.
72//!
73//! ## MSRV
74//!
75//! This crate must always compile with the latest Rust available in the latest Debian stable.
76//! That is currently Rust 1.48.0. (Debian 11 - Bullseye)
77
78#![cfg_attr(docsrs, feature(doc_auto_cfg))]
79
80#![deny(missing_docs)]
81
82pub mod error;
83mod resolv_addr;
84
85use std::convert::{TryFrom, TryInto};
86use std::fmt;
87use std::ffi::{OsStr, OsString};
88use crate::error::*;
89use crate::resolv_addr::ResolvAddr;
90
91#[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
92use std::convert::Infallible as Never;
93
94#[cfg(all(target_os = "linux", feature = "enable_systemd"))]
95pub(crate) mod systemd_sockets {
96    use std::fmt;
97    use std::sync::Mutex;
98    use libsystemd::activation::FileDescriptor;
99    use libsystemd::errors::SdError as LibSystemdError;
100
101    #[derive(Debug)]
102    pub(crate) struct Error(&'static Mutex<InitError>);
103
104    impl fmt::Display for Error {
105        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
106            use std::error::Error as _;
107
108            let guard = self.0.lock().expect("mutex poisoned");
109            fmt::Display::fmt(&*guard, f)?;
110            let mut source_opt = guard.source();
111            while let Some(source) = source_opt {
112                write!(f, ": {}", source)?;
113                source_opt = source.source();
114            }
115            Ok(())
116        }
117    }
118
119    type StoredSocket = Result<Socket, ()>;
120
121    // No source we can't keep the mutex locked
122    impl std::error::Error for Error {}
123
124    pub(crate) unsafe fn init(protected: bool) -> Result<(), InitError> {
125        SYSTEMD_SOCKETS.get_or_try_init(|| SystemdSockets::new(protected, true).map(Ok)).map(drop)
126    }
127
128    pub(crate) fn take(name: &str) -> Result<Option<StoredSocket>, Error> {
129        let sockets = SYSTEMD_SOCKETS.get_or_init(|| SystemdSockets::new_protected(false).map_err(Mutex::new));
130        match sockets {
131            Ok(sockets) => Ok(sockets.take(name)),
132            Err(error) => Err(Error(error))
133        }
134    }
135
136    #[derive(Debug)]
137    pub(crate) enum InitError {
138        OpenStatus(std::io::Error),
139        ReadStatus(std::io::Error),
140        ThreadCountNotFound,
141        MultipleThreads,
142        LibSystemd(LibSystemdError),
143    }
144
145    impl From<LibSystemdError> for InitError {
146        fn from(value: LibSystemdError) -> Self {
147            Self::LibSystemd(value)
148        }
149    }
150
151    impl fmt::Display for InitError {
152        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
153            match self {
154                Self::OpenStatus(_) => write!(f, "failed to open /proc/self/status"),
155                Self::ReadStatus(_) => write!(f, "failed to read /proc/self/status"),
156                Self::ThreadCountNotFound => write!(f, "/proc/self/status doesn't contain Threads entry"),
157                Self::MultipleThreads => write!(f, "there is more than one thread running"),
158                // We have nothing to say about the error, let's flatten it
159                Self::LibSystemd(error) => fmt::Display::fmt(error, f),
160            }
161        }
162    }
163
164    impl std::error::Error for InitError {
165        fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
166            match self {
167                Self::OpenStatus(error) => Some(error),
168                Self::ReadStatus(error) => Some(error),
169                Self::ThreadCountNotFound => None,
170                Self::MultipleThreads => None,
171                // We have nothing to say about the error, let's flatten it
172                Self::LibSystemd(error) => error.source(),
173            }
174        }
175    }
176
177    pub(crate) enum Socket {
178        TcpListener(std::net::TcpListener),
179    }
180
181    impl std::convert::TryFrom<FileDescriptor> for Socket {
182        type Error = ();
183
184        fn try_from(value: FileDescriptor) -> Result<Self, Self::Error> {
185            use libsystemd::activation::IsType;
186            use std::os::unix::io::{FromRawFd, IntoRawFd, AsRawFd};
187
188            fn set_cloexec(fd: std::os::unix::io::RawFd) {
189                // SAFETY: The function is a harmless syscall
190                let flags = unsafe { libc::fcntl(fd, libc::F_GETFD) };
191                if flags != -1 && flags & libc::FD_CLOEXEC == 0 {
192                    // We ignore errors, since the FD is still usable
193                    // SAFETY: socket is definitely a valid file descriptor and setting CLOEXEC is
194                    // a sound operation.
195                    unsafe {
196                        libc::fcntl(fd, libc::F_SETFD, flags | libc::FD_CLOEXEC);
197                    }
198                }
199            }
200
201            if value.is_inet() {
202                // SAFETY: FileDescriptor is obtained from systemd, so it should be valid.
203                let socket = unsafe { std::net::TcpListener::from_raw_fd(value.into_raw_fd()) };
204                set_cloexec(socket.as_raw_fd());
205                Ok(Socket::TcpListener(socket))
206            } else {
207                // We still need to make the filedescriptor harmless.
208                set_cloexec(value.into_raw_fd());
209                Err(())
210            }
211        }
212    }
213
214    struct SystemdSockets(std::sync::Mutex<std::collections::HashMap<String, StoredSocket>>);
215
216    impl SystemdSockets {
217        fn new_protected(explicit: bool) -> Result<Self, InitError> {
218            unsafe { Self::new(true, explicit) }
219        }
220
221        unsafe fn new(protected: bool, explicit: bool) -> Result<Self, InitError> {
222            use std::convert::TryFrom;
223
224            if explicit {
225                if std::env::var_os("LISTEN_PID").is_none() && std::env::var_os("LISTEN_FDS").is_none() && std::env::var_os("LISTEN_FDNAMES").is_none() {
226                    // Systemd is not used - make the map empty
227                    return Ok(SystemdSockets(Mutex::new(Default::default())));
228                }
229            }
230
231            if protected { Self::check_single_thread()? }
232                                                                            // MUST BE true FOR SAFETY!!!
233            let map = libsystemd::activation::receive_descriptors_with_names(/*unset env = */ protected)?.into_iter().map(|(fd, name)| {
234                (name, Socket::try_from(fd))
235            }).collect();
236            Ok(SystemdSockets(Mutex::new(map)))
237        }
238
239        fn check_single_thread() -> Result<(), InitError> {
240            use std::io::BufRead;
241
242            let status = std::fs::File::open("/proc/self/status").map_err(InitError::OpenStatus)?;
243            let mut status = std::io::BufReader::new(status);
244            let mut line = String::new();
245            loop {
246                if status.read_line(&mut line).map_err(InitError::ReadStatus)? == 0 {
247                    return Err(InitError::ThreadCountNotFound);
248                }
249                if let Some(threads) = line.strip_prefix("Threads:") {
250                    if threads.trim() == "1" {
251                        break;
252                    } else {
253                        return Err(InitError::MultipleThreads);
254                    }
255                }
256                line.clear();
257            }
258            Ok(())
259        }
260
261        fn take(&self, name: &str) -> Option<StoredSocket> {
262            // MUST remove THE SOCKET FOR SAFETY!!!
263            self.0.lock().expect("poisoned mutex").remove(name)
264        }
265    }
266
267    static SYSTEMD_SOCKETS: once_cell::sync::OnceCell<Result<SystemdSockets, Mutex<InitError>>> = once_cell::sync::OnceCell::new();
268}
269
270/// Socket address that can be an ordinary address or a systemd socket
271///
272/// This is the core type of this crate that abstracts possible addresses.
273/// It can be (fallibly) converted from various types of strings or deserialized with `serde`.
274/// After it's created, it can be bound as `TcpListener` from `std` or even `tokio` or `async_std`
275/// if the appropriate feature is enabled.
276///
277/// Optional dependencies on `parse_arg` and `serde` make it trivial to use with
278/// [`configure_me`](https://crates.io/crates/configure_me).
279#[derive(Debug)]
280#[cfg_attr(feature = "serde", derive(serde_crate::Deserialize), serde(crate = "serde_crate", try_from = "serde_str_helpers::DeserBorrowStr"))]
281pub struct SocketAddr(SocketAddrInner);
282
283impl SocketAddr {
284    /// Creates SocketAddr from systemd name directly, without requiring `systemd://` prefix.
285    ///
286    /// Always fails with systemd unsupported error if systemd is not supported.
287    pub fn from_systemd_name<T: Into<String>>(name: T) -> Result<Self, ParseError> {
288        Self::inner_from_systemd_name(name.into(), false)
289    }
290
291    #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
292    fn inner_from_systemd_name(name: String, prefixed: bool) -> Result<Self, ParseError> {
293        let real_systemd_name = if prefixed {
294            &name[SYSTEMD_PREFIX.len()..]
295        } else {
296            &name
297        };
298
299        let name_len = real_systemd_name.len();
300        match real_systemd_name.chars().enumerate().find(|(_, c)| !c.is_ascii() || *c < ' ' || *c == ':') {
301            None if name_len <= 255 && prefixed => Ok(SocketAddr(SocketAddrInner::Systemd(name))),
302            None if name_len <= 255 && !prefixed => Ok(SocketAddr(SocketAddrInner::SystemdNoPrefix(name))),
303            None => Err(ParseErrorInner::LongSocketName { string: name, len: name_len }.into()),
304            Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: name, c, pos, }.into()),
305        }
306    }
307
308
309    #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
310    fn inner_from_systemd_name(name: String, _prefixed: bool) -> Result<Self, ParseError> {
311        Err(ParseError(ParseErrorInner::SystemdUnsupported(name)))
312    }
313
314    /// Creates `std::net::TcpListener`
315    ///
316    /// This method either `binds` the socket, if the address was provided or uses systemd socket
317    /// if the socket name was provided.
318    pub fn bind(self) -> Result<std::net::TcpListener, BindError> {
319        match self.0 {
320            SocketAddrInner::Ordinary(addr) => match std::net::TcpListener::bind(addr) {
321                Ok(socket) => Ok(socket),
322                Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()),
323            },
324            SocketAddrInner::WithHostname(addr) => match std::net::TcpListener::bind(addr.as_str()) {
325                Ok(socket) => Ok(socket),
326                Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()),
327            },
328            SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name, true).map(|(socket, _)| socket),
329            SocketAddrInner::SystemdNoPrefix(socket_name) => Self::get_systemd(socket_name, false).map(|(socket, _)| socket),
330        }
331    }
332
333    /// Creates `tokio::net::TcpListener`
334    ///
335    /// To be specific, it binds the socket or converts systemd socket to `tokio` 1.0 socket.
336    ///
337    /// This method either `binds` the socket, if the address was provided or uses systemd socket
338    /// if the socket name was provided.
339    #[cfg(feature = "tokio")]
340    pub async fn bind_tokio(self) -> Result<tokio::net::TcpListener, TokioBindError> {
341        match self.0 {
342            SocketAddrInner::Ordinary(addr) => match tokio::net::TcpListener::bind(addr).await {
343                Ok(socket) => Ok(socket),
344                Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())),
345            },
346            SocketAddrInner::WithHostname(addr) => match tokio::net::TcpListener::bind(addr.as_str()).await {
347                Ok(socket) => Ok(socket),
348                Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
349            },
350            SocketAddrInner::Systemd(socket_name) => {
351                let (socket, addr) = Self::get_systemd(socket_name, true)?;
352                socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
353            },
354            SocketAddrInner::SystemdNoPrefix(socket_name) => {
355                let (socket, addr) = Self::get_systemd(socket_name, false)?;
356                socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
357            },
358        }
359    }
360
361    /// Creates `tokio::net::TcpListener`
362    ///
363    /// To be specific, it binds the socket or converts systemd socket to `tokio` 0.2 socket.
364    ///
365    /// This method either `binds` the socket, if the address was provided or uses systemd socket
366    /// if the socket name was provided.
367    #[cfg(feature = "tokio_0_2")]
368    pub async fn bind_tokio_0_2(self) -> Result<tokio_0_2::net::TcpListener, TokioBindError> {
369        match self.0 {
370            SocketAddrInner::Ordinary(addr) => match tokio_0_2::net::TcpListener::bind(addr).await {
371                Ok(socket) => Ok(socket),
372                Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())),
373            },
374            SocketAddrInner::WithHostname(addr) => match tokio_0_2::net::TcpListener::bind(addr.as_str()).await {
375                Ok(socket) => Ok(socket),
376                Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
377            },
378            SocketAddrInner::Systemd(socket_name) => {
379                let (socket, addr) = Self::get_systemd(socket_name, true)?;
380                socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
381            },
382            SocketAddrInner::SystemdNoPrefix(socket_name) => {
383                let (socket, addr) = Self::get_systemd(socket_name, false)?;
384                socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
385            },
386        }
387    }
388
389    /// Creates `tokio::net::TcpListener`
390    ///
391    /// To be specific, it binds the socket or converts systemd socket to `tokio` 0.3 socket.
392    ///
393    /// This method either `binds` the socket, if the address was provided or uses systemd socket
394    /// if the socket name was provided.
395    #[cfg(feature = "tokio_0_3")]
396    pub async fn bind_tokio_0_3(self) -> Result<tokio_0_3::net::TcpListener, TokioBindError> {
397        match self.0 {
398            SocketAddrInner::Ordinary(addr) => match tokio_0_3::net::TcpListener::bind(addr).await {
399                Ok(socket) => Ok(socket),
400                Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())),
401            },
402            SocketAddrInner::WithHostname(addr) => match tokio_0_3::net::TcpListener::bind(addr.as_str()).await {
403                Ok(socket) => Ok(socket),
404                Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
405            },
406            SocketAddrInner::Systemd(socket_name) => {
407                let (socket, addr) = Self::get_systemd(socket_name, true)?;
408                socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
409            },
410            SocketAddrInner::SystemdNoPrefix(socket_name) => {
411                let (socket, addr) = Self::get_systemd(socket_name, false)?;
412                socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
413            },
414        }
415    }
416
417    /// Creates `async_std::net::TcpListener`
418    ///
419    /// To be specific, it binds the socket or converts systemd socket to `async_std` socket.
420    ///
421    /// This method either `binds` the socket, if the address was provided or uses systemd socket
422    /// if the socket name was provided.
423    #[cfg(feature = "async-std")]
424    pub async fn bind_async_std(self) -> Result<async_std::net::TcpListener, BindError> {
425        match self.0 {
426            SocketAddrInner::Ordinary(addr) => match async_std::net::TcpListener::bind(addr).await {
427                Ok(socket) => Ok(socket),
428                Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()),
429            },
430            SocketAddrInner::WithHostname(addr) => match async_std::net::TcpListener::bind(addr.as_str()).await {
431                Ok(socket) => Ok(socket),
432                Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()),
433            },
434            SocketAddrInner::Systemd(socket_name) => {
435                let (socket, _) = Self::get_systemd(socket_name, true)?;
436                Ok(socket.into())
437            },
438            SocketAddrInner::SystemdNoPrefix(socket_name) => {
439                let (socket, _) = Self::get_systemd(socket_name, false)?;
440                Ok(socket.into())
441            },
442        }
443    }
444
445    // We can't impl<T: Deref<Target=str> + Into<String>> TryFrom<T> for SocketAddr because of orphan
446    // rules.
447    fn try_from_generic<'a, T>(string: T) -> Result<Self, ParseError> where T: 'a + std::ops::Deref<Target=str> + Into<String> {
448        if string.starts_with(SYSTEMD_PREFIX) {
449            Self::inner_from_systemd_name(string.into(), true)
450        } else {
451            match string.parse() {
452                Ok(addr) => Ok(SocketAddr(SocketAddrInner::Ordinary(addr))),
453                Err(_) => Ok(SocketAddr(SocketAddrInner::WithHostname(ResolvAddr::try_from_generic(string).map_err(ParseErrorInner::ResolvAddr)?))),
454            }
455        }
456    }
457
458    #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
459    fn get_systemd(socket_name: String, prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
460        use systemd_sockets::Socket;
461
462        let real_systemd_name = if prefixed {
463            &socket_name[SYSTEMD_PREFIX.len()..]
464        } else {
465            &socket_name
466        };
467
468        let socket = systemd_sockets::take(real_systemd_name).map_err(BindErrorInner::ReceiveDescriptors)?;
469        // match instead of combinators to avoid cloning socket_name
470        match socket {
471            Some(Ok(Socket::TcpListener(socket))) => Ok((socket, SocketAddrInner::Systemd(socket_name))),
472            Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()),
473            None => Err(BindErrorInner::MissingDescriptor(socket_name).into())
474        }
475    }
476
477    // This approach makes the rest of the code much simpler as it doesn't require sprinkling it
478    // with #[cfg(all(target_os = "linux", feature = "enable_systemd"))] yet still statically guarantees it won't execute.
479    #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
480    fn get_systemd(socket_name: Never, _prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
481        match socket_name {}
482    }
483}
484
485/// Initializes the library while there's only a single thread.
486///
487/// Unfortunately, this library has to be initialized and, for soundness, this initialization must
488/// happen when no other threads are running. This is attempted automatically when trying to bind a
489/// systemd socket but at that time there may be other threads running and error reporting also
490/// faces some restrictions. This function provides better control over the initialization point
491/// and returns a more idiomatic error type.
492///
493/// You should generally call this at around the top of `main`, where no threads were created yet.
494/// While technically, you may spawn a thread and call this function after that thread terminated,
495/// this has the additional problem that the descriptors are still around, so if that thread (or the
496/// current one!) forks and execs the descriptors will leak into the child.
497#[inline]
498pub fn init() -> Result<(), error::InitError> {
499    #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
500    {
501        // Calling with true is always sound
502        unsafe { systemd_sockets::init(true) }.map_err(error::InitError)
503    }
504    #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
505    {
506        Ok(())
507    }
508}
509
510/// Initializes the library without protection against double close.
511///
512/// Unfortunately, this library has to be initialized and, because double closing file descriptors
513/// is unsound, the library has some protections against double close. However these protections
514/// come with the limitation that the library must be initailized with a single thread.
515///
516/// If for any reason you're unable to call `init` in a single thread at around the top of `main`
517/// (and this should be almost never) you may call this method if you've ensured that no other part
518/// of your codebase is operating on systemd-provided file descriptors stored in the environment 
519/// variables.
520///
521/// Note however that doing so uncovers another problem: if another thread forks and execs the
522/// systemd file descriptors will get passed into that program! In such case you somehow need to
523/// clean up the file descriptors yourself.
524pub unsafe fn init_unprotected() -> Result<(), error::InitError> {
525    #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
526    {
527        systemd_sockets::init(false).map_err(error::InitError)
528    }
529    #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
530    {
531        Ok(())
532    }
533}
534
535/// Displays the address in format that can be parsed again.
536///
537/// **Important: While I don't expect this impl to change, don't rely on it!**
538/// It should be used mostly for debugging/logging.
539impl fmt::Display for SocketAddr {
540    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
541        fmt::Display::fmt(&self.0, f)
542    }
543}
544
545impl fmt::Display for SocketAddrInner {
546    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
547        match self {
548            SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f),
549            SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f),
550            SocketAddrInner::SystemdNoPrefix(addr) => write!(f, "{}{}", SYSTEMD_PREFIX, addr),
551            SocketAddrInner::WithHostname(addr) => fmt::Display::fmt(addr, f),
552        }
553    }
554}
555
556// PartialEq for testing, I'm not convinced it should be exposed
557#[derive(Debug, PartialEq)]
558enum SocketAddrInner {
559    Ordinary(std::net::SocketAddr),
560    WithHostname(resolv_addr::ResolvAddr),
561    #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
562    Systemd(String),
563    #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
564    #[allow(dead_code)]
565    Systemd(Never),
566    #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
567    #[allow(dead_code)]
568    SystemdNoPrefix(String),
569    #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
570    #[allow(dead_code)]
571    SystemdNoPrefix(Never),
572}
573
574const SYSTEMD_PREFIX: &str = "systemd://";
575
576impl<I: Into<std::net::IpAddr>> From<(I, u16)> for SocketAddr {
577    fn from(value: (I, u16)) -> Self {
578        SocketAddr(SocketAddrInner::Ordinary(value.into()))
579    }
580}
581
582impl From<std::net::SocketAddr> for SocketAddr {
583    fn from(value: std::net::SocketAddr) -> Self {
584        SocketAddr(SocketAddrInner::Ordinary(value))
585    }
586}
587
588impl From<std::net::SocketAddrV4> for SocketAddr {
589    fn from(value: std::net::SocketAddrV4) -> Self {
590        SocketAddr(SocketAddrInner::Ordinary(value.into()))
591    }
592}
593
594impl From<std::net::SocketAddrV6> for SocketAddr {
595    fn from(value: std::net::SocketAddrV6) -> Self {
596        SocketAddr(SocketAddrInner::Ordinary(value.into()))
597    }
598}
599
600impl std::str::FromStr for SocketAddr {
601    type Err = ParseError;
602
603    fn from_str(s: &str) -> Result<Self, Self::Err> {
604        SocketAddr::try_from_generic(s)
605    }
606}
607
608impl<'a> TryFrom<&'a str> for SocketAddr {
609    type Error = ParseError;
610
611    fn try_from(s: &'a str) -> Result<Self, Self::Error> {
612        SocketAddr::try_from_generic(s)
613    }
614}
615
616impl TryFrom<String> for SocketAddr {
617    type Error = ParseError;
618
619    fn try_from(s: String) -> Result<Self, Self::Error> {
620        SocketAddr::try_from_generic(s)
621    }
622}
623
624impl<'a> TryFrom<&'a OsStr> for SocketAddr {
625    type Error = ParseOsStrError;
626
627    fn try_from(s: &'a OsStr) -> Result<Self, Self::Error> {
628        s.to_str().ok_or(ParseOsStrError::InvalidUtf8)?.try_into().map_err(Into::into)
629    }
630}
631
632impl TryFrom<OsString> for SocketAddr {
633    type Error = ParseOsStrError;
634
635    fn try_from(s: OsString) -> Result<Self, Self::Error> {
636        s.into_string().map_err(|_| ParseOsStrError::InvalidUtf8)?.try_into().map_err(Into::into)
637    }
638}
639
640#[cfg(feature = "serde")]
641impl<'a> TryFrom<serde_str_helpers::DeserBorrowStr<'a>> for SocketAddr {
642    type Error = ParseError;
643
644    fn try_from(s: serde_str_helpers::DeserBorrowStr<'a>) -> Result<Self, Self::Error> {
645        SocketAddr::try_from_generic(s)
646    }
647}
648
649#[cfg(feature = "parse_arg")]
650impl parse_arg::ParseArg for SocketAddr {
651    type Error = ParseOsStrError;
652
653    fn describe_type<W: fmt::Write>(mut writer: W) -> fmt::Result {
654        std::net::SocketAddr::describe_type(&mut writer)?;
655        write!(writer, " or a systemd socket name prefixed with systemd://")
656    }
657
658    fn parse_arg(arg: &OsStr) -> Result<Self, Self::Error> {
659        arg.try_into()
660    }
661
662    fn parse_owned_arg(arg: OsString) -> Result<Self, Self::Error> {
663        arg.try_into()
664    }
665}
666
667#[cfg(test)]
668mod tests {
669    use super::{SocketAddr, SocketAddrInner};
670
671    #[test]
672    fn parse_ordinary() {
673        assert_eq!("127.0.0.1:42".parse::<SocketAddr>().unwrap().0, SocketAddrInner::Ordinary(([127, 0, 0, 1], 42).into()));
674    }
675
676    #[test]
677    #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
678    fn parse_systemd() {
679        assert_eq!("systemd://foo".parse::<SocketAddr>().unwrap().0, SocketAddrInner::Systemd("systemd://foo".to_owned()));
680    }
681
682    #[test]
683    #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
684    #[should_panic]
685    fn parse_systemd() {
686        "systemd://foo".parse::<SocketAddr>().unwrap();
687    }
688
689    #[test]
690    #[should_panic]
691    fn parse_systemd_fail_control() {
692        "systemd://foo\n".parse::<SocketAddr>().unwrap();
693    }
694
695    #[test]
696    #[should_panic]
697    fn parse_systemd_fail_colon() {
698        "systemd://foo:".parse::<SocketAddr>().unwrap();
699    }
700
701    #[test]
702    #[should_panic]
703    fn parse_systemd_fail_non_ascii() {
704        "systemd://fooá".parse::<SocketAddr>().unwrap();
705    }
706
707    #[test]
708    #[should_panic]
709    fn parse_systemd_fail_too_long() {
710        "systemd://xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx".parse::<SocketAddr>().unwrap();
711    }
712
713    #[test]
714    #[cfg_attr(not(all(target_os = "linux", feature = "enable_systemd")), should_panic)]
715    fn no_prefix_parse_systemd() {
716        SocketAddr::from_systemd_name("foo").unwrap();
717    }
718
719    #[test]
720    #[should_panic]
721    fn no_prefix_parse_systemd_fail_non_ascii() {
722        SocketAddr::from_systemd_name("fooá").unwrap();
723    }
724
725    #[test]
726    #[should_panic]
727    fn no_prefix_parse_systemd_fail_too_long() {
728        SocketAddr::from_systemd_name("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx").unwrap();
729    }
730}