systemd_connector/
socket.rs1use std::fs::File;
4use std::io;
5use std::net::TcpListener;
6use std::os::unix::prelude::*;
7use std::process;
8
9use thiserror::Error;
10
11const SD_FD_OFFSET: i32 = 3;
12const LISTEN_FDS: &str = "LISTEN_FDS";
13const LISTEN_FDNAMES: &str = "LISTEN_FDNAMES";
14const LISTEN_PID: &str = "LISTEN_PID";
15
16#[derive(Debug, Error)]
18pub enum SocketError {
19 #[error("{}", .0)]
21 IO(#[from] io::Error),
22
23 #[error("PID={0} but ${}={1}", LISTEN_PID)]
25 WrongPID(u32, String),
26
27 #[error("file descriptor {} is not a socket", .0)]
29 NotSocket(RawFd),
30
31 #[error("Missing ${0} variable")]
33 MissingVar(&'static str),
34
35 #[error("Invalid ${0}={1}")]
37 InvalidVar(&'static str, String),
38}
39
40pub(crate) fn var(name: &'static str) -> Result<String, SocketError> {
41 match std::env::var(name) {
42 Ok(value) => Ok(value),
43 Err(std::env::VarError::NotPresent) => Err(SocketError::MissingVar(name)),
44 Err(std::env::VarError::NotUnicode(_)) => Err(SocketError::MissingVar(name)),
45 }
46}
47
48pub fn sockets() -> Result<Vec<SystemDSocket>, SocketError> {
50 let listen_pid = var(LISTEN_PID);
51 let listen_fds = var(LISTEN_FDS);
52 let listen_fd_names = var(LISTEN_FDNAMES).ok();
53
54 construct_sockets(
55 listen_fds?.as_str(),
56 listen_fd_names.as_deref(),
57 listen_pid?.as_str(),
58 )
59}
60
61fn construct_sockets(
62 listen_fds: &str,
63 listen_fd_names: Option<&str>,
64 listen_pid: &str,
65) -> Result<Vec<SystemDSocket>, SocketError> {
66 let pid = listen_pid
67 .parse::<u32>()
68 .map_err(|_| SocketError::InvalidVar(LISTEN_PID, listen_pid.into()))?;
69
70 if process::id() != pid {
71 return Err(SocketError::WrongPID(process::id(), listen_pid.into()));
72 }
73
74 let n = listen_fds
75 .parse::<usize>()
76 .map_err(|_| SocketError::InvalidVar(LISTEN_FDS, listen_fds.into()))?;
77
78 if let Some(names_value) = listen_fd_names {
79 let names: Vec<_> = names_value.split(':').collect();
80
81 if names.len() == n {
82 return Ok((SD_FD_OFFSET..)
83 .take(n)
84 .zip(names)
85 .map(|(fd, name)| SystemDSocket::new(name, fd))
86 .collect());
87 } else if !names.is_empty() {
88 tracing::warn!("Invalid ${}={}", LISTEN_FDNAMES, names_value);
89 };
90 };
91
92 Ok((SD_FD_OFFSET..)
93 .take(n)
94 .map(SystemDSocket::unnamed)
95 .collect())
96}
97
98#[derive(Debug)]
100pub struct SystemDSocket {
101 name: Option<String>,
102 fd: RawFd,
103}
104
105impl SystemDSocket {
106 fn new<S: Into<String>>(name: S, fd: RawFd) -> Self {
107 Self {
108 name: Some(name.into()),
109 fd,
110 }
111 }
112
113 fn unnamed(fd: RawFd) -> Self {
114 Self { name: None, fd }
115 }
116
117 pub fn name(&self) -> Option<&str> {
122 self.name.as_deref()
123 }
124
125 pub fn listener(self) -> Result<TcpListener, SocketError> {
127 let file = unsafe { File::from_raw_fd(self.fd) };
130 let metadata = file.metadata()?;
131 if !metadata.file_type().is_socket() {
132 return Err(SocketError::NotSocket(file.into_raw_fd()));
133 }
134
135 let listener = unsafe { TcpListener::from_raw_fd(file.into_raw_fd()) };
142 listener.set_nonblocking(true)?;
143 Ok(listener)
144 }
145}
146
147impl AsRawFd for SystemDSocket {
148 fn as_raw_fd(&self) -> RawFd {
149 self.fd
150 }
151}
152
153impl AsFd for SystemDSocket {
154 fn as_fd(&self) -> BorrowedFd<'_> {
155 unsafe { BorrowedFd::borrow_raw(self.fd) }
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn parse_variables() {
165 let listen_fds = "3";
166 let listen_fd_names = "alice:bob:charlie";
167
168 let sockets = construct_sockets(
169 listen_fds,
170 Some(listen_fd_names),
171 &format!("{}", process::id()),
172 )
173 .unwrap();
174
175 let names: Vec<_> = sockets.iter().map(|s| s.name().unwrap()).collect();
176 assert_eq!(names, vec!["alice", "bob", "charlie"]);
177
178 let fds: Vec<_> = sockets.iter().map(|s| s.fd).collect();
179 assert_eq!(fds, vec![3, 4, 5]);
180 }
181}