systemd_connector/
socket.rs

1//! Access sockets passed from systemd
2
3use std::fs::File;
4use std::io;
5use std::net::TcpListener;
6use std::os::unix::prelude::*;
7use std::process;
8
9use thiserror::Error;
10
11const SD_FD_OFFSET: i32 = 3;
12const LISTEN_FDS: &str = "LISTEN_FDS";
13const LISTEN_FDNAMES: &str = "LISTEN_FDNAMES";
14const LISTEN_PID: &str = "LISTEN_PID";
15
16/// Errors that can occur when trying to access systemd-owned sockets
17#[derive(Debug, Error)]
18pub enum SocketError {
19    /// An IO error occurred communicating with the socket
20    #[error("{}", .0)]
21    IO(#[from] io::Error),
22
23    /// The PID that systemd gave us is not our PID
24    #[error("PID={0} but ${}={1}", LISTEN_PID)]
25    WrongPID(u32, String),
26
27    /// The file descriptor that systemd gave us is not a socket
28    #[error("file descriptor {} is not a socket", .0)]
29    NotSocket(RawFd),
30
31    /// Missing a systemd variable
32    #[error("Missing ${0} variable")]
33    MissingVar(&'static str),
34
35    /// Invalid value for a systemd variable
36    #[error("Invalid ${0}={1}")]
37    InvalidVar(&'static str, String),
38}
39
40pub(crate) fn var(name: &'static str) -> Result<String, SocketError> {
41    match std::env::var(name) {
42        Ok(value) => Ok(value),
43        Err(std::env::VarError::NotPresent) => Err(SocketError::MissingVar(name)),
44        Err(std::env::VarError::NotUnicode(_)) => Err(SocketError::MissingVar(name)),
45    }
46}
47
48/// Get the sockets that systemd has passed to us as file descriptors
49pub fn sockets() -> Result<Vec<SystemDSocket>, SocketError> {
50    let listen_pid = var(LISTEN_PID);
51    let listen_fds = var(LISTEN_FDS);
52    let listen_fd_names = var(LISTEN_FDNAMES).ok();
53
54    construct_sockets(
55        listen_fds?.as_str(),
56        listen_fd_names.as_deref(),
57        listen_pid?.as_str(),
58    )
59}
60
61fn construct_sockets(
62    listen_fds: &str,
63    listen_fd_names: Option<&str>,
64    listen_pid: &str,
65) -> Result<Vec<SystemDSocket>, SocketError> {
66    let pid = listen_pid
67        .parse::<u32>()
68        .map_err(|_| SocketError::InvalidVar(LISTEN_PID, listen_pid.into()))?;
69
70    if process::id() != pid {
71        return Err(SocketError::WrongPID(process::id(), listen_pid.into()));
72    }
73
74    let n = listen_fds
75        .parse::<usize>()
76        .map_err(|_| SocketError::InvalidVar(LISTEN_FDS, listen_fds.into()))?;
77
78    if let Some(names_value) = listen_fd_names {
79        let names: Vec<_> = names_value.split(':').collect();
80
81        if names.len() == n {
82            return Ok((SD_FD_OFFSET..)
83                .take(n)
84                .zip(names)
85                .map(|(fd, name)| SystemDSocket::new(name, fd))
86                .collect());
87        } else if !names.is_empty() {
88            tracing::warn!("Invalid ${}={}", LISTEN_FDNAMES, names_value);
89        };
90    };
91
92    Ok((SD_FD_OFFSET..)
93        .take(n)
94        .map(SystemDSocket::unnamed)
95        .collect())
96}
97
98/// Represents a socket that systemd has passed to us
99#[derive(Debug)]
100pub struct SystemDSocket {
101    name: Option<String>,
102    fd: RawFd,
103}
104
105impl SystemDSocket {
106    fn new<S: Into<String>>(name: S, fd: RawFd) -> Self {
107        Self {
108            name: Some(name.into()),
109            fd,
110        }
111    }
112
113    fn unnamed(fd: RawFd) -> Self {
114        Self { name: None, fd }
115    }
116
117    /// Get the name of the socket, if it has one.
118    ///
119    /// Systemd can provide names in environemnt variables, but it is not required
120    /// to. If the socket does not have a name, this will return `None`.
121    pub fn name(&self) -> Option<&str> {
122        self.name.as_deref()
123    }
124
125    /// Convert this socket into a `TcpListener`
126    pub fn listener(self) -> Result<TcpListener, SocketError> {
127        // Safety: This is how systemd rolls
128        // See: sd_listen_fds(3), the c API for accessing systemd sockets
129        let file = unsafe { File::from_raw_fd(self.fd) };
130        let metadata = file.metadata()?;
131        if !metadata.file_type().is_socket() {
132            return Err(SocketError::NotSocket(file.into_raw_fd()));
133        }
134
135        //Todo: We could manually check that this is an INET socket
136        // here, so that we don't listen on some arbitrary socket?
137
138        // Safety: Above, we know that the FD is one we should be reading,
139        // and we just checked that the socket was one which is listening
140        // over tcp;
141        let listener = unsafe { TcpListener::from_raw_fd(file.into_raw_fd()) };
142        listener.set_nonblocking(true)?;
143        Ok(listener)
144    }
145}
146
147impl AsRawFd for SystemDSocket {
148    fn as_raw_fd(&self) -> RawFd {
149        self.fd
150    }
151}
152
153impl AsFd for SystemDSocket {
154    fn as_fd(&self) -> BorrowedFd<'_> {
155        unsafe { BorrowedFd::borrow_raw(self.fd) }
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn parse_variables() {
165        let listen_fds = "3";
166        let listen_fd_names = "alice:bob:charlie";
167
168        let sockets = construct_sockets(
169            listen_fds,
170            Some(listen_fd_names),
171            &format!("{}", process::id()),
172        )
173        .unwrap();
174
175        let names: Vec<_> = sockets.iter().map(|s| s.name().unwrap()).collect();
176        assert_eq!(names, vec!["alice", "bob", "charlie"]);
177
178        let fds: Vec<_> = sockets.iter().map(|s| s.fd).collect();
179        assert_eq!(fds, vec![3, 4, 5]);
180    }
181}