use super::{
protocol::VsockAddr, vsock::ConnectionInfo, DisconnectReason, SocketError, VirtIOSocket,
VsockEvent, VsockEventType,
};
use crate::{transport::Transport, Hal, Result};
use alloc::{boxed::Box, vec::Vec};
use core::cmp::min;
use core::convert::TryInto;
use core::hint::spin_loop;
use log::debug;
use zerocopy::FromBytes;
const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024;
pub struct VsockConnectionManager<H: Hal, T: Transport> {
driver: VirtIOSocket<H, T>,
connections: Vec<Connection>,
listening_ports: Vec<u32>,
}
#[derive(Debug)]
struct Connection {
info: ConnectionInfo,
buffer: RingBuffer,
peer_requested_shutdown: bool,
}
impl Connection {
fn new(peer: VsockAddr, local_port: u32) -> Self {
let mut info = ConnectionInfo::new(peer, local_port);
info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap();
Self {
info,
buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY),
peer_requested_shutdown: false,
}
}
}
impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
pub fn new(driver: VirtIOSocket<H, T>) -> Self {
Self {
driver,
connections: Vec::new(),
listening_ports: Vec::new(),
}
}
pub fn guest_cid(&self) -> u64 {
self.driver.guest_cid()
}
pub fn listen(&mut self, port: u32) {
if !self.listening_ports.contains(&port) {
self.listening_ports.push(port);
}
}
pub fn unlisten(&mut self, port: u32) {
self.listening_ports.retain(|p| *p != port);
}
pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
if self.connections.iter().any(|connection| {
connection.info.dst == destination && connection.info.src_port == src_port
}) {
return Err(SocketError::ConnectionExists.into());
}
let new_connection = Connection::new(destination, src_port);
self.driver.connect(&new_connection.info)?;
debug!("Connection requested: {:?}", new_connection.info);
self.connections.push(new_connection);
Ok(())
}
pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
self.driver.send(buffer, &mut connection.info)
}
pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
let guest_cid = self.driver.guest_cid();
let connections = &mut self.connections;
let result = self.driver.poll(|event, body| {
let connection = get_connection_for_event(connections, &event, guest_cid);
let connection = if let Some((_, connection)) = connection {
connection
} else if let VsockEventType::ConnectionRequest = event.event_type {
if connection.is_some() || event.destination.cid != guest_cid {
return Ok(None);
}
connections.push(Connection::new(event.source, event.destination.port));
connections.last_mut().unwrap()
} else {
return Ok(None);
};
connection.info.update_for_event(&event);
if let VsockEventType::Received { length } = event.event_type {
if !connection.buffer.add(body) {
return Err(SocketError::OutputBufferTooShort(length).into());
}
}
Ok(Some(event))
})?;
let Some(event) = result else {
return Ok(None);
};
let (connection_index, connection) =
get_connection_for_event(connections, &event, guest_cid).unwrap();
match event.event_type {
VsockEventType::ConnectionRequest => {
if self.listening_ports.contains(&event.destination.port) {
self.driver.accept(&connection.info)?;
} else {
self.driver.force_close(&connection.info)?;
self.connections.swap_remove(connection_index);
return Ok(None);
}
}
VsockEventType::Connected => {}
VsockEventType::Disconnected { reason } => {
if connection.buffer.is_empty() {
if reason == DisconnectReason::Shutdown {
self.driver.force_close(&connection.info)?;
}
self.connections.swap_remove(connection_index);
} else {
connection.peer_requested_shutdown = true;
}
}
VsockEventType::Received { .. } => {
}
VsockEventType::CreditRequest => {
self.driver.credit_update(&connection.info)?;
return Ok(None);
}
VsockEventType::CreditUpdate => {}
}
Ok(Some(event))
}
pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
let bytes_read = connection.buffer.drain(buffer);
connection.info.done_forwarding(bytes_read);
if connection.peer_requested_shutdown && connection.buffer.is_empty() {
self.driver.force_close(&connection.info)?;
self.connections.swap_remove(connection_index);
}
Ok(bytes_read)
}
pub fn recv_buffer_available_bytes(&mut self, peer: VsockAddr, src_port: u32) -> Result<usize> {
let (_, connection) = get_connection(&mut self.connections, peer, src_port)?;
Ok(connection.buffer.available())
}
pub fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result {
let (_, connection) = get_connection(&mut self.connections, peer, src_port)?;
self.driver.credit_update(&connection.info)
}
pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
loop {
if let Some(event) = self.poll()? {
return Ok(event);
} else {
spin_loop();
}
}
}
pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
self.driver.shutdown(&connection.info)
}
pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
self.driver.force_close(&connection.info)?;
self.connections.swap_remove(index);
Ok(())
}
}
fn get_connection(
connections: &mut [Connection],
peer: VsockAddr,
local_port: u32,
) -> core::result::Result<(usize, &mut Connection), SocketError> {
connections
.iter_mut()
.enumerate()
.find(|(_, connection)| {
connection.info.dst == peer && connection.info.src_port == local_port
})
.ok_or(SocketError::NotConnected)
}
fn get_connection_for_event<'a>(
connections: &'a mut [Connection],
event: &VsockEvent,
local_cid: u64,
) -> Option<(usize, &'a mut Connection)> {
connections
.iter_mut()
.enumerate()
.find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
}
#[derive(Debug)]
struct RingBuffer {
buffer: Box<[u8]>,
used: usize,
start: usize,
}
impl RingBuffer {
pub fn new(capacity: usize) -> Self {
Self {
buffer: FromBytes::new_box_slice_zeroed(capacity),
used: 0,
start: 0,
}
}
pub fn used(&self) -> usize {
self.used
}
pub fn is_empty(&self) -> bool {
self.used == 0
}
pub fn available(&self) -> usize {
self.buffer.len() - self.used
}
pub fn add(&mut self, bytes: &[u8]) -> bool {
if bytes.len() > self.available() {
return false;
}
let first_available = (self.start + self.used) % self.buffer.len();
let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
self.buffer[first_available..first_available + copy_length_before_wraparound]
.copy_from_slice(&bytes[0..copy_length_before_wraparound]);
if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
}
self.used += bytes.len();
true
}
pub fn drain(&mut self, out: &mut [u8]) -> usize {
let bytes_read = min(self.used, out.len());
let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
let read_after_wraparound = bytes_read
.checked_sub(read_before_wraparound)
.unwrap_or_default();
out[0..read_before_wraparound]
.copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
out[read_before_wraparound..bytes_read]
.copy_from_slice(&self.buffer[0..read_after_wraparound]);
self.used -= bytes_read;
self.start = (self.start + bytes_read) % self.buffer.len();
bytes_read
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
device::socket::{
protocol::{SocketType, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp},
vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
},
hal::fake::FakeHal,
transport::{
fake::{FakeTransport, QueueStatus, State},
DeviceType,
},
volatile::ReadOnly,
};
use alloc::{sync::Arc, vec};
use core::{mem::size_of, ptr::NonNull};
use std::{sync::Mutex, thread};
use zerocopy::{AsBytes, FromBytes};
#[test]
fn send_recv() {
let host_cid = 2;
let guest_cid = 66;
let host_port = 1234;
let guest_port = 4321;
let host_address = VsockAddr {
cid: host_cid,
port: host_port,
};
let hello_from_guest = "Hello from guest";
let hello_from_host = "Hello from host";
let mut config_space = VirtioVsockConfig {
guest_cid_low: ReadOnly::new(66),
guest_cid_high: ReadOnly::new(0),
};
let state = Arc::new(Mutex::new(State {
queues: vec![
QueueStatus::default(),
QueueStatus::default(),
QueueStatus::default(),
],
..Default::default()
}));
let transport = FakeTransport {
device_type: DeviceType::Socket,
max_queue_size: 32,
device_features: 0,
config_space: NonNull::from(&mut config_space),
state: state.clone(),
};
let mut socket = VsockConnectionManager::new(
VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
);
let handle = thread::spawn(move || {
State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
assert_eq!(
VirtioVsockHdr::read_from(
state
.lock()
.unwrap()
.read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
.as_slice()
)
.unwrap(),
VirtioVsockHdr {
op: VirtioVsockOp::Request.into(),
src_cid: guest_cid.into(),
dst_cid: host_cid.into(),
src_port: guest_port.into(),
dst_port: host_port.into(),
len: 0.into(),
socket_type: SocketType::Stream.into(),
flags: 0.into(),
buf_alloc: 1024.into(),
fwd_cnt: 0.into(),
}
);
state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
RX_QUEUE_IDX,
VirtioVsockHdr {
op: VirtioVsockOp::Response.into(),
src_cid: host_cid.into(),
dst_cid: guest_cid.into(),
src_port: host_port.into(),
dst_port: guest_port.into(),
len: 0.into(),
socket_type: SocketType::Stream.into(),
flags: 0.into(),
buf_alloc: 50.into(),
fwd_cnt: 0.into(),
}
.as_bytes(),
);
State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
let request = state
.lock()
.unwrap()
.read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
assert_eq!(
request.len(),
size_of::<VirtioVsockHdr>() + hello_from_guest.len()
);
assert_eq!(
VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
VirtioVsockHdr {
op: VirtioVsockOp::Rw.into(),
src_cid: guest_cid.into(),
dst_cid: host_cid.into(),
src_port: guest_port.into(),
dst_port: host_port.into(),
len: (hello_from_guest.len() as u32).into(),
socket_type: SocketType::Stream.into(),
flags: 0.into(),
buf_alloc: 1024.into(),
fwd_cnt: 0.into(),
}
);
assert_eq!(
&request[size_of::<VirtioVsockHdr>()..],
hello_from_guest.as_bytes()
);
println!("Host sending");
let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
VirtioVsockHdr {
op: VirtioVsockOp::Rw.into(),
src_cid: host_cid.into(),
dst_cid: guest_cid.into(),
src_port: host_port.into(),
dst_port: guest_port.into(),
len: (hello_from_host.len() as u32).into(),
socket_type: SocketType::Stream.into(),
flags: 0.into(),
buf_alloc: 50.into(),
fwd_cnt: (hello_from_guest.len() as u32).into(),
}
.write_to_prefix(response.as_mut_slice());
response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
state
.lock()
.unwrap()
.write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
assert_eq!(
VirtioVsockHdr::read_from(
state
.lock()
.unwrap()
.read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
.as_slice()
)
.unwrap(),
VirtioVsockHdr {
op: VirtioVsockOp::Shutdown.into(),
src_cid: guest_cid.into(),
dst_cid: host_cid.into(),
src_port: guest_port.into(),
dst_port: host_port.into(),
len: 0.into(),
socket_type: SocketType::Stream.into(),
flags: 0.into(),
buf_alloc: 1024.into(),
fwd_cnt: (hello_from_host.len() as u32).into(),
}
);
});
socket.connect(host_address, guest_port).unwrap();
assert_eq!(
socket.wait_for_event().unwrap(),
VsockEvent {
source: host_address,
destination: VsockAddr {
cid: guest_cid,
port: guest_port,
},
event_type: VsockEventType::Connected,
buffer_status: VsockBufferStatus {
buffer_allocation: 50,
forward_count: 0,
},
}
);
println!("Guest sending");
socket
.send(host_address, guest_port, "Hello from guest".as_bytes())
.unwrap();
println!("Guest waiting to receive.");
assert_eq!(
socket.wait_for_event().unwrap(),
VsockEvent {
source: host_address,
destination: VsockAddr {
cid: guest_cid,
port: guest_port,
},
event_type: VsockEventType::Received {
length: hello_from_host.len()
},
buffer_status: VsockBufferStatus {
buffer_allocation: 50,
forward_count: hello_from_guest.len() as u32,
},
}
);
println!("Guest getting received data.");
let mut buffer = [0u8; 64];
assert_eq!(
socket.recv(host_address, guest_port, &mut buffer).unwrap(),
hello_from_host.len()
);
assert_eq!(
&buffer[0..hello_from_host.len()],
hello_from_host.as_bytes()
);
socket.shutdown(host_address, guest_port).unwrap();
handle.join().unwrap();
}
#[test]
fn incoming_connection() {
let host_cid = 2;
let guest_cid = 66;
let host_port = 1234;
let guest_port = 4321;
let wrong_guest_port = 4444;
let host_address = VsockAddr {
cid: host_cid,
port: host_port,
};
let mut config_space = VirtioVsockConfig {
guest_cid_low: ReadOnly::new(66),
guest_cid_high: ReadOnly::new(0),
};
let state = Arc::new(Mutex::new(State {
queues: vec![
QueueStatus::default(),
QueueStatus::default(),
QueueStatus::default(),
],
..Default::default()
}));
let transport = FakeTransport {
device_type: DeviceType::Socket,
max_queue_size: 32,
device_features: 0,
config_space: NonNull::from(&mut config_space),
state: state.clone(),
};
let mut socket = VsockConnectionManager::new(
VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
);
socket.listen(guest_port);
let handle = thread::spawn(move || {
println!("Host sending connection request to wrong port");
state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
RX_QUEUE_IDX,
VirtioVsockHdr {
op: VirtioVsockOp::Request.into(),
src_cid: host_cid.into(),
dst_cid: guest_cid.into(),
src_port: host_port.into(),
dst_port: wrong_guest_port.into(),
len: 0.into(),
socket_type: SocketType::Stream.into(),
flags: 0.into(),
buf_alloc: 50.into(),
fwd_cnt: 0.into(),
}
.as_bytes(),
);
println!("Host waiting for rejection");
State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
assert_eq!(
VirtioVsockHdr::read_from(
state
.lock()
.unwrap()
.read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
.as_slice()
)
.unwrap(),
VirtioVsockHdr {
op: VirtioVsockOp::Rst.into(),
src_cid: guest_cid.into(),
dst_cid: host_cid.into(),
src_port: wrong_guest_port.into(),
dst_port: host_port.into(),
len: 0.into(),
socket_type: SocketType::Stream.into(),
flags: 0.into(),
buf_alloc: 1024.into(),
fwd_cnt: 0.into(),
}
);
println!("Host sending connection request to right port");
state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
RX_QUEUE_IDX,
VirtioVsockHdr {
op: VirtioVsockOp::Request.into(),
src_cid: host_cid.into(),
dst_cid: guest_cid.into(),
src_port: host_port.into(),
dst_port: guest_port.into(),
len: 0.into(),
socket_type: SocketType::Stream.into(),
flags: 0.into(),
buf_alloc: 50.into(),
fwd_cnt: 0.into(),
}
.as_bytes(),
);
println!("Host waiting for response");
State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
assert_eq!(
VirtioVsockHdr::read_from(
state
.lock()
.unwrap()
.read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
.as_slice()
)
.unwrap(),
VirtioVsockHdr {
op: VirtioVsockOp::Response.into(),
src_cid: guest_cid.into(),
dst_cid: host_cid.into(),
src_port: guest_port.into(),
dst_port: host_port.into(),
len: 0.into(),
socket_type: SocketType::Stream.into(),
flags: 0.into(),
buf_alloc: 1024.into(),
fwd_cnt: 0.into(),
}
);
println!("Host finished");
});
println!("Guest expecting incoming connection.");
assert_eq!(
socket.wait_for_event().unwrap(),
VsockEvent {
source: host_address,
destination: VsockAddr {
cid: guest_cid,
port: guest_port,
},
event_type: VsockEventType::ConnectionRequest,
buffer_status: VsockBufferStatus {
buffer_allocation: 50,
forward_count: 0,
},
}
);
handle.join().unwrap();
}
}