Skip to main content

virtio_drivers/device/socket/
vsock.rs

1//! Driver for VirtIO socket devices.
2
3use super::DEFAULT_RX_BUFFER_SIZE;
4use super::error::SocketError;
5use super::protocol::{
6    Feature, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr,
7};
8use crate::Result;
9use crate::config::read_config;
10use crate::hal::Hal;
11use crate::queue::{OwningQueue, VirtQueue};
12use crate::transport::Transport;
13use core::mem::size_of;
14use log::debug;
15use zerocopy::{FromBytes, IntoBytes};
16
17pub(crate) const RX_QUEUE_IDX: u16 = 0;
18pub(crate) const TX_QUEUE_IDX: u16 = 1;
19const EVENT_QUEUE_IDX: u16 = 2;
20
21pub(crate) const QUEUE_SIZE: usize = 8;
22const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX
23    .union(Feature::RING_INDIRECT_DESC)
24    .union(Feature::VERSION_1);
25
26/// Information about a particular vsock connection.
27#[derive(Clone, Debug, Default, PartialEq, Eq)]
28pub struct ConnectionInfo {
29    /// The address of the peer.
30    pub dst: VsockAddr,
31    /// The local port number associated with the connection.
32    pub src_port: u32,
33    /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in
34    /// bytes it has allocated for packet bodies.
35    peer_buf_alloc: u32,
36    /// The last `fwd_cnt` value the peer sent to us, indicating how many bytes of packet bodies it
37    /// has finished processing.
38    peer_fwd_cnt: u32,
39    /// The number of bytes of packet bodies which we have sent to the peer.
40    tx_cnt: u32,
41    /// The number of bytes of buffer space we have allocated to receive packet bodies from the
42    /// peer.
43    pub buf_alloc: u32,
44    /// The number of bytes of packet bodies which we have received from the peer and handled.
45    fwd_cnt: u32,
46    /// Whether we have recently requested credit from the peer.
47    ///
48    /// This is set to true when we send a `VIRTIO_VSOCK_OP_CREDIT_REQUEST`, and false when we
49    /// receive a `VIRTIO_VSOCK_OP_CREDIT_UPDATE`.
50    has_pending_credit_request: bool,
51}
52
53impl ConnectionInfo {
54    /// Creates a new `ConnectionInfo` for the given peer address and local port, and default values
55    /// for everything else.
56    pub fn new(destination: VsockAddr, src_port: u32) -> Self {
57        Self {
58            dst: destination,
59            src_port,
60            ..Default::default()
61        }
62    }
63
64    /// Updates this connection info with the peer buffer allocation and forwarded count from the
65    /// given event.
66    pub fn update_for_event(&mut self, event: &VsockEvent) {
67        self.peer_buf_alloc = event.buffer_status.buffer_allocation;
68        self.peer_fwd_cnt = event.buffer_status.forward_count;
69
70        if let VsockEventType::CreditUpdate = event.event_type {
71            self.has_pending_credit_request = false;
72        }
73    }
74
75    /// Increases the forwarded count recorded for this connection by the given number of bytes.
76    ///
77    /// This should be called once received data has been passed to the client, so there is buffer
78    /// space available for more.
79    pub fn done_forwarding(&mut self, length: usize) {
80        self.fwd_cnt += length as u32;
81    }
82
83    /// Returns the number of bytes of RX buffer space the peer has available to receive packet body
84    /// data from us.
85    fn peer_free(&self) -> u32 {
86        self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
87    }
88
89    fn new_header(&self, src_cid: u64) -> VirtioVsockHdr {
90        VirtioVsockHdr {
91            src_cid: src_cid.into(),
92            dst_cid: self.dst.cid.into(),
93            src_port: self.src_port.into(),
94            dst_port: self.dst.port.into(),
95            buf_alloc: self.buf_alloc.into(),
96            fwd_cnt: self.fwd_cnt.into(),
97            ..Default::default()
98        }
99    }
100}
101
102/// An event received from a VirtIO socket device.
103#[derive(Clone, Debug, Eq, PartialEq)]
104pub struct VsockEvent {
105    /// The source of the event, i.e. the peer who sent it.
106    pub source: VsockAddr,
107    /// The destination of the event, i.e. the CID and port on our side.
108    pub destination: VsockAddr,
109    /// The peer's buffer status for the connection.
110    pub buffer_status: VsockBufferStatus,
111    /// The type of event.
112    pub event_type: VsockEventType,
113}
114
115impl VsockEvent {
116    /// Returns whether the event matches the given connection.
117    pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool {
118        self.source == connection_info.dst
119            && self.destination.cid == guest_cid
120            && self.destination.port == connection_info.src_port
121    }
122
123    fn from_header(header: &VirtioVsockHdr) -> Result<Self> {
124        let op = header.op()?;
125        let buffer_status = VsockBufferStatus {
126            buffer_allocation: header.buf_alloc.into(),
127            forward_count: header.fwd_cnt.into(),
128        };
129        let source = header.source();
130        let destination = header.destination();
131
132        let event_type = match op {
133            VirtioVsockOp::Request => {
134                header.check_data_is_empty()?;
135                VsockEventType::ConnectionRequest
136            }
137            VirtioVsockOp::Response => {
138                header.check_data_is_empty()?;
139                VsockEventType::Connected
140            }
141            VirtioVsockOp::CreditUpdate => {
142                header.check_data_is_empty()?;
143                VsockEventType::CreditUpdate
144            }
145            VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
146                header.check_data_is_empty()?;
147                debug!("Disconnected from the peer");
148                let reason = if op == VirtioVsockOp::Rst {
149                    DisconnectReason::Reset
150                } else {
151                    DisconnectReason::Shutdown
152                };
153                VsockEventType::Disconnected { reason }
154            }
155            VirtioVsockOp::Rw => VsockEventType::Received {
156                length: header.len() as usize,
157            },
158            VirtioVsockOp::CreditRequest => {
159                header.check_data_is_empty()?;
160                VsockEventType::CreditRequest
161            }
162            VirtioVsockOp::Invalid => return Err(SocketError::InvalidOperation.into()),
163        };
164
165        Ok(VsockEvent {
166            source,
167            destination,
168            buffer_status,
169            event_type,
170        })
171    }
172}
173
174#[derive(Clone, Debug, Eq, PartialEq)]
175pub struct VsockBufferStatus {
176    pub buffer_allocation: u32,
177    pub forward_count: u32,
178}
179
180/// The reason why a vsock connection was closed.
181#[derive(Copy, Clone, Debug, Eq, PartialEq)]
182pub enum DisconnectReason {
183    /// The peer has either closed the connection in response to our shutdown request, or forcibly
184    /// closed it of its own accord.
185    Reset,
186    /// The peer asked to shut down the connection.
187    Shutdown,
188}
189
190/// Details of the type of an event received from a VirtIO socket.
191#[derive(Clone, Debug, Eq, PartialEq)]
192pub enum VsockEventType {
193    /// The peer requests to establish a connection with us.
194    ConnectionRequest,
195    /// The connection was successfully established.
196    Connected,
197    /// The connection was closed.
198    Disconnected {
199        /// The reason for the disconnection.
200        reason: DisconnectReason,
201    },
202    /// Data was received on the connection.
203    Received {
204        /// The length of the data in bytes.
205        length: usize,
206    },
207    /// The peer requests us to send a credit update.
208    CreditRequest,
209    /// The peer just sent us a credit update with nothing else.
210    CreditUpdate,
211}
212
213/// Low-level driver for a VirtIO socket device.
214///
215/// You probably want to use [`VsockConnectionManager`](super::VsockConnectionManager) rather than
216/// using this directly.
217///
218/// `RX_BUFFER_SIZE` is the size in bytes of each buffer used in the RX virtqueue. This must be
219/// bigger than `size_of::<VirtioVsockHdr>()`.
220pub struct VirtIOSocket<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE>
221{
222    transport: T,
223    /// Virtqueue to receive packets.
224    rx: OwningQueue<H, QUEUE_SIZE, RX_BUFFER_SIZE>,
225    tx: VirtQueue<H, { QUEUE_SIZE }>,
226    /// Virtqueue to receive events from the device.
227    event: VirtQueue<H, { QUEUE_SIZE }>,
228    /// The guest_cid field contains the guest’s context ID, which uniquely identifies
229    /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
230    guest_cid: u64,
231}
232
233impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> Drop
234    for VirtIOSocket<H, T, RX_BUFFER_SIZE>
235{
236    fn drop(&mut self) {
237        // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
238        // after they have been freed.
239        self.transport.queue_unset(RX_QUEUE_IDX);
240        self.transport.queue_unset(TX_QUEUE_IDX);
241        self.transport.queue_unset(EVENT_QUEUE_IDX);
242    }
243}
244
245impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BUFFER_SIZE> {
246    /// Create a new VirtIO Vsock driver.
247    pub fn new(mut transport: T) -> Result<Self> {
248        assert!(RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>());
249
250        let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
251
252        let guest_cid = transport.read_consistent(|| {
253            Ok(
254                (read_config!(transport, VirtioVsockConfig, guest_cid_low)? as u64)
255                    | ((read_config!(transport, VirtioVsockConfig, guest_cid_high)? as u64) << 32),
256            )
257        })?;
258        debug!("guest cid: {guest_cid:?}");
259
260        let rx = VirtQueue::new(
261            &mut transport,
262            RX_QUEUE_IDX,
263            negotiated_features.contains(Feature::RING_INDIRECT_DESC),
264            negotiated_features.contains(Feature::RING_EVENT_IDX),
265        )?;
266        let tx = VirtQueue::new(
267            &mut transport,
268            TX_QUEUE_IDX,
269            negotiated_features.contains(Feature::RING_INDIRECT_DESC),
270            negotiated_features.contains(Feature::RING_EVENT_IDX),
271        )?;
272        let event = VirtQueue::new(
273            &mut transport,
274            EVENT_QUEUE_IDX,
275            negotiated_features.contains(Feature::RING_INDIRECT_DESC),
276            negotiated_features.contains(Feature::RING_EVENT_IDX),
277        )?;
278
279        let rx = OwningQueue::new(rx)?;
280
281        transport.finish_init();
282        if rx.should_notify() {
283            transport.notify(RX_QUEUE_IDX);
284        }
285
286        Ok(Self {
287            transport,
288            rx,
289            tx,
290            event,
291            guest_cid,
292        })
293    }
294
295    /// Returns the CID which has been assigned to this guest.
296    pub fn guest_cid(&self) -> u64 {
297        self.guest_cid
298    }
299
300    /// Sends a request to connect to the given destination.
301    ///
302    /// This returns as soon as the request is sent; you should wait until `poll` returns a
303    /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
304    /// before sending data.
305    pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result {
306        let header = VirtioVsockHdr {
307            op: VirtioVsockOp::Request.into(),
308            ..connection_info.new_header(self.guest_cid)
309        };
310        // Sends a header only packet to the TX queue to connect the device to the listening socket
311        // at the given destination.
312        self.send_packet_to_tx_queue(&header, &[])
313    }
314
315    /// Accepts the given connection from a peer.
316    pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result {
317        let header = VirtioVsockHdr {
318            op: VirtioVsockOp::Response.into(),
319            ..connection_info.new_header(self.guest_cid)
320        };
321        self.send_packet_to_tx_queue(&header, &[])
322    }
323
324    /// Requests the peer to send us a credit update for the given connection.
325    fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result {
326        let header = VirtioVsockHdr {
327            op: VirtioVsockOp::CreditRequest.into(),
328            ..connection_info.new_header(self.guest_cid)
329        };
330        self.send_packet_to_tx_queue(&header, &[])
331    }
332
333    /// Sends the buffer to the destination.
334    pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result {
335        self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
336
337        let len = buffer.len() as u32;
338        let header = VirtioVsockHdr {
339            op: VirtioVsockOp::Rw.into(),
340            len: len.into(),
341            ..connection_info.new_header(self.guest_cid)
342        };
343        connection_info.tx_cnt += len;
344        self.send_packet_to_tx_queue(&header, buffer)
345    }
346
347    fn check_peer_buffer_is_sufficient(
348        &mut self,
349        connection_info: &mut ConnectionInfo,
350        buffer_len: usize,
351    ) -> Result {
352        if connection_info.peer_free() as usize >= buffer_len {
353            Ok(())
354        } else {
355            // Request an update of the cached peer credit, if we haven't already done so, and tell
356            // the caller to try again later.
357            if !connection_info.has_pending_credit_request {
358                self.request_credit(connection_info)?;
359                connection_info.has_pending_credit_request = true;
360            }
361            Err(SocketError::InsufficientBufferSpaceInPeer.into())
362        }
363    }
364
365    /// Tells the peer how much buffer space we have to receive data.
366    pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result {
367        let header = VirtioVsockHdr {
368            op: VirtioVsockOp::CreditUpdate.into(),
369            ..connection_info.new_header(self.guest_cid)
370        };
371        self.send_packet_to_tx_queue(&header, &[])
372    }
373
374    /// Polls the RX virtqueue for the next event, and calls the given handler function to handle
375    /// it.
376    pub fn poll(
377        &mut self,
378        handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>,
379    ) -> Result<Option<VsockEvent>> {
380        self.rx.poll(&mut self.transport, |buffer| {
381            let (header, body) = read_header_and_body(buffer)?;
382            VsockEvent::from_header(&header).and_then(|event| handler(event, body))
383        })
384    }
385
386    /// Requests to shut down the connection cleanly, sending hints about whether we will send or
387    /// receive more data.
388    ///
389    /// This returns as soon as the request is sent; you should wait until `poll` returns a
390    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
391    /// shutdown.
392    pub fn shutdown_with_hints(
393        &mut self,
394        connection_info: &ConnectionInfo,
395        hints: StreamShutdown,
396    ) -> Result {
397        let header = VirtioVsockHdr {
398            op: VirtioVsockOp::Shutdown.into(),
399            flags: hints.into(),
400            ..connection_info.new_header(self.guest_cid)
401        };
402        self.send_packet_to_tx_queue(&header, &[])
403    }
404
405    /// Requests to shut down the connection cleanly, telling the peer that we won't send or receive
406    /// any more data.
407    ///
408    /// This returns as soon as the request is sent; you should wait until `poll` returns a
409    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
410    /// shutdown.
411    pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result {
412        self.shutdown_with_hints(
413            connection_info,
414            StreamShutdown::SEND | StreamShutdown::RECEIVE,
415        )
416    }
417
418    /// Forcibly closes the connection without waiting for the peer.
419    pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result {
420        let header = VirtioVsockHdr {
421            op: VirtioVsockOp::Rst.into(),
422            ..connection_info.new_header(self.guest_cid)
423        };
424        self.send_packet_to_tx_queue(&header, &[])?;
425        Ok(())
426    }
427
428    fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
429        let _len = if buffer.is_empty() {
430            self.tx
431                .add_notify_wait_pop(&[header.as_bytes()], &mut [], &mut self.transport)?
432        } else {
433            self.tx.add_notify_wait_pop(
434                &[header.as_bytes(), buffer],
435                &mut [],
436                &mut self.transport,
437            )?
438        };
439        Ok(())
440    }
441}
442
443fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> {
444    // This could fail if the device returns a buffer used length shorter than the header size.
445    let header = VirtioVsockHdr::read_from_prefix(buffer)
446        .map_err(|_| SocketError::BufferTooShort)?
447        .0;
448    let body_length = header.len() as usize;
449
450    // This could fail if the device returns an unreasonably long body length.
451    let data_end = size_of::<VirtioVsockHdr>()
452        .checked_add(body_length)
453        .ok_or(SocketError::InvalidNumber)?;
454    // This could fail if the device returns a body length longer than buffer used length it
455    // returned.
456    let data = buffer
457        .get(size_of::<VirtioVsockHdr>()..data_end)
458        .ok_or(SocketError::BufferTooShort)?;
459    Ok((header, data))
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use crate::{
466        config::ReadOnly,
467        hal::fake::FakeHal,
468        transport::{
469            DeviceType,
470            fake::{FakeTransport, QueueStatus, State},
471        },
472    };
473    use alloc::{sync::Arc, vec};
474    use std::sync::Mutex;
475
476    #[test]
477    fn config() {
478        let config_space = VirtioVsockConfig {
479            guest_cid_low: ReadOnly::new(66),
480            guest_cid_high: ReadOnly::new(0),
481        };
482        let state = Arc::new(Mutex::new(State::new(
483            vec![
484                QueueStatus::default(),
485                QueueStatus::default(),
486                QueueStatus::default(),
487            ],
488            config_space,
489        )));
490        let transport = FakeTransport {
491            device_type: DeviceType::Socket,
492            max_queue_size: 32,
493            device_features: 0,
494            state: state.clone(),
495        };
496        let socket =
497            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
498        assert_eq!(socket.guest_cid(), 0x00_0000_0042);
499    }
500}