tiny_nix_ipc/
lib.rs

1extern crate nix;
2#[macro_use]
3extern crate error_chain;
4#[cfg(any(feature = "ser_cbor", feature = "ser_json", feature = "ser_bincode"))]
5extern crate serde;
6#[cfg(feature = "ser_cbor")]
7extern crate serde_cbor;
8#[cfg(feature = "ser_json")]
9extern crate serde_json;
10#[cfg(feature = "ser_bincode")]
11extern crate bincode;
12#[cfg(feature = "zero_copy")]
13#[macro_use]
14extern crate zerocopy;
15
16use std::{mem, ptr, slice};
17use std::io::{IoSliceMut, IoSlice};
18use std::os::unix::io::{RawFd, FromRawFd, IntoRawFd, AsRawFd};
19use nix::{unistd, cmsg_space};
20use nix::fcntl::{self, FdFlag, FcntlArg};
21use nix::sys::socket::{
22    recvmsg, sendmsg, ControlMessageOwned, ControlMessage, MsgFlags,
23    socketpair, AddressFamily, SockFlag, SockType, UnixAddr,
24};
25
26pub mod errors {
27    error_chain!{
28        foreign_links {
29            Nix(::nix::Error);
30            Cbor(::serde_cbor::error::Error) #[cfg(feature = "ser_cbor")];
31            Json(::serde_json::Error) #[cfg(feature = "ser_json")];
32            Bincode(::bincode::Error) #[cfg(feature = "ser_bincode")];
33        }
34
35        errors {
36            WrongRecvLength {
37                description("length of received message doesn't match the struct size or received length")
38            }
39        }
40    }
41}
42
43use errors::*;
44
45pub struct Socket {
46    fd: RawFd,
47}
48
49impl FromRawFd for Socket {
50    unsafe fn from_raw_fd(fd: RawFd) -> Socket {
51        Socket {
52            fd,
53        }
54    }
55}
56
57impl IntoRawFd for Socket {
58    fn into_raw_fd(self) -> RawFd {
59        let fd = self.fd;
60        std::mem::forget(self);
61        fd
62    }
63}
64
65impl AsRawFd for Socket {
66    fn as_raw_fd(&self) -> RawFd {
67        self.fd
68    }
69}
70
71impl Socket {
72    /// Creates a socket pair (AF_UNIX/SOCK_SEQPACKET).
73    ///
74    /// Both sockets are close-on-exec by default.
75    pub fn new_socketpair() -> Result<(Socket, Socket)> {
76        socketpair(AddressFamily::Unix, SockType::SeqPacket, None, SockFlag::SOCK_CLOEXEC).map(|(a, b)| {
77            unsafe { (Self::from_raw_fd(a), Self::from_raw_fd(b)) }
78        }).map_err(|e| e.into())
79    }
80
81    /// Disables close-on-exec on the socket (to preserve it across process forks).
82    pub fn no_cloexec(&mut self) -> Result<()> {
83        fcntl::fcntl(self.fd, FcntlArg::F_SETFD(FdFlag::empty())).map(|_| ()).map_err(|e| e.into())
84    }
85
86    /// Reads bytes from the socket into the given scatter/gather array.
87    ///
88    /// If file descriptors were passed, returns them too.
89    /// To receive file descriptors, you need to instantiate the type parameter `F`
90    /// as `[RawFd; n]`, where `n` is the number of descriptors you want to receive.
91    ///
92    /// Received file descriptors are set close-on-exec.
93    pub fn recv_into_iovec<F: Default + AsMut<[RawFd]>>(&mut self, iov: &mut [IoSliceMut]) -> Result<(usize, Option<F>)> {
94        let mut rfds = None;
95        let mut cmsgspace = cmsg_space!(F);
96        let msg = recvmsg::<UnixAddr>(self.fd, iov, Some(&mut cmsgspace), MsgFlags::MSG_CMSG_CLOEXEC)?;
97        for cmsg in msg.cmsgs() {
98            if let ControlMessageOwned::ScmRights(fds) = cmsg {
99                if fds.len() >= 1 {
100                    let mut fd_arr: F = Default::default();
101                    <F as AsMut<[RawFd]>>::as_mut(&mut fd_arr).clone_from_slice(&fds);
102                    rfds = Some(fd_arr);
103                }
104            }
105        }
106        Ok((msg.bytes, rfds))
107    }
108
109    /// Reads bytes from the socket into the given buffer.
110    ///
111    /// If file descriptors were passed, returns them too.
112    /// To receive file descriptors, you need to instantiate the type parameter `F`
113    /// as `[RawFd; n]`, where `n` is the number of descriptors you want to receive.
114    ///
115    /// Received file descriptors are set close-on-exec.
116    pub fn recv_into_slice<F: Default + AsMut<[RawFd]>>(&mut self, buf: &mut [u8]) -> Result<(usize, Option<F>)> {
117        let mut iov = [IoSliceMut::new(&mut buf[..])];
118        self.recv_into_iovec(&mut iov)
119    }
120
121    /// Reads bytes from the socket into a new buffer.
122    ///
123    /// If file descriptors were passed, returns them too.
124    /// To receive file descriptors, you need to instantiate the type parameter `F`
125    /// as `[RawFd; n]`, where `n` is the number of descriptors you want to receive.
126    ///
127    /// Received file descriptors are set close-on-exec.
128    pub fn recv_into_buf<F: Default + AsMut<[RawFd]>>(&mut self, buf_size: usize) -> Result<(usize, Vec<u8>, Option<F>)> {
129        let mut buf = vec![0u8; buf_size];
130        let (bytes, rfds) = {
131            let mut iov = [IoSliceMut::new(&mut buf[..])];
132            self.recv_into_iovec(&mut iov)?
133        };
134        Ok((bytes, buf, rfds))
135    }
136
137    /// Reads bytes from the socket into a new buffer, also reading the first 64 bits as length.
138    /// The resulting buffer is truncated to that length.
139    ///
140    /// If file descriptors were passed, returns them too.
141    /// To receive file descriptors, you need to instantiate the type parameter `F`
142    /// as `[RawFd; n]`, where `n` is the number of descriptors you want to receive.
143    ///
144    /// Received file descriptors are set close-on-exec.
145    pub fn recv_into_buf_with_len<F: Default + AsMut<[RawFd]>>(&mut self, buf_size: usize) -> Result<(usize, Vec<u8>, u64, Option<F>)> {
146        let mut len: u64 = 0;
147        let mut buf = vec![0u8; buf_size];
148        let (bytes, rfds) = {
149            let mut iov = [
150                IoSliceMut::new(unsafe { slice::from_raw_parts_mut((&mut len as *mut u64) as *mut u8, mem::size_of::<u64>()) }),
151                IoSliceMut::new(&mut buf[..]),
152            ];
153            self.recv_into_iovec(&mut iov)?
154        };
155        buf.truncate(len as usize);
156        Ok((bytes, buf, len, rfds))
157    }
158
159
160    /// See `recv_struct` for docs
161    ///
162    /// # Safety
163    /// - For some types (e.g.), not every bit pattern is allowed. If bytes, read from socket
164    /// aren't correct, that's UB.
165    /// - Some types mustn't change their memory location (see `std::pin::Pin`). Sending object of
166    /// such a type is UB.
167    pub unsafe fn recv_struct_raw<T, F: Default + AsMut<[RawFd]>>(&mut self) -> Result<(T, Option<F>)> {
168        let (bytes, buf, rfds) = self.recv_into_buf(mem::size_of::<T>())?;
169        if bytes != mem::size_of::<T>() {
170            bail!(ErrorKind::WrongRecvLength);
171        }
172        Ok((ptr::read(buf.as_slice().as_ptr() as *const _), rfds))
173    }
174    
175    /// Reads bytes from the socket and interprets them as a given data type.
176    /// If the size does not match, returns `WrongRecvLength`..
177    ///
178    /// If file descriptors were passed, returns them too.
179    /// To receive file descriptors, you need to instantiate the type parameter `F`
180    /// as `[RawFd; n]`, where `n` is the number of descriptors you want to receive.
181    ///
182    /// Received file descriptors are set close-on-exec.
183    #[cfg(feature = "zero_copy")]
184    pub fn recv_struct<T: zerocopy::FromBytes, F: Default + AsMut<[RawFd]>>(&mut self) -> Result<(T, Option<F>)> {
185        unsafe {
186            self.recv_struct_raw()
187        }
188    }
189
190    /// Reads bytes from the socket and deserializes them as a given data type using CBOR.
191    /// If the size does not match, returns `WrongRecvLength`.
192    ///
193    /// You have to provide a size for the receive buffer.
194    /// It should be large enough for the data you want to receive plus 64 bits for the length.
195    ///
196    /// If file descriptors were passed, returns them too.
197    /// To receive file descriptors, you need to instantiate the type parameter `F`
198    /// as `[RawFd; n]`, where `n` is the number of descriptors you want to receive.
199    ///
200    /// Received file descriptors are set close-on-exec.
201    #[cfg(feature = "ser_cbor")]
202    pub fn recv_cbor<T: serde::de::DeserializeOwned, F: Default + AsMut<[RawFd]>>(&mut self, buf_size: usize) -> Result<(T, Option<F>)> {
203        let (bytes, buf, len, rfds) = self.recv_into_buf_with_len(buf_size)?;
204        if bytes != len as usize + mem::size_of::<u64>() {
205            bail!(ErrorKind::WrongRecvLength);
206        }
207        Ok((serde_cbor::from_slice(&buf[..])?, rfds))
208    }
209
210    /// Reads bytes from the socket and deserializes them as a given data type using JSON.
211    /// If the size does not match, returns `WrongRecvLength`.
212    ///
213    /// You have to provide a size for the receive buffer.
214    /// It should be large enough for the data you want to receive plus 64 bits for the length.
215    ///
216    /// If file descriptors were passed, returns them too.
217    /// To receive file descriptors, you need to instantiate the type parameter `F`
218    /// as `[RawFd; n]`, where `n` is the number of descriptors you want to receive.
219    ///
220    /// Received file descriptors are set close-on-exec.
221    #[cfg(feature = "ser_json")]
222    pub fn recv_json<T: serde::de::DeserializeOwned, F: Default + AsMut<[RawFd]>>(&mut self, buf_size: usize) -> Result<(T, Option<F>)> {
223        let (bytes, buf, len, rfds) = self.recv_into_buf_with_len(buf_size)?;
224        if bytes != len as usize + mem::size_of::<u64>() {
225            bail!(ErrorKind::WrongRecvLength);
226        }
227        Ok((serde_json::from_slice(&buf[..])?, rfds))
228    }
229
230    /// Reads bytes from the socket and deserializes them as a given data type using Bincode.
231    /// If the size does not match, returns `WrongRecvLength`.
232    ///
233    /// You have to provide a size for the receive buffer.
234    /// It should be large enough for the data you want to receive plus 64 bits for the length.
235    ///
236    /// If file descriptors were passed, returns them too.
237    /// To receive file descriptors, you need to instantiate the type parameter `F`
238    /// as `[RawFd; n]`, where `n` is the number of descriptors you want to receive.
239    ///
240    /// Received file descriptors are set close-on-exec.
241    #[cfg(feature = "ser_bincode")]
242    pub fn recv_bincode<T: serde::de::DeserializeOwned, F: Default + AsMut<[RawFd]>>(&mut self, buf_size: usize) -> Result<(T, Option<F>)> {
243        let (bytes, buf, len, rfds) = self.recv_into_buf_with_len(buf_size)?;
244        if bytes != len as usize + mem::size_of::<u64>() {
245            bail!(ErrorKind::WrongRecvLength);
246        }
247        Ok((bincode::deserialize(&buf[..])?, rfds))
248    }
249
250    /// Sends bytes from scatter-gather vectors over the socket.
251    ///
252    /// Optionally passes file descriptors with the message.
253    pub fn send_iovec(&mut self, iov: &[IoSlice], fds: Option<&[RawFd]>) -> Result<usize> {
254        if let Some(rfds) = fds {
255            sendmsg::<UnixAddr>(self.fd, iov, &[ControlMessage::ScmRights(rfds)], MsgFlags::empty(), None).map_err(|e| e.into())
256        } else {
257            sendmsg::<UnixAddr>(self.fd, iov, &[], MsgFlags::empty(), None).map_err(|e| e.into())
258        }
259    }
260
261    /// Sends bytes from a slice over the socket.
262    ///
263    /// Optionally passes file descriptors with the message.
264    pub fn send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize> {
265        let iov = [IoSlice::new(data)];
266        self.send_iovec(&iov[..], fds)
267    }
268
269    /// Sends bytes from a slice over the socket, prefixing with the length
270    /// (as a 64-bit unsigned integer).
271    ///
272    /// Optionally passes file descriptors with the message.
273    pub fn send_slice_with_len(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize> {
274        let len = data.len() as u64;
275        let iov = [IoSlice::new(unsafe { slice::from_raw_parts((&len as *const u64) as *const u8, mem::size_of::<u64>()) }), IoSlice::new(data)];
276        self.send_iovec(&iov[..], fds)
277    }
278
279    /// See `send_struct` for docs.
280    ///
281    /// # Safety
282    /// - T must not have padding bytes.
283    /// - Also, if T violates `recv_struct_raw` safety preconditions, receiving it will trigger
284    /// undefined behavior.
285    pub unsafe fn send_struct_raw<T>(&mut self, data: &T, fds: Option<&[RawFd]>) -> Result<usize> {
286        self.send_slice(slice::from_raw_parts((data as *const T) as *const u8, mem::size_of::<T>()), fds)
287    }
288
289    /// Sends a value of any type as its raw bytes over the socket.
290    /// (Do not use with types that contain pointers, references, boxes, etc.!
291    ///  Use serialization in that case!)
292    ///
293    /// Optionally passes file descriptors with the message.
294    #[cfg(feature = "zero_copy")]
295    pub fn send_struct<T: zerocopy::AsBytes>(&mut self, data: &T, fds: Option<&[RawFd]>) -> Result<usize> {
296        unsafe {
297            self.send_struct_raw(data, fds)
298        }
299    }
300
301
302
303    /// Serializes a value with CBOR and sends it over the socket.
304    ///
305    /// Optionally passes file descriptors with the message.
306    #[cfg(feature = "ser_cbor")]
307    pub fn send_cbor<T: serde::ser::Serialize>(&mut self, data: &T, fds: Option<&[RawFd]>) -> Result<usize> {
308        let bytes = serde_cbor::to_vec(data)?;
309        self.send_slice_with_len(&bytes[..], fds)
310    }
311
312    /// Serializes a value with JSON and sends it over the socket.
313    ///
314    /// Optionally passes file descriptors with the message.
315    #[cfg(feature = "ser_json")]
316    pub fn send_json<T: serde::ser::Serialize>(&mut self, data: &T, fds: Option<&[RawFd]>) -> Result<usize> {
317        let bytes = serde_json::to_vec(data)?;
318        self.send_slice_with_len(&bytes[..], fds)
319    }
320
321    /// Serializes a value with Bincode and sends it over the socket.
322    ///
323    /// Optionally passes file descriptors with the message.
324    #[cfg(feature = "ser_bincode")]
325    pub fn send_bincode<T: serde::ser::Serialize>(&mut self, data: &T, fds: Option<&[RawFd]>) -> Result<usize> {
326        let bytes = bincode::serialize(data)?;
327        self.send_slice_with_len(&bytes[..], fds)
328    }
329}
330
331impl Drop for Socket {
332    fn drop(&mut self) {
333        let _ = unistd::close(self.fd);
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    extern crate shmemfdrs;
340    use super::Socket;
341    use std::os::unix::io::RawFd;
342    #[cfg(feature = "zero_copy")]
343    use zerocopy::AsBytes;
344
345    #[test]
346    fn test_slice_success() {
347        let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
348        let data = [0xDE, 0xAD, 0xBE, 0xEF];
349        let sent = tx.send_slice(&data[..], None).unwrap();
350        assert_eq!(sent, 4);
351        let mut rdata = [0; 4];
352        let (recvd, rfds) = rx.recv_into_slice::<[RawFd; 0]>(&mut rdata[..]).unwrap();
353        assert_eq!(recvd, 4);
354        assert_eq!(rfds, None);
355        assert_eq!(&rdata[..], &data[..]);
356    }
357
358    #[test]
359    fn test_slice_buf_too_short() {
360        let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
361        let data = [0xDE, 0xAD, 0xBE, 0xEF];
362        let sent = tx.send_slice(&data[..], None).unwrap();
363        assert_eq!(sent, 4);
364        let mut rdata = [0; 3];
365        let (recvd, rfds) = rx.recv_into_slice::<[RawFd; 0]>(&mut rdata[..]).unwrap();
366        assert_eq!(recvd, 3);
367        assert_eq!(rfds, None);
368        assert_eq!(&rdata[..], &data[0..3]);
369    }
370
371    #[test]
372    fn test_slice_with_len_success() {
373        let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
374        let data = [0xDE, 0xAD, 0xBE, 0xEF];
375        let sent = tx.send_slice_with_len(&data[..], None).unwrap();
376        assert_eq!(sent, 12); // 4 + 8 (bytes in a 64-bit number)
377        let mut rdata = [0; 12];
378        let (recvd, rfds) = rx.recv_into_slice::<[RawFd; 0]>(&mut rdata[..]).unwrap();
379        assert_eq!(recvd, 12);
380        assert_eq!(rfds, None);
381        assert_eq!(rdata[0], 4);
382        assert_eq!(&rdata[8..], &data[..]);
383    }
384
385    #[cfg(feature = "zero_copy")]
386    #[derive(Debug, PartialEq, FromBytes, AsBytes)]
387    #[repr(C)]
388    struct TestStruct {
389        one: i8,
390        // Note an explicit padding bytes here
391        // Without it, `send_struct` would read real compiler-provided padding, which is UB
392        pad: [u8; 3],
393        two: u32,
394    }
395
396    #[test]
397    #[cfg(feature = "zero_copy")]
398    fn test_struct_success() {
399        let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
400        let data = TestStruct { one: -64, two: 0xDEADBEEF, pad: [0, 0, 0]};
401        let _ = tx.send_struct(&data, None).unwrap();
402        let (rdata, rfds) = rx.recv_struct::<TestStruct, [RawFd; 0]>().unwrap();
403        assert_eq!(rfds, None);
404        assert_eq!(rdata, data);
405    }
406
407    #[test]
408    #[cfg(feature = "zero_copy")]
409    fn test_struct_wrong_len() {
410        let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
411        let data = [0xDE, 0xAD, 0xBE, 0xEF];
412        let sent = tx.send_slice(&data[..], None).unwrap();
413        assert_eq!(sent, 4);
414        let ret = rx.recv_struct::<TestStruct, [RawFd; 0]>();
415        assert!(ret.is_err());
416    }
417
418    #[test]
419    fn test_fd_passing() {
420        use std::fs::File;
421        use std::io::{Read, Write, Seek, SeekFrom};
422        use std::os::unix::io::FromRawFd;
423        use std::ffi::CString;
424        use std::mem::ManuallyDrop;
425        let fd = shmemfdrs::create_shmem(CString::new("/test").unwrap(), 6);
426        let mut orig_file = {
427            let mut file = unsafe { File::from_raw_fd(fd) };
428            file.write_all(b"hello\n").unwrap();
429            ManuallyDrop::new(file) // do not destroy the actual file before it's read
430        };
431        let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
432        let data = [0xDE, 0xAD, 0xBE, 0xEF];
433        let sent = tx.send_slice(&data[..], Some(&[fd])).unwrap();
434        assert_eq!(sent, 4);
435        let mut rdata = [0; 4];
436        let (recvd, rfds) = rx.recv_into_slice::<[RawFd; 1]>(&mut rdata[..]).unwrap();
437        assert_eq!(recvd, 4);
438        assert_eq!(&rdata[..], &data[..]);
439        let new_fd = rfds.unwrap()[0];
440        {
441            let mut file = unsafe { File::from_raw_fd(new_fd) };
442            let mut content = String::new();
443            file.seek(SeekFrom::Start(0)).unwrap();
444            file.read_to_string(&mut content).unwrap();
445            assert_eq!(content, "hello\n");
446        }
447        unsafe { ManuallyDrop::drop(&mut orig_file); }
448    }
449
450    #[test]
451    #[cfg(feature = "ser_cbor")]
452    fn test_cbor() {
453        use serde_cbor::value::Value;
454        let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
455        let data = Value::Integer(123456);
456        let _ = tx.send_cbor(&data, None).unwrap();
457        let (rdata, rfds) = rx.recv_cbor::<Value, [RawFd; 0]>(24).unwrap();
458        assert_eq!(rfds, None);
459        assert_eq!(rdata, data);
460    }
461
462    #[test]
463    #[cfg(feature = "ser_json")]
464    fn test_json() {
465        use serde_json::value::Value;
466        let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
467        let data = Value::String("hi".to_owned());
468        let _ = tx.send_json(&data, None).unwrap();
469        let (rdata, rfds) = rx.recv_json::<Value, [RawFd; 0]>(24).unwrap();
470        assert_eq!(rfds, None);
471        assert_eq!(rdata, data);
472    }
473
474    #[test]
475    #[cfg(feature = "ser_bincode")]
476    fn test_bincode() {
477        let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
478        let data = Some("hello world".to_string());
479        let _ = tx.send_bincode(&data, None).unwrap();
480        let (rdata, rfds) = rx.recv_bincode::<Option<String>, [RawFd; 0]>(24).unwrap();
481        assert_eq!(rfds, None);
482        assert_eq!(rdata, data);
483    }
484}