1use std::env;
3use std::ffi::{OsString, OsStr};
4use std::os::fd::{AsRawFd, OwnedFd, RawFd};
5use std::os::unix::ffi::OsStrExt;
6use std::os::unix::net::UnixDatagram;
7
8pub struct NotifyFd {
15 fd: OwnedFd,
16 addr: Vec<libc::c_char>,
17}
18
19impl NotifyFd {
21 pub fn new() -> Option<Result<Self, std::io::Error>> {
23 let Some(addr) = env::var_os("NOTIFY_SOCKET") else {
24 return None;
25 };
26
27 Self::from_env(addr).map(Some).transpose()
28 }
29
30 pub fn from_env(name: OsString) -> Result<Self, std::io::Error> {
37 let ty = name.as_encoded_bytes().get(0).cloned();
38
39 let name_bytes = match ty {
40 Some(b'/') => {
41 name.as_encoded_bytes()
42 }
43 Some(b'@') => {
44 &name.as_encoded_bytes()[1..]
45 },
46 _ => return Err(std::io::ErrorKind::Unsupported)?,
47 };
48
49
50 let name = OsStr::from_bytes(name_bytes);
51 let dgram_socket = UnixDatagram::unbound()?;
52 dgram_socket.connect(name)?;
53
54 Ok(NotifyFd {
55 fd: dgram_socket.into(),
56 addr: name_bytes.iter().map(|&b| b as libc::c_char).collect(),
57 })
58 }
59
60 pub fn notify_with_fds(
71 self,
72 state: &str,
73 fds: &[RawFd]
74 ) -> Result<(), std::io::Error> {
75 let mut hdr: libc::msghdr = unsafe { core::mem::zeroed::<libc::msghdr>() };
76 let mut iov: libc::iovec = unsafe { core::mem::zeroed::<libc::iovec>() };
77 let mut addr: libc::sockaddr_un = unsafe { core::mem::zeroed::<libc::sockaddr_un>() };
78
79 iov.iov_base = state.as_ptr() as *mut libc::c_void;
80 iov.iov_len = state.len();
81
82 addr.sun_family = libc::AF_UNIX as libc::c_ushort;
83 let addr_len = addr.sun_path.len().min(self.addr.len());
84 addr.sun_path[..addr_len].copy_from_slice(&self.addr[..addr_len]);
85
86 hdr.msg_iov = &mut iov;
87 hdr.msg_iovlen = 1;
88 hdr.msg_namelen = core::mem::size_of_val(&addr) as libc::c_uint;
89 hdr.msg_name = &mut addr as *mut _ as *mut libc::c_void;
90
91 let len = u32::try_from(core::mem::size_of_val(fds))
93 .expect("user error");
94 let len = if len > 0 {
95 (unsafe { libc::CMSG_SPACE(len) } as usize)
96 } else { 0 };
97
98 let mut buf = vec![0; len];
99
100 hdr.msg_controllen = len;
101 hdr.msg_control = buf.as_mut_ptr() as *mut libc::c_void;
102
103 if len > 0 {
104 let cmsg = unsafe { libc::CMSG_FIRSTHDR(&hdr) };
105 let cmsg = unsafe { &mut *cmsg };
106 let msg_len = core::mem::size_of_val(fds);
107
108 cmsg.cmsg_level = libc::SOL_SOCKET;
109 cmsg.cmsg_type = libc::SCM_RIGHTS;
110 cmsg.cmsg_len = unsafe { libc::CMSG_LEN(msg_len as u32) } as usize;
111
112 assert!(cmsg.cmsg_len >= msg_len);
113 let data = unsafe { libc::CMSG_DATA(cmsg) };
114
115 unsafe {
119 core::ptr::copy_nonoverlapping(
120 fds.as_ptr() as *const _ as *const u8,
121 data,
122 msg_len,
123 );
124 }
125 }
126
127 let sent = unsafe {
128 libc::sendmsg(self.fd.as_raw_fd(), &hdr, libc::MSG_NOSIGNAL)
129 };
130
131 if -1 == sent {
132 return Err(std::io::Error::last_os_error());
133 }
134
135 if sent as usize != state.len() {
136 return Err(std::io::ErrorKind::InvalidData)?;
137 }
138
139 Ok(())
140 }
141}