1extern crate nix;
2#[macro_use]
3extern crate error_chain;
4#[cfg(any(feature = "ser_cbor", feature = "ser_json", feature = "ser_bincode"))]
5extern crate serde;
6#[cfg(feature = "ser_cbor")]
7extern crate serde_cbor;
8#[cfg(feature = "ser_json")]
9extern crate serde_json;
10#[cfg(feature = "ser_bincode")]
11extern crate bincode;
12#[cfg(feature = "zero_copy")]
13#[macro_use]
14extern crate zerocopy;
15
16use std::{mem, ptr, slice};
17use std::io::{IoSliceMut, IoSlice};
18use std::os::unix::io::{RawFd, FromRawFd, IntoRawFd, AsRawFd};
19use nix::{unistd, cmsg_space};
20use nix::fcntl::{self, FdFlag, FcntlArg};
21use nix::sys::socket::{
22 recvmsg, sendmsg, ControlMessageOwned, ControlMessage, MsgFlags,
23 socketpair, AddressFamily, SockFlag, SockType, UnixAddr,
24};
25
26pub mod errors {
27 error_chain!{
28 foreign_links {
29 Nix(::nix::Error);
30 Cbor(::serde_cbor::error::Error) #[cfg(feature = "ser_cbor")];
31 Json(::serde_json::Error) #[cfg(feature = "ser_json")];
32 Bincode(::bincode::Error) #[cfg(feature = "ser_bincode")];
33 }
34
35 errors {
36 WrongRecvLength {
37 description("length of received message doesn't match the struct size or received length")
38 }
39 }
40 }
41}
42
43use errors::*;
44
45pub struct Socket {
46 fd: RawFd,
47}
48
49impl FromRawFd for Socket {
50 unsafe fn from_raw_fd(fd: RawFd) -> Socket {
51 Socket {
52 fd,
53 }
54 }
55}
56
57impl IntoRawFd for Socket {
58 fn into_raw_fd(self) -> RawFd {
59 let fd = self.fd;
60 std::mem::forget(self);
61 fd
62 }
63}
64
65impl AsRawFd for Socket {
66 fn as_raw_fd(&self) -> RawFd {
67 self.fd
68 }
69}
70
71impl Socket {
72 pub fn new_socketpair() -> Result<(Socket, Socket)> {
76 socketpair(AddressFamily::Unix, SockType::SeqPacket, None, SockFlag::SOCK_CLOEXEC).map(|(a, b)| {
77 unsafe { (Self::from_raw_fd(a), Self::from_raw_fd(b)) }
78 }).map_err(|e| e.into())
79 }
80
81 pub fn no_cloexec(&mut self) -> Result<()> {
83 fcntl::fcntl(self.fd, FcntlArg::F_SETFD(FdFlag::empty())).map(|_| ()).map_err(|e| e.into())
84 }
85
86 pub fn recv_into_iovec<F: Default + AsMut<[RawFd]>>(&mut self, iov: &mut [IoSliceMut]) -> Result<(usize, Option<F>)> {
94 let mut rfds = None;
95 let mut cmsgspace = cmsg_space!(F);
96 let msg = recvmsg::<UnixAddr>(self.fd, iov, Some(&mut cmsgspace), MsgFlags::MSG_CMSG_CLOEXEC)?;
97 for cmsg in msg.cmsgs() {
98 if let ControlMessageOwned::ScmRights(fds) = cmsg {
99 if fds.len() >= 1 {
100 let mut fd_arr: F = Default::default();
101 <F as AsMut<[RawFd]>>::as_mut(&mut fd_arr).clone_from_slice(&fds);
102 rfds = Some(fd_arr);
103 }
104 }
105 }
106 Ok((msg.bytes, rfds))
107 }
108
109 pub fn recv_into_slice<F: Default + AsMut<[RawFd]>>(&mut self, buf: &mut [u8]) -> Result<(usize, Option<F>)> {
117 let mut iov = [IoSliceMut::new(&mut buf[..])];
118 self.recv_into_iovec(&mut iov)
119 }
120
121 pub fn recv_into_buf<F: Default + AsMut<[RawFd]>>(&mut self, buf_size: usize) -> Result<(usize, Vec<u8>, Option<F>)> {
129 let mut buf = vec![0u8; buf_size];
130 let (bytes, rfds) = {
131 let mut iov = [IoSliceMut::new(&mut buf[..])];
132 self.recv_into_iovec(&mut iov)?
133 };
134 Ok((bytes, buf, rfds))
135 }
136
137 pub fn recv_into_buf_with_len<F: Default + AsMut<[RawFd]>>(&mut self, buf_size: usize) -> Result<(usize, Vec<u8>, u64, Option<F>)> {
146 let mut len: u64 = 0;
147 let mut buf = vec![0u8; buf_size];
148 let (bytes, rfds) = {
149 let mut iov = [
150 IoSliceMut::new(unsafe { slice::from_raw_parts_mut((&mut len as *mut u64) as *mut u8, mem::size_of::<u64>()) }),
151 IoSliceMut::new(&mut buf[..]),
152 ];
153 self.recv_into_iovec(&mut iov)?
154 };
155 buf.truncate(len as usize);
156 Ok((bytes, buf, len, rfds))
157 }
158
159
160 pub unsafe fn recv_struct_raw<T, F: Default + AsMut<[RawFd]>>(&mut self) -> Result<(T, Option<F>)> {
168 let (bytes, buf, rfds) = self.recv_into_buf(mem::size_of::<T>())?;
169 if bytes != mem::size_of::<T>() {
170 bail!(ErrorKind::WrongRecvLength);
171 }
172 Ok((ptr::read(buf.as_slice().as_ptr() as *const _), rfds))
173 }
174
175 #[cfg(feature = "zero_copy")]
184 pub fn recv_struct<T: zerocopy::FromBytes, F: Default + AsMut<[RawFd]>>(&mut self) -> Result<(T, Option<F>)> {
185 unsafe {
186 self.recv_struct_raw()
187 }
188 }
189
190 #[cfg(feature = "ser_cbor")]
202 pub fn recv_cbor<T: serde::de::DeserializeOwned, F: Default + AsMut<[RawFd]>>(&mut self, buf_size: usize) -> Result<(T, Option<F>)> {
203 let (bytes, buf, len, rfds) = self.recv_into_buf_with_len(buf_size)?;
204 if bytes != len as usize + mem::size_of::<u64>() {
205 bail!(ErrorKind::WrongRecvLength);
206 }
207 Ok((serde_cbor::from_slice(&buf[..])?, rfds))
208 }
209
210 #[cfg(feature = "ser_json")]
222 pub fn recv_json<T: serde::de::DeserializeOwned, F: Default + AsMut<[RawFd]>>(&mut self, buf_size: usize) -> Result<(T, Option<F>)> {
223 let (bytes, buf, len, rfds) = self.recv_into_buf_with_len(buf_size)?;
224 if bytes != len as usize + mem::size_of::<u64>() {
225 bail!(ErrorKind::WrongRecvLength);
226 }
227 Ok((serde_json::from_slice(&buf[..])?, rfds))
228 }
229
230 #[cfg(feature = "ser_bincode")]
242 pub fn recv_bincode<T: serde::de::DeserializeOwned, F: Default + AsMut<[RawFd]>>(&mut self, buf_size: usize) -> Result<(T, Option<F>)> {
243 let (bytes, buf, len, rfds) = self.recv_into_buf_with_len(buf_size)?;
244 if bytes != len as usize + mem::size_of::<u64>() {
245 bail!(ErrorKind::WrongRecvLength);
246 }
247 Ok((bincode::deserialize(&buf[..])?, rfds))
248 }
249
250 pub fn send_iovec(&mut self, iov: &[IoSlice], fds: Option<&[RawFd]>) -> Result<usize> {
254 if let Some(rfds) = fds {
255 sendmsg::<UnixAddr>(self.fd, iov, &[ControlMessage::ScmRights(rfds)], MsgFlags::empty(), None).map_err(|e| e.into())
256 } else {
257 sendmsg::<UnixAddr>(self.fd, iov, &[], MsgFlags::empty(), None).map_err(|e| e.into())
258 }
259 }
260
261 pub fn send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize> {
265 let iov = [IoSlice::new(data)];
266 self.send_iovec(&iov[..], fds)
267 }
268
269 pub fn send_slice_with_len(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize> {
274 let len = data.len() as u64;
275 let iov = [IoSlice::new(unsafe { slice::from_raw_parts((&len as *const u64) as *const u8, mem::size_of::<u64>()) }), IoSlice::new(data)];
276 self.send_iovec(&iov[..], fds)
277 }
278
279 pub unsafe fn send_struct_raw<T>(&mut self, data: &T, fds: Option<&[RawFd]>) -> Result<usize> {
286 self.send_slice(slice::from_raw_parts((data as *const T) as *const u8, mem::size_of::<T>()), fds)
287 }
288
289 #[cfg(feature = "zero_copy")]
295 pub fn send_struct<T: zerocopy::AsBytes>(&mut self, data: &T, fds: Option<&[RawFd]>) -> Result<usize> {
296 unsafe {
297 self.send_struct_raw(data, fds)
298 }
299 }
300
301
302
303 #[cfg(feature = "ser_cbor")]
307 pub fn send_cbor<T: serde::ser::Serialize>(&mut self, data: &T, fds: Option<&[RawFd]>) -> Result<usize> {
308 let bytes = serde_cbor::to_vec(data)?;
309 self.send_slice_with_len(&bytes[..], fds)
310 }
311
312 #[cfg(feature = "ser_json")]
316 pub fn send_json<T: serde::ser::Serialize>(&mut self, data: &T, fds: Option<&[RawFd]>) -> Result<usize> {
317 let bytes = serde_json::to_vec(data)?;
318 self.send_slice_with_len(&bytes[..], fds)
319 }
320
321 #[cfg(feature = "ser_bincode")]
325 pub fn send_bincode<T: serde::ser::Serialize>(&mut self, data: &T, fds: Option<&[RawFd]>) -> Result<usize> {
326 let bytes = bincode::serialize(data)?;
327 self.send_slice_with_len(&bytes[..], fds)
328 }
329}
330
331impl Drop for Socket {
332 fn drop(&mut self) {
333 let _ = unistd::close(self.fd);
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 extern crate shmemfdrs;
340 use super::Socket;
341 use std::os::unix::io::RawFd;
342 #[cfg(feature = "zero_copy")]
343 use zerocopy::AsBytes;
344
345 #[test]
346 fn test_slice_success() {
347 let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
348 let data = [0xDE, 0xAD, 0xBE, 0xEF];
349 let sent = tx.send_slice(&data[..], None).unwrap();
350 assert_eq!(sent, 4);
351 let mut rdata = [0; 4];
352 let (recvd, rfds) = rx.recv_into_slice::<[RawFd; 0]>(&mut rdata[..]).unwrap();
353 assert_eq!(recvd, 4);
354 assert_eq!(rfds, None);
355 assert_eq!(&rdata[..], &data[..]);
356 }
357
358 #[test]
359 fn test_slice_buf_too_short() {
360 let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
361 let data = [0xDE, 0xAD, 0xBE, 0xEF];
362 let sent = tx.send_slice(&data[..], None).unwrap();
363 assert_eq!(sent, 4);
364 let mut rdata = [0; 3];
365 let (recvd, rfds) = rx.recv_into_slice::<[RawFd; 0]>(&mut rdata[..]).unwrap();
366 assert_eq!(recvd, 3);
367 assert_eq!(rfds, None);
368 assert_eq!(&rdata[..], &data[0..3]);
369 }
370
371 #[test]
372 fn test_slice_with_len_success() {
373 let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
374 let data = [0xDE, 0xAD, 0xBE, 0xEF];
375 let sent = tx.send_slice_with_len(&data[..], None).unwrap();
376 assert_eq!(sent, 12); let mut rdata = [0; 12];
378 let (recvd, rfds) = rx.recv_into_slice::<[RawFd; 0]>(&mut rdata[..]).unwrap();
379 assert_eq!(recvd, 12);
380 assert_eq!(rfds, None);
381 assert_eq!(rdata[0], 4);
382 assert_eq!(&rdata[8..], &data[..]);
383 }
384
385 #[cfg(feature = "zero_copy")]
386 #[derive(Debug, PartialEq, FromBytes, AsBytes)]
387 #[repr(C)]
388 struct TestStruct {
389 one: i8,
390 pad: [u8; 3],
393 two: u32,
394 }
395
396 #[test]
397 #[cfg(feature = "zero_copy")]
398 fn test_struct_success() {
399 let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
400 let data = TestStruct { one: -64, two: 0xDEADBEEF, pad: [0, 0, 0]};
401 let _ = tx.send_struct(&data, None).unwrap();
402 let (rdata, rfds) = rx.recv_struct::<TestStruct, [RawFd; 0]>().unwrap();
403 assert_eq!(rfds, None);
404 assert_eq!(rdata, data);
405 }
406
407 #[test]
408 #[cfg(feature = "zero_copy")]
409 fn test_struct_wrong_len() {
410 let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
411 let data = [0xDE, 0xAD, 0xBE, 0xEF];
412 let sent = tx.send_slice(&data[..], None).unwrap();
413 assert_eq!(sent, 4);
414 let ret = rx.recv_struct::<TestStruct, [RawFd; 0]>();
415 assert!(ret.is_err());
416 }
417
418 #[test]
419 fn test_fd_passing() {
420 use std::fs::File;
421 use std::io::{Read, Write, Seek, SeekFrom};
422 use std::os::unix::io::FromRawFd;
423 use std::ffi::CString;
424 use std::mem::ManuallyDrop;
425 let fd = shmemfdrs::create_shmem(CString::new("/test").unwrap(), 6);
426 let mut orig_file = {
427 let mut file = unsafe { File::from_raw_fd(fd) };
428 file.write_all(b"hello\n").unwrap();
429 ManuallyDrop::new(file) };
431 let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
432 let data = [0xDE, 0xAD, 0xBE, 0xEF];
433 let sent = tx.send_slice(&data[..], Some(&[fd])).unwrap();
434 assert_eq!(sent, 4);
435 let mut rdata = [0; 4];
436 let (recvd, rfds) = rx.recv_into_slice::<[RawFd; 1]>(&mut rdata[..]).unwrap();
437 assert_eq!(recvd, 4);
438 assert_eq!(&rdata[..], &data[..]);
439 let new_fd = rfds.unwrap()[0];
440 {
441 let mut file = unsafe { File::from_raw_fd(new_fd) };
442 let mut content = String::new();
443 file.seek(SeekFrom::Start(0)).unwrap();
444 file.read_to_string(&mut content).unwrap();
445 assert_eq!(content, "hello\n");
446 }
447 unsafe { ManuallyDrop::drop(&mut orig_file); }
448 }
449
450 #[test]
451 #[cfg(feature = "ser_cbor")]
452 fn test_cbor() {
453 use serde_cbor::value::Value;
454 let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
455 let data = Value::Integer(123456);
456 let _ = tx.send_cbor(&data, None).unwrap();
457 let (rdata, rfds) = rx.recv_cbor::<Value, [RawFd; 0]>(24).unwrap();
458 assert_eq!(rfds, None);
459 assert_eq!(rdata, data);
460 }
461
462 #[test]
463 #[cfg(feature = "ser_json")]
464 fn test_json() {
465 use serde_json::value::Value;
466 let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
467 let data = Value::String("hi".to_owned());
468 let _ = tx.send_json(&data, None).unwrap();
469 let (rdata, rfds) = rx.recv_json::<Value, [RawFd; 0]>(24).unwrap();
470 assert_eq!(rfds, None);
471 assert_eq!(rdata, data);
472 }
473
474 #[test]
475 #[cfg(feature = "ser_bincode")]
476 fn test_bincode() {
477 let (mut rx, mut tx) = Socket::new_socketpair().unwrap();
478 let data = Some("hello world".to_string());
479 let _ = tx.send_bincode(&data, None).unwrap();
480 let (rdata, rfds) = rx.recv_bincode::<Option<String>, [RawFd; 0]>(24).unwrap();
481 assert_eq!(rfds, None);
482 assert_eq!(rdata, data);
483 }
484}