Skip to main content

libcontainer/
channel.rs

1use std::io::{IoSlice, IoSliceMut};
2use std::marker::PhantomData;
3use std::os::fd::AsRawFd;
4use std::os::unix::prelude::RawFd;
5
6use nix::sys::socket::{self, UnixAddr};
7use nix::unistd::{self};
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, thiserror::Error)]
11pub enum ChannelError {
12    #[error("failed unix syscalls")]
13    Nix(#[from] nix::Error),
14    #[error("failed serde serialization")]
15    Serde(#[from] serde_json::Error),
16    #[error("channel connection broken")]
17    BrokenChannel,
18}
19pub struct Receiver<T> {
20    receiver: RawFd,
21    phantom: PhantomData<T>,
22}
23
24pub struct Sender<T> {
25    sender: RawFd,
26    phantom: PhantomData<T>,
27}
28
29impl<T> Sender<T>
30where
31    T: Serialize,
32{
33    fn send_iovec(
34        &mut self,
35        iov: &[IoSlice],
36        fds: Option<&[RawFd]>,
37    ) -> Result<usize, ChannelError> {
38        let cmsgs = if let Some(fds) = fds {
39            vec![socket::ControlMessage::ScmRights(fds)]
40        } else {
41            vec![]
42        };
43        socket::sendmsg::<UnixAddr>(self.sender, iov, &cmsgs, socket::MsgFlags::empty(), None)
44            .map_err(|e| e.into())
45    }
46
47    fn send_slice_with_len(
48        &mut self,
49        data: &[u8],
50        fds: Option<&[RawFd]>,
51    ) -> Result<usize, ChannelError> {
52        let len = data.len() as u64;
53        // Here we prefix the length of the data onto the serialized data.
54        let iov = [
55            IoSlice::new(unsafe {
56                std::slice::from_raw_parts(
57                    (&len as *const u64) as *const u8,
58                    std::mem::size_of::<u64>(),
59                )
60            }),
61            IoSlice::new(data),
62        ];
63        self.send_iovec(&iov[..], fds)
64    }
65
66    pub fn send(&mut self, object: T) -> Result<(), ChannelError> {
67        let payload = serde_json::to_vec(&object)?;
68        self.send_slice_with_len(&payload, None)?;
69
70        Ok(())
71    }
72
73    pub fn send_fds(&mut self, object: T, fds: &[RawFd]) -> Result<(), ChannelError> {
74        let payload = serde_json::to_vec(&object)?;
75        self.send_slice_with_len(&payload, Some(fds))?;
76
77        Ok(())
78    }
79
80    pub fn close(&self) -> Result<(), ChannelError> {
81        Ok(unistd::close(self.sender)?)
82    }
83}
84
85impl<T> Receiver<T>
86where
87    T: serde::de::DeserializeOwned,
88{
89    fn peek_size_iovec(&mut self) -> Result<u64, ChannelError> {
90        let mut len: u64 = 0;
91        let mut iov = [IoSliceMut::new(unsafe {
92            std::slice::from_raw_parts_mut(
93                (&mut len as *mut u64) as *mut u8,
94                std::mem::size_of::<u64>(),
95            )
96        })];
97        let _ =
98            socket::recvmsg::<UnixAddr>(self.receiver, &mut iov, None, socket::MsgFlags::MSG_PEEK)?;
99        match len {
100            0 => Err(ChannelError::BrokenChannel),
101            _ => Ok(len),
102        }
103    }
104
105    fn recv_into_iovec<F>(
106        &mut self,
107        iov: &mut [IoSliceMut],
108    ) -> Result<(usize, Option<F>), ChannelError>
109    where
110        F: Default + AsMut<[RawFd]>,
111    {
112        let mut cmsgspace = nix::cmsg_space!(F);
113        let msg = socket::recvmsg::<UnixAddr>(
114            self.receiver,
115            iov,
116            Some(&mut cmsgspace),
117            socket::MsgFlags::MSG_CMSG_CLOEXEC,
118        )?;
119
120        // Sending multiple SCM_RIGHTS message will led to platform dependent
121        // behavior, with some system choose to return EINVAL when sending or
122        // silently only process the first msg or send all of it. Here we assume
123        // there is only one SCM_RIGHTS message and will only process the first
124        // message.
125        let fds: Option<F> = msg
126            .cmsgs()?
127            .find_map(|cmsg| {
128                if let socket::ControlMessageOwned::ScmRights(fds) = cmsg {
129                    Some(fds)
130                } else {
131                    None
132                }
133            })
134            .map(|fds| {
135                let mut fds_array: F = Default::default();
136                <F as AsMut<[RawFd]>>::as_mut(&mut fds_array).clone_from_slice(&fds);
137                fds_array
138            });
139
140        Ok((msg.bytes, fds))
141    }
142
143    fn recv_into_buf_with_len<F>(&mut self) -> Result<(Vec<u8>, Option<F>), ChannelError>
144    where
145        F: Default + AsMut<[RawFd]>,
146    {
147        let msg_len = self.peek_size_iovec()?;
148        let mut len: u64 = 0;
149        let mut buf = vec![0u8; msg_len as usize];
150        let (bytes, fds) = {
151            let mut iov = [
152                IoSliceMut::new(unsafe {
153                    std::slice::from_raw_parts_mut(
154                        (&mut len as *mut u64) as *mut u8,
155                        std::mem::size_of::<u64>(),
156                    )
157                }),
158                IoSliceMut::new(&mut buf),
159            ];
160            self.recv_into_iovec(&mut iov)?
161        };
162
163        match bytes {
164            0 => Err(ChannelError::BrokenChannel),
165            _ => Ok((buf, fds)),
166        }
167    }
168
169    // Recv the next message of type T.
170    pub fn recv(&mut self) -> Result<T, ChannelError> {
171        let (buf, _) = self.recv_into_buf_with_len::<[RawFd; 0]>()?;
172        Ok(serde_json::from_slice(&buf[..])?)
173    }
174
175    // Works similar to `recv`, but will look for fds sent by SCM_RIGHTS
176    // message.  We use F as as `[RawFd; n]`, where `n` is the number of
177    // descriptors you want to receive.
178    pub fn recv_with_fds<F>(&mut self) -> Result<(T, Option<F>), ChannelError>
179    where
180        F: Default + AsMut<[RawFd]>,
181    {
182        let (buf, fds) = self.recv_into_buf_with_len::<F>()?;
183        Ok((serde_json::from_slice(&buf[..])?, fds))
184    }
185
186    pub fn close(&self) -> Result<(), ChannelError> {
187        Ok(unistd::close(self.receiver)?)
188    }
189}
190
191pub fn channel<T>() -> Result<(Sender<T>, Receiver<T>), ChannelError>
192where
193    T: for<'de> Deserialize<'de> + Serialize,
194{
195    let (os_sender, os_receiver) = unix_channel()?;
196    let receiver = Receiver {
197        receiver: os_receiver,
198        phantom: PhantomData,
199    };
200    let sender = Sender {
201        sender: os_sender,
202        phantom: PhantomData,
203    };
204    Ok((sender, receiver))
205}
206
207// Use socketpair as the underlying pipe.
208fn unix_channel() -> Result<(RawFd, RawFd), ChannelError> {
209    let (f1, f2) = socket::socketpair(
210        socket::AddressFamily::Unix,
211        socket::SockType::SeqPacket,
212        None,
213        socket::SockFlag::SOCK_CLOEXEC,
214    )?;
215    // It is not straightforward to share the OwnedFd across forks, so we
216    // treat them as i32. We use ManuallyDrop to keep the connection open.
217    let f1 = std::mem::ManuallyDrop::new(f1);
218    let f2 = std::mem::ManuallyDrop::new(f2);
219
220    Ok((f1.as_raw_fd(), f2.as_raw_fd()))
221}