use std::io::{Error, ErrorKind, Read, Result, Write};
use std::mem::size_of;
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use libc::*;
use nix::sys::socket::{SockAddr, VsockAddr};
use std::ffi::c_void;
use std::net::Shutdown;
use std::time::Duration;
#[derive(Debug)]
pub struct Incoming<'a> {
listener: &'a VsockListener,
}
impl<'a> Iterator for Incoming<'a> {
type Item = Result<VsockStream>;
fn next(&mut self) -> Option<Result<VsockStream>> {
Some(self.listener.accept().map(|p| p.0))
}
}
#[derive(Debug, Clone)]
pub struct VsockListener {
socket: RawFd,
}
impl VsockListener {
pub fn bind(addr: &SockAddr) -> Result<VsockListener> {
let mut vsock_addr = if let SockAddr::Vsock(addr) = addr {
addr.0
} else {
return Err(Error::new(
ErrorKind::Other,
"require a virtio socket address",
));
};
let socket = unsafe { socket(AF_VSOCK, SOCK_STREAM, 0) };
if socket < 0 {
return Err(Error::new(ErrorKind::Other, "socket() failed"));
}
let res = unsafe {
bind(
socket,
&mut vsock_addr as *mut _ as *mut sockaddr,
size_of::<sockaddr_vm>() as u32,
)
};
if res < 0 {
return Err(Error::new(ErrorKind::Other, "bind() failed"));
}
let res = unsafe { listen(socket, 128) };
if res < 0 {
return Err(Error::new(ErrorKind::Other, "listen() failed"));
}
Ok(Self { socket })
}
pub fn local_addr(&self) -> Result<SockAddr> {
let mut vsock_addr = sockaddr_vm {
svm_family: AF_VSOCK as sa_family_t,
svm_reserved1: 0,
svm_port: 0,
svm_cid: 0,
svm_zero: [0u8; 4],
};
let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
if unsafe {
getsockname(
self.socket,
&mut vsock_addr as *mut _ as *mut sockaddr,
&mut vsock_addr_len,
)
} < 0
{
Err(Error::new(ErrorKind::Other, "getsockname() failed"))
} else {
Ok(SockAddr::Vsock(VsockAddr(vsock_addr)))
}
}
pub fn try_clone(&self) -> Result<Self> {
Ok(self.clone())
}
pub fn accept(&self) -> Result<(VsockStream, SockAddr)> {
let mut vsock_addr = sockaddr_vm {
svm_family: AF_VSOCK as sa_family_t,
svm_reserved1: 0,
svm_port: 0,
svm_cid: 0,
svm_zero: [0u8; 4],
};
let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
let socket = unsafe {
accept(
self.socket,
&mut vsock_addr as *mut _ as *mut sockaddr,
&mut vsock_addr_len,
)
};
if socket < 0 {
Err(Error::new(ErrorKind::Other, "accept() failed"))
} else {
Ok((
unsafe { VsockStream::from_raw_fd(socket as RawFd) },
SockAddr::Vsock(VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port)),
))
}
}
pub fn incoming(&self) -> Incoming {
Incoming { listener: self }
}
pub fn take_error(&self) -> Result<Option<Error>> {
let mut error: i32 = 0;
let mut error_len: socklen_t = 0;
if unsafe {
getsockopt(
self.socket,
SOL_SOCKET,
SO_ERROR,
&mut error as *mut _ as *mut c_void,
&mut error_len,
)
} < 0
{
Err(Error::new(ErrorKind::Other, "getsockopt() failed"))
} else {
Ok(if error == 0 {
None
} else {
Some(Error::from_raw_os_error(error))
})
}
}
pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
if unsafe { ioctl(self.socket, FIONBIO, &mut nonblocking) } < 0 {
Err(Error::new(ErrorKind::Other, "ioctl() failed"))
} else {
Ok(())
}
}
}
impl AsRawFd for VsockListener {
fn as_raw_fd(&self) -> RawFd {
self.socket
}
}
impl FromRawFd for VsockListener {
unsafe fn from_raw_fd(socket: RawFd) -> Self {
Self { socket }
}
}
impl IntoRawFd for VsockListener {
fn into_raw_fd(self) -> RawFd {
self.socket
}
}
#[derive(Debug, Clone)]
pub struct VsockStream {
socket: RawFd,
}
impl VsockStream {
pub fn connect(addr: &SockAddr) -> Result<Self> {
let vsock_addr = if let SockAddr::Vsock(addr) = addr {
addr.0
} else {
return Err(Error::new(
ErrorKind::Other,
"require a virtio socket address",
));
};
let sock = unsafe { socket(AF_VSOCK, SOCK_STREAM, 0) };
if sock < 0 {
return Err(Error::new(ErrorKind::Other, "socket() failed"));
}
if unsafe {
connect(
sock,
&vsock_addr as *const _ as *const sockaddr,
size_of::<sockaddr_vm>() as u32,
)
} < 0
{
Err(Error::new(ErrorKind::Other, "connect() failed"))
} else {
Ok(unsafe { VsockStream::from_raw_fd(sock) })
}
}
pub fn peer_addr(&self) -> Result<SockAddr> {
let mut vsock_addr = sockaddr_vm {
svm_family: AF_VSOCK as sa_family_t,
svm_reserved1: 0,
svm_port: 0,
svm_cid: 0,
svm_zero: [0u8; 4],
};
let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
if unsafe {
getpeername(
self.socket,
&mut vsock_addr as *mut _ as *mut sockaddr,
&mut vsock_addr_len,
)
} < 0
{
Err(Error::new(ErrorKind::Other, "getpeername() failed"))
} else {
Ok(SockAddr::Vsock(VsockAddr(vsock_addr)))
}
}
pub fn local_addr(&self) -> Result<SockAddr> {
let mut vsock_addr = sockaddr_vm {
svm_family: AF_VSOCK as sa_family_t,
svm_reserved1: 0,
svm_port: 0,
svm_cid: 0,
svm_zero: [0u8; 4],
};
let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
if unsafe {
getsockname(
self.socket,
&mut vsock_addr as *mut _ as *mut sockaddr,
&mut vsock_addr_len,
)
} < 0
{
Err(Error::new(ErrorKind::Other, "getsockname() failed"))
} else {
Ok(SockAddr::Vsock(VsockAddr(vsock_addr)))
}
}
pub fn shutdown(&self, how: Shutdown) -> Result<()> {
let how = match how {
Shutdown::Write => SHUT_WR,
Shutdown::Read => SHUT_RD,
Shutdown::Both => SHUT_RDWR,
};
if unsafe { shutdown(self.socket, how) } < 0 {
Err(Error::new(ErrorKind::Other, "shutdown() failed"))
} else {
Ok(())
}
}
pub fn try_clone(&self) -> Result<Self> {
Ok(self.clone())
}
pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
let timeout = Self::timeval_from_duration(dur)?;
if unsafe {
setsockopt(
self.socket,
SOL_SOCKET,
SO_SNDTIMEO,
&timeout as *const _ as *const c_void,
size_of::<timeval>() as u32,
)
} < 0
{
Err(Error::new(ErrorKind::Other, "setsockopt() failed"))
} else {
Ok(())
}
}
pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
let timeout = Self::timeval_from_duration(dur)?;
if unsafe {
setsockopt(
self.socket,
SOL_SOCKET,
SO_RCVTIMEO,
&timeout as *const _ as *const c_void,
size_of::<timeval>() as u32,
)
} < 0
{
Err(Error::new(ErrorKind::Other, "setsockopt() failed"))
} else {
Ok(())
}
}
pub fn take_error(&self) -> Result<Option<Error>> {
let mut error: i32 = 0;
let mut error_len: socklen_t = 0;
if unsafe {
getsockopt(
self.socket,
SOL_SOCKET,
SO_ERROR,
&mut error as *mut _ as *mut c_void,
&mut error_len,
)
} < 0
{
Err(Error::new(ErrorKind::Other, "getsockopt() failed"))
} else {
Ok(if error == 0 {
None
} else {
Some(Error::from_raw_os_error(error))
})
}
}
pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
if unsafe { ioctl(self.socket, FIONBIO, &mut nonblocking) } < 0 {
Err(Error::new(ErrorKind::Other, "ioctl() failed"))
} else {
Ok(())
}
}
fn timeval_from_duration(dur: Option<Duration>) -> Result<timeval> {
match dur {
Some(dur) => {
if dur.as_secs() == 0 && dur.subsec_nanos() == 0 {
return Err(Error::new(
ErrorKind::InvalidInput,
"cannot set a zero duration timeout",
));
}
let secs = if dur.as_secs() > time_t::max_value() as u64 {
time_t::max_value()
} else {
dur.as_secs() as time_t
};
let mut timeout = timeval {
tv_sec: secs,
tv_usec: i64::from(dur.subsec_micros()) as suseconds_t,
};
if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
timeout.tv_usec = 1;
}
Ok(timeout)
}
None => Ok(timeval {
tv_sec: 0,
tv_usec: 0,
}),
}
}
}
impl Read for VsockStream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
let ret = unsafe { recv(self.socket, buf.as_mut_ptr() as *mut c_void, buf.len(), 0) };
if ret < 0 {
Err(Error::new(ErrorKind::Other, "recv() failed"))
} else {
Ok(ret as usize)
}
}
}
impl Write for VsockStream {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
let ret = unsafe {
send(
self.socket,
buf.as_ptr() as *const c_void,
buf.len(),
MSG_NOSIGNAL,
)
};
if ret < 0 {
Err(Error::new(ErrorKind::Other, "send() failed"))
} else {
Ok(ret as usize)
}
}
fn flush(&mut self) -> Result<()> {
Ok(())
}
}
impl Read for &VsockStream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
let ret = unsafe { recv(self.socket, buf.as_mut_ptr() as *mut c_void, buf.len(), 0) };
if ret < 0 {
Err(Error::new(ErrorKind::Other, "recv() failed"))
} else {
Ok(ret as usize)
}
}
}
impl Write for &VsockStream {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
let ret = unsafe {
send(
self.socket,
buf.as_ptr() as *const c_void,
buf.len(),
MSG_NOSIGNAL,
)
};
if ret < 0 {
Err(Error::new(ErrorKind::Other, "send() failed"))
} else {
Ok(ret as usize)
}
}
fn flush(&mut self) -> Result<()> {
Ok(())
}
}
impl AsRawFd for VsockStream {
fn as_raw_fd(&self) -> RawFd {
self.socket
}
}
impl FromRawFd for VsockStream {
unsafe fn from_raw_fd(socket: RawFd) -> Self {
Self { socket }
}
}
impl IntoRawFd for VsockStream {
fn into_raw_fd(self) -> RawFd {
self.socket
}
}