Skip to main content

virtio_drivers_and_devices/device/socket/
connectionmanager.rs

1use super::{
2    protocol::VsockAddr, vsock::ConnectionInfo, DisconnectReason, SocketError, VirtIOSocket,
3    VirtIOSocketDevice, VirtIOSocketManager, VsockEvent, VsockEventType, DEFAULT_RX_BUFFER_SIZE,
4};
5use crate::{
6    transport::{DeviceTransport, InterruptStatus, Transport},
7    DeviceHal, Hal, Result,
8};
9use alloc::{boxed::Box, vec::Vec};
10use core::cmp::min;
11use core::convert::TryInto;
12use core::hint::spin_loop;
13use log::debug;
14use zerocopy::FromZeros;
15
16const DEFAULT_PER_CONNECTION_BUFFER_CAPACITY: u32 = 1024;
17
18/// A higher level interface for VirtIO socket (vsock) drivers.
19///
20/// This keeps track of multiple vsock connections.
21///
22/// `RX_BUFFER_SIZE` is the size in bytes of each buffer used in the RX virtqueue. This must be
23/// bigger than `size_of::<VirtioVsockHdr>()`.
24///
25/// # Example
26///
27/// ```
28/// # use virtio_drivers_and_devices::{Error, Hal};
29/// # use virtio_drivers_and_devices::transport::Transport;
30/// use virtio_drivers_and_devices::device::socket::{VirtIOSocket, VsockAddr, VsockConnectionManager};
31///
32/// # fn example<HalImpl: Hal, T: Transport>(transport: T) -> Result<(), Error> {
33/// let mut socket = VsockConnectionManager::new(VirtIOSocket::<HalImpl, _>::new(transport)?);
34///
35/// // Start a thread to call `socket.poll()` and handle events.
36///
37/// let remote_address = VsockAddr { cid: 2, port: 42 };
38/// let local_port = 1234;
39/// socket.connect(remote_address, local_port)?;
40///
41/// // Wait until `socket.poll()` returns an event indicating that the socket is connected.
42///
43/// socket.send(remote_address, local_port, "Hello world".as_bytes())?;
44///
45/// socket.shutdown(remote_address, local_port)?;
46/// # Ok(())
47/// # }
48/// ```
49pub struct VsockConnectionManager<
50    H: Hal,
51    T: Transport,
52    const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE,
53>(VsockConnectionManagerCommon<VirtIOSocket<H, T, RX_BUFFER_SIZE>>);
54
55/// A high level interface for VirtIO socket (vsock) devices.
56pub struct VsockDeviceConnectionManager<H: DeviceHal, T: DeviceTransport>(
57    VsockConnectionManagerCommon<VirtIOSocketDevice<H, T>>,
58);
59
60/// A trait defining shared behavior for VirtIO socket devices and drivers.
61///
62/// All methods are implemented for VsockConnectionManager and VsockDeviceConnectionManager though
63/// the device side must not call the connect method. These are equivalent to the inherent methods
64/// which are kept for backwards compatibility.
65pub trait VsockManager {
66    /// Sends a request to connect to the given destination on the driver side.
67    ///
68    /// This returns as soon as the request is sent; you should wait until `poll` returns a
69    /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
70    /// before sending data. This panics if called from the device side.
71    fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result;
72
73    /// Sends the buffer to the destination.
74    fn send(&mut self, dest: VsockAddr, src_port: u32, buffer: &[u8]) -> Result;
75
76    /// Sends a credit update to the given peer.
77    fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result;
78
79    /// Forcibly closes the connection without waiting for the peer.
80    fn force_close(&mut self, dest: VsockAddr, src_port: u32) -> Result;
81
82    /// Allows incoming connections on the given port number.
83    fn listen(&mut self, port: u32);
84
85    /// Polls the vsock device to receive data or other updates.
86    fn poll(&mut self) -> Result<Option<VsockEvent>>;
87
88    /// Returns the local CID, i.e. the CID of the guest on the driver side and the CID of the host
89    /// on the device side.
90    fn local_cid(&self) -> u64;
91
92    /// Requests to shut down the connection cleanly, telling the peer that we won't send or receive
93    /// any more data.
94    ///
95    /// This returns as soon as the request is sent; you should wait until `poll` returns a
96    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
97    /// shutdown.
98    fn shutdown(&mut self, dest: VsockAddr, src_port: u32) -> Result;
99
100    /// Reads data received from the given connection.
101    fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize>;
102
103    /// Acknowledges an interrupt using a pointer to the VsockManager impl
104    ///
105    /// This is useful when you cannot soundly get a mutable reference to the VsockManager impl. It
106    /// may only be called on the driver side as it will panic if called on the device side.
107    ///
108    /// # Safety
109    ///
110    /// `ptr` must point to an initialized VsockManager impl which is ready to acknowledge
111    /// interrupts.
112    unsafe fn ack_interrupt(ptr: *mut Self) -> InterruptStatus;
113}
114
115struct VsockConnectionManagerCommon<M: VirtIOSocketManager> {
116    driver: M,
117    per_connection_buffer_capacity: u32,
118    connections: Vec<Connection>,
119    listening_ports: Vec<u32>,
120}
121
122#[derive(Debug)]
123struct Connection {
124    info: ConnectionInfo,
125    buffer: RingBuffer,
126    /// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is
127    /// still data in the buffer.
128    peer_requested_shutdown: bool,
129}
130
131impl Connection {
132    fn new(peer: VsockAddr, local_port: u32, buffer_capacity: u32) -> Self {
133        let mut info = ConnectionInfo::new(peer, local_port);
134        info.buf_alloc = buffer_capacity;
135        Self {
136            info,
137            buffer: RingBuffer::new(buffer_capacity.try_into().unwrap()),
138            peer_requested_shutdown: false,
139        }
140    }
141}
142
143impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize>
144    VsockConnectionManager<H, T, RX_BUFFER_SIZE>
145{
146    /// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
147    pub fn new(driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>) -> Self {
148        Self::new_with_capacity(driver, DEFAULT_PER_CONNECTION_BUFFER_CAPACITY)
149    }
150
151    /// Construct a new connection manager wrapping the given low-level VirtIO socket driver, with
152    /// the given per-connection buffer capacity.
153    pub fn new_with_capacity(
154        driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>,
155        per_connection_buffer_capacity: u32,
156    ) -> Self {
157        Self(VsockConnectionManagerCommon {
158            driver,
159            connections: Vec::new(),
160            listening_ports: Vec::new(),
161            per_connection_buffer_capacity,
162        })
163    }
164
165    /// Returns the CID which has been assigned to this guest.
166    pub fn guest_cid(&self) -> u64 {
167        self.0.local_cid()
168    }
169
170    /// Sends a request to connect to the given destination.
171    ///
172    /// This returns as soon as the request is sent; you should wait until `poll` returns a
173    /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
174    /// before sending data.
175    pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
176        if self.0.connections.iter().any(|connection| {
177            connection.info.dst == destination && connection.info.src_port == src_port
178        }) {
179            return Err(SocketError::ConnectionExists.into());
180        }
181
182        let new_connection =
183            Connection::new(destination, src_port, self.0.per_connection_buffer_capacity);
184
185        self.0.driver.connect(&new_connection.info)?;
186        debug!("Connection requested: {:?}", new_connection.info);
187        self.0.connections.push(new_connection);
188        Ok(())
189    }
190    /// Allows incoming connections on the given port number.
191    pub fn listen(&mut self, port: u32) {
192        self.0.listen(port)
193    }
194
195    /// Stops allowing incoming connections on the given port number.
196    pub fn unlisten(&mut self, port: u32) {
197        self.0.unlisten(port)
198    }
199
200    /// Sends the buffer to the destination.
201    pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
202        self.0.send(destination, src_port, buffer)
203    }
204
205    /// Polls the vsock device to receive data or other updates.
206    pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
207        self.0.poll()
208    }
209
210    /// Reads data received from the given connection.
211    pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
212        self.0.recv(peer, src_port, buffer)
213    }
214
215    /// Returns the number of bytes in the receive buffer available to be read by `recv`.
216    ///
217    /// When the available bytes is 0, it indicates that the receive buffer is empty and does not
218    /// contain any data.
219    pub fn recv_buffer_available_bytes(&mut self, peer: VsockAddr, src_port: u32) -> Result<usize> {
220        self.0.recv_buffer_available_bytes(peer, src_port)
221    }
222
223    /// Sends a credit update to the given peer.
224    pub fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result {
225        self.0.update_credit(peer, src_port)
226    }
227
228    /// Blocks until we get some event from the vsock device.
229    pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
230        self.0.wait_for_event()
231    }
232
233    /// Requests to shut down the connection cleanly, telling the peer that we won't send or receive
234    /// any more data.
235    ///
236    /// This returns as soon as the request is sent; you should wait until `poll` returns a
237    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
238    /// shutdown.
239    pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
240        self.0.shutdown(destination, src_port)
241    }
242
243    /// Forcibly closes the connection without waiting for the peer.
244    pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
245        self.0.force_close(destination, src_port)
246    }
247}
248
249impl<H: DeviceHal, T: DeviceTransport> VsockDeviceConnectionManager<H, T> {
250    /// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
251    pub fn new(driver: VirtIOSocketDevice<H, T>) -> Self {
252        Self::new_with_capacity(driver, DEFAULT_PER_CONNECTION_BUFFER_CAPACITY)
253    }
254
255    /// Construct a new connection manager wrapping the given low-level VirtIO socket driver, with
256    /// the given per-connection buffer capacity.
257    pub fn new_with_capacity(
258        driver: VirtIOSocketDevice<H, T>,
259        per_connection_buffer_capacity: u32,
260    ) -> Self {
261        Self(VsockConnectionManagerCommon {
262            driver,
263            connections: Vec::new(),
264            listening_ports: Vec::new(),
265            per_connection_buffer_capacity,
266        })
267    }
268
269    /// Allows incoming connections on the given port number.
270    pub fn listen(&mut self, port: u32) {
271        self.0.listen(port)
272    }
273
274    /// Stops allowing incoming connections on the given port number.
275    pub fn unlisten(&mut self, port: u32) {
276        self.0.unlisten(port)
277    }
278
279    /// Sends the buffer to the destination.
280    pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
281        self.0.send(destination, src_port, buffer)
282    }
283
284    /// Polls the vsock device to receive data or other updates.
285    pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
286        self.0.poll()
287    }
288
289    /// Reads data received from the given connection.
290    pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
291        self.0.recv(peer, src_port, buffer)
292    }
293
294    /// Returns the number of bytes in the receive buffer available to be read by `recv`.
295    ///
296    /// When the available bytes is 0, it indicates that the receive buffer is empty and does not
297    /// contain any data.
298    pub fn recv_buffer_available_bytes(&mut self, peer: VsockAddr, src_port: u32) -> Result<usize> {
299        self.0.recv_buffer_available_bytes(peer, src_port)
300    }
301
302    /// Sends a credit update to the given peer.
303    pub fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result {
304        self.0.update_credit(peer, src_port)
305    }
306
307    /// Blocks until we get some event from the vsock device.
308    pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
309        self.0.wait_for_event()
310    }
311
312    /// Requests to shut down the connection cleanly, telling the peer that we won't send or receive
313    /// any more data.
314    ///
315    /// This returns as soon as the request is sent; you should wait until `poll` returns a
316    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
317    /// shutdown.
318    pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
319        self.0.shutdown(destination, src_port)
320    }
321
322    /// Forcibly closes the connection without waiting for the peer.
323    pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
324        self.0.force_close(destination, src_port)
325    }
326}
327
328impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VsockManager
329    for VsockConnectionManager<H, T, RX_BUFFER_SIZE>
330{
331    fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
332        Self::connect(self, destination, src_port)
333    }
334    fn send(&mut self, dest: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
335        Self::send(self, dest, src_port, buffer)
336    }
337    fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result {
338        Self::update_credit(self, peer, src_port)
339    }
340    fn force_close(&mut self, dest: VsockAddr, src_port: u32) -> Result {
341        Self::force_close(self, dest, src_port)
342    }
343    fn listen(&mut self, port: u32) {
344        Self::listen(self, port)
345    }
346    fn poll(&mut self) -> Result<Option<VsockEvent>> {
347        Self::poll(self)
348    }
349    fn local_cid(&self) -> u64 {
350        self.0.local_cid()
351    }
352    fn shutdown(&mut self, dest: VsockAddr, src_port: u32) -> Result {
353        Self::shutdown(self, dest, src_port)
354    }
355    fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
356        Self::recv(self, peer, src_port, buffer)
357    }
358    unsafe fn ack_interrupt(ptr: *mut Self) -> InterruptStatus {
359        // SAFETY: This function's safety requirements ensure that `ptr` points to a valid
360        // VsockConnectionManager so this gives a valid pointer to the field.
361        let vsock_driver_ptr = unsafe { &raw mut (*ptr).0.driver };
362        // SAFETY: delegated to the caller
363        unsafe { VirtIOSocket::<H, T, RX_BUFFER_SIZE>::ack_interrupt(vsock_driver_ptr) }
364    }
365}
366
367impl<H: DeviceHal, T: DeviceTransport> VsockManager for VsockDeviceConnectionManager<H, T> {
368    fn connect(&mut self, _destination: VsockAddr, _src_port: u32) -> Result {
369        unreachable!("vsock devices should not make outgoing connections")
370    }
371    fn send(&mut self, dest: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
372        Self::send(self, dest, src_port, buffer)
373    }
374    fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result {
375        Self::update_credit(self, peer, src_port)
376    }
377    fn force_close(&mut self, dest: VsockAddr, src_port: u32) -> Result {
378        Self::force_close(self, dest, src_port)
379    }
380    fn listen(&mut self, port: u32) {
381        Self::listen(self, port)
382    }
383    fn poll(&mut self) -> Result<Option<VsockEvent>> {
384        Self::poll(self)
385    }
386    fn local_cid(&self) -> u64 {
387        self.0.local_cid()
388    }
389    fn shutdown(&mut self, dest: VsockAddr, src_port: u32) -> Result {
390        Self::shutdown(self, dest, src_port)
391    }
392    fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
393        Self::recv(self, peer, src_port, buffer)
394    }
395    unsafe fn ack_interrupt(_ptr: *mut Self) -> InterruptStatus {
396        panic!("vsock devices cannot acknowledge interrupts")
397    }
398}
399
400impl<M: VirtIOSocketManager> VsockConnectionManagerCommon<M> {
401    /// Allows incoming connections on the given port number.
402    pub fn listen(&mut self, port: u32) {
403        if !self.listening_ports.contains(&port) {
404            self.listening_ports.push(port);
405        }
406    }
407
408    /// Stops allowing incoming connections on the given port number.
409    pub fn unlisten(&mut self, port: u32) {
410        self.listening_ports.retain(|p| *p != port);
411    }
412
413    /// Sends the buffer to the destination.
414    pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
415        let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
416
417        self.driver.send(buffer, &mut connection.info)
418    }
419
420    /// Polls the vsock device to receive data or other updates.
421    pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
422        let local_cid = self.driver.local_cid();
423        let connections = &mut self.connections;
424        let per_connection_buffer_capacity = self.per_connection_buffer_capacity;
425
426        let result = self.driver.poll(|event, body| {
427            let connection = get_connection_for_event(connections, &event, local_cid);
428
429            // Skip events which don't match any connection we know about, unless they are a
430            // connection request.
431            let connection = if let Some((_, connection)) = connection {
432                connection
433            } else if let VsockEventType::ConnectionRequest = event.event_type {
434                // If the requested connection already exists or the CID isn't ours, ignore it.
435                if connection.is_some() || event.destination.cid != local_cid {
436                    return Ok(None);
437                }
438                // Add the new connection to our list, at least for now. It will be removed again
439                // below if we weren't listening on the port.
440                connections.push(Connection::new(
441                    event.source,
442                    event.destination.port,
443                    per_connection_buffer_capacity,
444                ));
445                connections.last_mut().unwrap()
446            } else {
447                return Ok(None);
448            };
449
450            // Update stored connection info.
451            connection.info.update_for_event(&event);
452
453            if let VsockEventType::Received { length } = event.event_type {
454                // Copy to buffer
455                if !connection.buffer.add(body) {
456                    return Err(SocketError::OutputBufferTooShort(length).into());
457                }
458            }
459
460            Ok(Some(event))
461        })?;
462
463        let Some(event) = result else {
464            return Ok(None);
465        };
466
467        // The connection must exist because we found it above in the callback.
468        let (connection_index, connection) =
469            get_connection_for_event(connections, &event, local_cid).unwrap();
470
471        match event.event_type {
472            VsockEventType::ConnectionRequest => {
473                if self.listening_ports.contains(&event.destination.port) {
474                    self.driver.accept(&connection.info)?;
475                } else {
476                    // Reject the connection request and remove it from our list.
477                    self.driver.force_close(&connection.info)?;
478                    self.connections.swap_remove(connection_index);
479
480                    // No need to pass the request on to the client, as we've already rejected it.
481                    return Ok(None);
482                }
483            }
484            VsockEventType::Connected => {}
485            VsockEventType::Disconnected { reason } => {
486                // Wait until client reads all data before removing connection.
487                if connection.buffer.is_empty() {
488                    if reason == DisconnectReason::Shutdown {
489                        self.driver.force_close(&connection.info)?;
490                    }
491                    self.connections.swap_remove(connection_index);
492                } else {
493                    connection.peer_requested_shutdown = true;
494                }
495            }
496            VsockEventType::Received { .. } => {
497                // Already copied the buffer in the callback above.
498            }
499            VsockEventType::CreditRequest => {
500                // If the peer requested credit, send an update.
501                self.driver.credit_update(&connection.info)?;
502                // No need to pass the request on to the client, we've already handled it.
503                return Ok(None);
504            }
505            VsockEventType::CreditUpdate => {}
506        }
507
508        Ok(Some(event))
509    }
510
511    /// Returns the local CID of the vsock device.
512    pub fn local_cid(&self) -> u64 {
513        self.driver.local_cid()
514    }
515
516    /// Reads data received from the given connection.
517    pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
518        let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
519
520        // Copy from ring buffer
521        let bytes_read = connection.buffer.drain(buffer);
522
523        connection.info.done_forwarding(bytes_read);
524
525        // If buffer is now empty and the peer requested shutdown, finish shutting down the
526        // connection.
527        if connection.peer_requested_shutdown && connection.buffer.is_empty() {
528            self.driver.force_close(&connection.info)?;
529            self.connections.swap_remove(connection_index);
530        }
531
532        Ok(bytes_read)
533    }
534
535    /// Returns the number of bytes in the receive buffer available to be read by `recv`.
536    ///
537    /// When the available bytes is 0, it indicates that the receive buffer is empty and does not
538    /// contain any data.
539    pub fn recv_buffer_available_bytes(&mut self, peer: VsockAddr, src_port: u32) -> Result<usize> {
540        let (_, connection) = get_connection(&mut self.connections, peer, src_port)?;
541        Ok(connection.buffer.used())
542    }
543
544    /// Sends a credit update to the given peer.
545    pub fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result {
546        let (_, connection) = get_connection(&mut self.connections, peer, src_port)?;
547        self.driver.credit_update(&connection.info)
548    }
549
550    /// Blocks until we get some event from the vsock device.
551    pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
552        loop {
553            if let Some(event) = self.poll()? {
554                return Ok(event);
555            } else {
556                spin_loop();
557            }
558        }
559    }
560
561    /// Requests to shut down the connection cleanly, telling the peer that we won't send or receive
562    /// any more data.
563    ///
564    /// This returns as soon as the request is sent; you should wait until `poll` returns a
565    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
566    /// shutdown.
567    pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
568        let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
569
570        self.driver.shutdown(&connection.info)
571    }
572
573    /// Forcibly closes the connection without waiting for the peer.
574    pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
575        let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
576
577        self.driver.force_close(&connection.info)?;
578
579        self.connections.swap_remove(index);
580        Ok(())
581    }
582}
583
584/// Returns the connection from the given list matching the given peer address and local port, and
585/// its index.
586///
587/// Returns `Err(SocketError::NotConnected)` if there is no matching connection in the list.
588fn get_connection(
589    connections: &mut [Connection],
590    peer: VsockAddr,
591    local_port: u32,
592) -> core::result::Result<(usize, &mut Connection), SocketError> {
593    connections
594        .iter_mut()
595        .enumerate()
596        .find(|(_, connection)| {
597            connection.info.dst == peer && connection.info.src_port == local_port
598        })
599        .ok_or(SocketError::NotConnected)
600}
601
602/// Returns the connection from the given list matching the event, if any, and its index.
603fn get_connection_for_event<'a>(
604    connections: &'a mut [Connection],
605    event: &VsockEvent,
606    local_cid: u64,
607) -> Option<(usize, &'a mut Connection)> {
608    connections
609        .iter_mut()
610        .enumerate()
611        .find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
612}
613
614#[derive(Debug)]
615struct RingBuffer {
616    buffer: Box<[u8]>,
617    /// The number of bytes currently in the buffer.
618    used: usize,
619    /// The index of the first used byte in the buffer.
620    start: usize,
621}
622
623impl RingBuffer {
624    pub fn new(capacity: usize) -> Self {
625        Self {
626            buffer: FromZeros::new_box_zeroed_with_elems(capacity).unwrap(),
627            used: 0,
628            start: 0,
629        }
630    }
631
632    /// Returns the number of bytes currently used in the buffer.
633    pub fn used(&self) -> usize {
634        self.used
635    }
636
637    /// Returns true iff there are currently no bytes in the buffer.
638    pub fn is_empty(&self) -> bool {
639        self.used == 0
640    }
641
642    /// Returns the number of bytes currently free in the buffer.
643    pub fn free(&self) -> usize {
644        self.buffer.len() - self.used
645    }
646
647    /// Adds the given bytes to the buffer if there is enough capacity for them all.
648    ///
649    /// Returns true if they were added, or false if they were not.
650    pub fn add(&mut self, bytes: &[u8]) -> bool {
651        if bytes.len() > self.free() {
652            return false;
653        }
654
655        // The index of the first available position in the buffer.
656        let first_available = (self.start + self.used) % self.buffer.len();
657        // The number of bytes to copy from `bytes` to `buffer` between `first_available` and
658        // `buffer.len()`.
659        let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
660        self.buffer[first_available..first_available + copy_length_before_wraparound]
661            .copy_from_slice(&bytes[0..copy_length_before_wraparound]);
662        if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
663            self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
664        }
665        self.used += bytes.len();
666
667        true
668    }
669
670    /// Reads and removes as many bytes as possible from the buffer, up to the length of the given
671    /// buffer.
672    pub fn drain(&mut self, out: &mut [u8]) -> usize {
673        let bytes_read = min(self.used, out.len());
674
675        // The number of bytes to copy out between `start` and the end of the buffer.
676        let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
677        // The number of bytes to copy out from the beginning of the buffer after wrapping around.
678        let read_after_wraparound = bytes_read
679            .checked_sub(read_before_wraparound)
680            .unwrap_or_default();
681
682        out[0..read_before_wraparound]
683            .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
684        out[read_before_wraparound..bytes_read]
685            .copy_from_slice(&self.buffer[0..read_after_wraparound]);
686
687        self.used -= bytes_read;
688        self.start = (self.start + bytes_read) % self.buffer.len();
689
690        bytes_read
691    }
692}
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697    use crate::{
698        config::ReadOnly,
699        device::socket::{
700            protocol::{
701                SocketType, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp,
702            },
703            vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
704        },
705        hal::fake::FakeHal,
706        transport::{
707            fake::{FakeTransport, QueueStatus, State},
708            DeviceType,
709        },
710    };
711    use alloc::{sync::Arc, vec};
712    use core::mem::size_of;
713    use std::{sync::Mutex, thread};
714    use zerocopy::{FromBytes, IntoBytes};
715
716    #[test]
717    fn send_recv() {
718        let host_cid = 2;
719        let guest_cid = 66;
720        let host_port = 1234;
721        let guest_port = 4321;
722        let host_address = VsockAddr {
723            cid: host_cid,
724            port: host_port,
725        };
726        let hello_from_guest = "Hello from guest";
727        let hello_from_host = "Hello from host";
728
729        let config_space = VirtioVsockConfig {
730            guest_cid_low: ReadOnly::new(66),
731            guest_cid_high: ReadOnly::new(0),
732        };
733        let state = Arc::new(Mutex::new(State::new(
734            vec![
735                QueueStatus::default(),
736                QueueStatus::default(),
737                QueueStatus::default(),
738            ],
739            config_space,
740        )));
741        let transport = FakeTransport {
742            device_type: DeviceType::Socket,
743            max_queue_size: 32,
744            device_features: 0,
745            state: state.clone(),
746        };
747        let mut socket = VsockConnectionManager::new(
748            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
749        );
750
751        // Start a thread to simulate the device.
752        let handle = thread::spawn(move || {
753            // Wait for connection request.
754            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
755            assert_eq!(
756                VirtioVsockHdr::read_from_bytes(
757                    state
758                        .lock()
759                        .unwrap()
760                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
761                        .as_slice()
762                )
763                .unwrap(),
764                VirtioVsockHdr {
765                    op: VirtioVsockOp::Request.into(),
766                    src_cid: guest_cid.into(),
767                    dst_cid: host_cid.into(),
768                    src_port: guest_port.into(),
769                    dst_port: host_port.into(),
770                    len: 0.into(),
771                    socket_type: SocketType::Stream.into(),
772                    flags: 0.into(),
773                    buf_alloc: 1024.into(),
774                    fwd_cnt: 0.into(),
775                }
776            );
777
778            // Accept connection and give the peer enough credit to send the message.
779            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
780                RX_QUEUE_IDX,
781                VirtioVsockHdr {
782                    op: VirtioVsockOp::Response.into(),
783                    src_cid: host_cid.into(),
784                    dst_cid: guest_cid.into(),
785                    src_port: host_port.into(),
786                    dst_port: guest_port.into(),
787                    len: 0.into(),
788                    socket_type: SocketType::Stream.into(),
789                    flags: 0.into(),
790                    buf_alloc: 50.into(),
791                    fwd_cnt: 0.into(),
792                }
793                .as_bytes(),
794            );
795
796            // Expect the guest to send some data.
797            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
798            let request = state
799                .lock()
800                .unwrap()
801                .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
802            assert_eq!(
803                request.len(),
804                size_of::<VirtioVsockHdr>() + hello_from_guest.len()
805            );
806            assert_eq!(
807                VirtioVsockHdr::read_from_prefix(request.as_slice())
808                    .unwrap()
809                    .0,
810                VirtioVsockHdr {
811                    op: VirtioVsockOp::Rw.into(),
812                    src_cid: guest_cid.into(),
813                    dst_cid: host_cid.into(),
814                    src_port: guest_port.into(),
815                    dst_port: host_port.into(),
816                    len: (hello_from_guest.len() as u32).into(),
817                    socket_type: SocketType::Stream.into(),
818                    flags: 0.into(),
819                    buf_alloc: 1024.into(),
820                    fwd_cnt: 0.into(),
821                }
822            );
823            assert_eq!(
824                &request[size_of::<VirtioVsockHdr>()..],
825                hello_from_guest.as_bytes()
826            );
827
828            println!("Host sending");
829
830            // Send a response.
831            let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
832            VirtioVsockHdr {
833                op: VirtioVsockOp::Rw.into(),
834                src_cid: host_cid.into(),
835                dst_cid: guest_cid.into(),
836                src_port: host_port.into(),
837                dst_port: guest_port.into(),
838                len: (hello_from_host.len() as u32).into(),
839                socket_type: SocketType::Stream.into(),
840                flags: 0.into(),
841                buf_alloc: 50.into(),
842                fwd_cnt: (hello_from_guest.len() as u32).into(),
843            }
844            .write_to_prefix(response.as_mut_slice())
845            .unwrap();
846            response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
847            state
848                .lock()
849                .unwrap()
850                .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
851
852            // Expect a shutdown.
853            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
854            assert_eq!(
855                VirtioVsockHdr::read_from_bytes(
856                    state
857                        .lock()
858                        .unwrap()
859                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
860                        .as_slice()
861                )
862                .unwrap(),
863                VirtioVsockHdr {
864                    op: VirtioVsockOp::Shutdown.into(),
865                    src_cid: guest_cid.into(),
866                    dst_cid: host_cid.into(),
867                    src_port: guest_port.into(),
868                    dst_port: host_port.into(),
869                    len: 0.into(),
870                    socket_type: SocketType::Stream.into(),
871                    flags: (StreamShutdown::SEND | StreamShutdown::RECEIVE).into(),
872                    buf_alloc: 1024.into(),
873                    fwd_cnt: (hello_from_host.len() as u32).into(),
874                }
875            );
876        });
877
878        socket.connect(host_address, guest_port).unwrap();
879        assert_eq!(
880            socket.wait_for_event().unwrap(),
881            VsockEvent {
882                source: host_address,
883                destination: VsockAddr {
884                    cid: guest_cid,
885                    port: guest_port,
886                },
887                event_type: VsockEventType::Connected,
888                buffer_status: VsockBufferStatus {
889                    buffer_allocation: 50,
890                    forward_count: 0,
891                },
892            }
893        );
894        println!("Guest sending");
895        socket
896            .send(host_address, guest_port, "Hello from guest".as_bytes())
897            .unwrap();
898        println!("Guest waiting to receive.");
899        assert_eq!(
900            socket.wait_for_event().unwrap(),
901            VsockEvent {
902                source: host_address,
903                destination: VsockAddr {
904                    cid: guest_cid,
905                    port: guest_port,
906                },
907                event_type: VsockEventType::Received {
908                    length: hello_from_host.len()
909                },
910                buffer_status: VsockBufferStatus {
911                    buffer_allocation: 50,
912                    forward_count: hello_from_guest.len() as u32,
913                },
914            }
915        );
916        println!("Guest getting received data.");
917        let mut buffer = [0u8; 64];
918        assert_eq!(
919            socket.recv(host_address, guest_port, &mut buffer).unwrap(),
920            hello_from_host.len()
921        );
922        assert_eq!(
923            &buffer[0..hello_from_host.len()],
924            hello_from_host.as_bytes()
925        );
926        socket.shutdown(host_address, guest_port).unwrap();
927
928        handle.join().unwrap();
929    }
930
931    #[test]
932    fn incoming_connection() {
933        let host_cid = 2;
934        let guest_cid = 66;
935        let host_port = 1234;
936        let guest_port = 4321;
937        let wrong_guest_port = 4444;
938        let host_address = VsockAddr {
939            cid: host_cid,
940            port: host_port,
941        };
942
943        let config_space = VirtioVsockConfig {
944            guest_cid_low: ReadOnly::new(66),
945            guest_cid_high: ReadOnly::new(0),
946        };
947        let state = Arc::new(Mutex::new(State::new(
948            vec![
949                QueueStatus::default(),
950                QueueStatus::default(),
951                QueueStatus::default(),
952            ],
953            config_space,
954        )));
955        let transport = FakeTransport {
956            device_type: DeviceType::Socket,
957            max_queue_size: 32,
958            device_features: 0,
959            state: state.clone(),
960        };
961        let mut socket = VsockConnectionManager::new(
962            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
963        );
964
965        socket.listen(guest_port);
966
967        // Start a thread to simulate the device.
968        let handle = thread::spawn(move || {
969            // Send a connection request for a port the guest isn't listening on.
970            println!("Host sending connection request to wrong port");
971            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
972                RX_QUEUE_IDX,
973                VirtioVsockHdr {
974                    op: VirtioVsockOp::Request.into(),
975                    src_cid: host_cid.into(),
976                    dst_cid: guest_cid.into(),
977                    src_port: host_port.into(),
978                    dst_port: wrong_guest_port.into(),
979                    len: 0.into(),
980                    socket_type: SocketType::Stream.into(),
981                    flags: 0.into(),
982                    buf_alloc: 50.into(),
983                    fwd_cnt: 0.into(),
984                }
985                .as_bytes(),
986            );
987
988            // Expect a rejection.
989            println!("Host waiting for rejection");
990            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
991            assert_eq!(
992                VirtioVsockHdr::read_from_bytes(
993                    state
994                        .lock()
995                        .unwrap()
996                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
997                        .as_slice()
998                )
999                .unwrap(),
1000                VirtioVsockHdr {
1001                    op: VirtioVsockOp::Rst.into(),
1002                    src_cid: guest_cid.into(),
1003                    dst_cid: host_cid.into(),
1004                    src_port: wrong_guest_port.into(),
1005                    dst_port: host_port.into(),
1006                    len: 0.into(),
1007                    socket_type: SocketType::Stream.into(),
1008                    flags: 0.into(),
1009                    buf_alloc: 1024.into(),
1010                    fwd_cnt: 0.into(),
1011                }
1012            );
1013
1014            // Send a connection request for a port the guest is listening on.
1015            println!("Host sending connection request to right port");
1016            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
1017                RX_QUEUE_IDX,
1018                VirtioVsockHdr {
1019                    op: VirtioVsockOp::Request.into(),
1020                    src_cid: host_cid.into(),
1021                    dst_cid: guest_cid.into(),
1022                    src_port: host_port.into(),
1023                    dst_port: guest_port.into(),
1024                    len: 0.into(),
1025                    socket_type: SocketType::Stream.into(),
1026                    flags: 0.into(),
1027                    buf_alloc: 50.into(),
1028                    fwd_cnt: 0.into(),
1029                }
1030                .as_bytes(),
1031            );
1032
1033            // Expect a response.
1034            println!("Host waiting for response");
1035            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
1036            assert_eq!(
1037                VirtioVsockHdr::read_from_bytes(
1038                    state
1039                        .lock()
1040                        .unwrap()
1041                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
1042                        .as_slice()
1043                )
1044                .unwrap(),
1045                VirtioVsockHdr {
1046                    op: VirtioVsockOp::Response.into(),
1047                    src_cid: guest_cid.into(),
1048                    dst_cid: host_cid.into(),
1049                    src_port: guest_port.into(),
1050                    dst_port: host_port.into(),
1051                    len: 0.into(),
1052                    socket_type: SocketType::Stream.into(),
1053                    flags: 0.into(),
1054                    buf_alloc: 1024.into(),
1055                    fwd_cnt: 0.into(),
1056                }
1057            );
1058
1059            println!("Host finished");
1060        });
1061
1062        // Expect an incoming connection.
1063        println!("Guest expecting incoming connection.");
1064        assert_eq!(
1065            socket.wait_for_event().unwrap(),
1066            VsockEvent {
1067                source: host_address,
1068                destination: VsockAddr {
1069                    cid: guest_cid,
1070                    port: guest_port,
1071                },
1072                event_type: VsockEventType::ConnectionRequest,
1073                buffer_status: VsockBufferStatus {
1074                    buffer_allocation: 50,
1075                    forward_count: 0,
1076                },
1077            }
1078        );
1079
1080        handle.join().unwrap();
1081    }
1082}