1use std::io::{IoSlice, IoSliceMut, Read, Result as SResult, Write};
2use std::os::unix::io::{FromRawFd, IntoRawFd, RawFd};
3use std::os::unix::net::UnixStream;
4use std::{fs, io};
5
6use log::{error, info, trace};
7use nix::cmsg_space;
8use nix::sys::socket::{recvmsg, sendmsg, ControlMessage, ControlMessageOwned, MsgFlags};
9
10use nix::libc;
11use nix::unistd::close;
12
13const PARENT_FD: u16 = 3;
14
15pub struct UnixConnection {
16 fd: RawFd,
17 stream: UnixStream,
18}
19
20impl UnixConnection {
21 pub fn new(fd: RawFd) -> Self {
22 UnixConnection { fd, stream: unsafe { UnixStream::from_raw_fd(fd) } }
23 }
24
25 pub fn try_clone(&self) -> SResult<Self> {
26 Ok(UnixConnection { fd: self.fd, stream: self.stream.try_clone()? })
27 }
28
29 pub(crate) fn id(&self) -> u64 {
30 self.fd as u64
31 }
32
33 pub(crate) fn send_fd(&self, fd: RawFd) {
34 info!("sending fd {}", fd);
35 let fds = [fd];
36 let fd_msg = ControlMessage::ScmRights(&fds);
37 let c = [0xff];
38 let x = IoSlice::new(&c);
39 sendmsg::<()>(self.fd, &[x], &[fd_msg], MsgFlags::empty(), None).unwrap();
40 close(fd).unwrap();
41 }
42
43 pub(crate) fn recv_fd(&self) -> Result<RawFd, ()> {
44 let mut cmsgs = cmsg_space!(RawFd);
45 let mut c = [0];
46 let mut io_slice = [IoSliceMut::new(&mut c)];
47 let result =
48 recvmsg::<()>(self.fd, &mut io_slice, Some(&mut cmsgs), MsgFlags::empty()).unwrap();
49 let mut iter = result.cmsgs();
50 let cmsg = iter.next().ok_or_else(|| {
51 error!("expected a control message");
52 })?;
53 if iter.next().is_some() {
54 error!("expected exactly one control message");
55 return Err(());
56 }
57 match cmsg {
58 ControlMessageOwned::ScmRights(r) =>
59 if r.len() != 1 {
60 error!("expected exactly one fd");
61 Err(())
62 } else {
63 if c[0] != 0xff {
64 error!("expected a 0xff byte ancillary byte, got {}", c[0]);
65 return Err(());
66 }
67 Ok(r[0])
68 },
69 m => {
70 error!("unexpected cmsg {:?}", m);
71 Err(())
72 }
73 }
74 }
75}
76
77impl Read for UnixConnection {
78 fn read(&mut self, dest: &mut [u8]) -> SResult<usize> {
79 let mut cursor = 0;
80 if dest.is_empty() {
81 return Ok(0);
82 }
83 while cursor < dest.len() {
84 let res: io::Result<usize> = self.stream.read(&mut dest[cursor..]);
85 trace!("read {}: {:?} cursor={} expected={}", self.id(), res, cursor, dest.len());
86 match res {
87 Ok(n) => {
88 if n == 0 {
89 return Ok(cursor);
90 }
91 cursor = cursor + n;
92 }
93 Err(e) => return Err(e),
94 }
95 }
96 Ok(cursor)
97 }
98}
99
100impl Write for UnixConnection {
101 fn write(&mut self, buf: &[u8]) -> SResult<usize> {
102 self.stream.write(buf)
103 }
104
105 fn flush(&mut self) -> SResult<()> {
106 self.stream.flush()
107 }
108
109 fn write_all(&mut self, buf: &[u8]) -> SResult<()> {
110 self.stream.write_all(buf)
111 }
112}
113
114pub fn open_parent_fd() -> RawFd {
115 let have_parent = unsafe { libc::fcntl(PARENT_FD as libc::c_int, libc::F_GETFD) } != -1;
118
119 let dummy_file = fs::File::open("/dev/null").unwrap().into_raw_fd();
120
121 let parent_fd = if have_parent {
122 close(dummy_file).expect("close dummy");
123 RawFd::from(PARENT_FD)
124 } else {
125 error!("no parent on {}, using /dev/null", PARENT_FD);
126 dummy_file
127 };
128 parent_fd
129}