1use super::DEFAULT_RX_BUFFER_SIZE;
4use super::error::SocketError;
5use super::protocol::{
6 Feature, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr,
7};
8use crate::Result;
9use crate::config::read_config;
10use crate::hal::Hal;
11use crate::queue::{OwningQueue, VirtQueue};
12use crate::transport::Transport;
13use core::mem::size_of;
14use log::debug;
15use zerocopy::{FromBytes, IntoBytes};
16
17pub(crate) const RX_QUEUE_IDX: u16 = 0;
18pub(crate) const TX_QUEUE_IDX: u16 = 1;
19const EVENT_QUEUE_IDX: u16 = 2;
20
21pub(crate) const QUEUE_SIZE: usize = 8;
22const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX
23 .union(Feature::RING_INDIRECT_DESC)
24 .union(Feature::VERSION_1);
25
26#[derive(Clone, Debug, Default, PartialEq, Eq)]
28pub struct ConnectionInfo {
29 pub dst: VsockAddr,
31 pub src_port: u32,
33 peer_buf_alloc: u32,
36 peer_fwd_cnt: u32,
39 tx_cnt: u32,
41 pub buf_alloc: u32,
44 fwd_cnt: u32,
46 has_pending_credit_request: bool,
51}
52
53impl ConnectionInfo {
54 pub fn new(destination: VsockAddr, src_port: u32) -> Self {
57 Self {
58 dst: destination,
59 src_port,
60 ..Default::default()
61 }
62 }
63
64 pub fn update_for_event(&mut self, event: &VsockEvent) {
67 self.peer_buf_alloc = event.buffer_status.buffer_allocation;
68 self.peer_fwd_cnt = event.buffer_status.forward_count;
69
70 if let VsockEventType::CreditUpdate = event.event_type {
71 self.has_pending_credit_request = false;
72 }
73 }
74
75 pub fn done_forwarding(&mut self, length: usize) {
80 self.fwd_cnt += length as u32;
81 }
82
83 fn peer_free(&self) -> u32 {
86 self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
87 }
88
89 fn new_header(&self, src_cid: u64) -> VirtioVsockHdr {
90 VirtioVsockHdr {
91 src_cid: src_cid.into(),
92 dst_cid: self.dst.cid.into(),
93 src_port: self.src_port.into(),
94 dst_port: self.dst.port.into(),
95 buf_alloc: self.buf_alloc.into(),
96 fwd_cnt: self.fwd_cnt.into(),
97 ..Default::default()
98 }
99 }
100}
101
102#[derive(Clone, Debug, Eq, PartialEq)]
104pub struct VsockEvent {
105 pub source: VsockAddr,
107 pub destination: VsockAddr,
109 pub buffer_status: VsockBufferStatus,
111 pub event_type: VsockEventType,
113}
114
115impl VsockEvent {
116 pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool {
118 self.source == connection_info.dst
119 && self.destination.cid == guest_cid
120 && self.destination.port == connection_info.src_port
121 }
122
123 fn from_header(header: &VirtioVsockHdr) -> Result<Self> {
124 let op = header.op()?;
125 let buffer_status = VsockBufferStatus {
126 buffer_allocation: header.buf_alloc.into(),
127 forward_count: header.fwd_cnt.into(),
128 };
129 let source = header.source();
130 let destination = header.destination();
131
132 let event_type = match op {
133 VirtioVsockOp::Request => {
134 header.check_data_is_empty()?;
135 VsockEventType::ConnectionRequest
136 }
137 VirtioVsockOp::Response => {
138 header.check_data_is_empty()?;
139 VsockEventType::Connected
140 }
141 VirtioVsockOp::CreditUpdate => {
142 header.check_data_is_empty()?;
143 VsockEventType::CreditUpdate
144 }
145 VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
146 header.check_data_is_empty()?;
147 debug!("Disconnected from the peer");
148 let reason = if op == VirtioVsockOp::Rst {
149 DisconnectReason::Reset
150 } else {
151 DisconnectReason::Shutdown
152 };
153 VsockEventType::Disconnected { reason }
154 }
155 VirtioVsockOp::Rw => VsockEventType::Received {
156 length: header.len() as usize,
157 },
158 VirtioVsockOp::CreditRequest => {
159 header.check_data_is_empty()?;
160 VsockEventType::CreditRequest
161 }
162 VirtioVsockOp::Invalid => return Err(SocketError::InvalidOperation.into()),
163 };
164
165 Ok(VsockEvent {
166 source,
167 destination,
168 buffer_status,
169 event_type,
170 })
171 }
172}
173
174#[derive(Clone, Debug, Eq, PartialEq)]
175pub struct VsockBufferStatus {
176 pub buffer_allocation: u32,
177 pub forward_count: u32,
178}
179
180#[derive(Copy, Clone, Debug, Eq, PartialEq)]
182pub enum DisconnectReason {
183 Reset,
186 Shutdown,
188}
189
190#[derive(Clone, Debug, Eq, PartialEq)]
192pub enum VsockEventType {
193 ConnectionRequest,
195 Connected,
197 Disconnected {
199 reason: DisconnectReason,
201 },
202 Received {
204 length: usize,
206 },
207 CreditRequest,
209 CreditUpdate,
211}
212
213pub struct VirtIOSocket<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE>
221{
222 transport: T,
223 rx: OwningQueue<H, QUEUE_SIZE, RX_BUFFER_SIZE>,
225 tx: VirtQueue<H, { QUEUE_SIZE }>,
226 event: VirtQueue<H, { QUEUE_SIZE }>,
228 guest_cid: u64,
231}
232
233impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> Drop
234 for VirtIOSocket<H, T, RX_BUFFER_SIZE>
235{
236 fn drop(&mut self) {
237 self.transport.queue_unset(RX_QUEUE_IDX);
240 self.transport.queue_unset(TX_QUEUE_IDX);
241 self.transport.queue_unset(EVENT_QUEUE_IDX);
242 }
243}
244
245impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BUFFER_SIZE> {
246 pub fn new(mut transport: T) -> Result<Self> {
248 assert!(RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>());
249
250 let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
251
252 let guest_cid = transport.read_consistent(|| {
253 Ok(
254 (read_config!(transport, VirtioVsockConfig, guest_cid_low)? as u64)
255 | ((read_config!(transport, VirtioVsockConfig, guest_cid_high)? as u64) << 32),
256 )
257 })?;
258 debug!("guest cid: {guest_cid:?}");
259
260 let rx = VirtQueue::new(
261 &mut transport,
262 RX_QUEUE_IDX,
263 negotiated_features.contains(Feature::RING_INDIRECT_DESC),
264 negotiated_features.contains(Feature::RING_EVENT_IDX),
265 )?;
266 let tx = VirtQueue::new(
267 &mut transport,
268 TX_QUEUE_IDX,
269 negotiated_features.contains(Feature::RING_INDIRECT_DESC),
270 negotiated_features.contains(Feature::RING_EVENT_IDX),
271 )?;
272 let event = VirtQueue::new(
273 &mut transport,
274 EVENT_QUEUE_IDX,
275 negotiated_features.contains(Feature::RING_INDIRECT_DESC),
276 negotiated_features.contains(Feature::RING_EVENT_IDX),
277 )?;
278
279 let rx = OwningQueue::new(rx)?;
280
281 transport.finish_init();
282 if rx.should_notify() {
283 transport.notify(RX_QUEUE_IDX);
284 }
285
286 Ok(Self {
287 transport,
288 rx,
289 tx,
290 event,
291 guest_cid,
292 })
293 }
294
295 pub fn guest_cid(&self) -> u64 {
297 self.guest_cid
298 }
299
300 pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result {
306 let header = VirtioVsockHdr {
307 op: VirtioVsockOp::Request.into(),
308 ..connection_info.new_header(self.guest_cid)
309 };
310 self.send_packet_to_tx_queue(&header, &[])
313 }
314
315 pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result {
317 let header = VirtioVsockHdr {
318 op: VirtioVsockOp::Response.into(),
319 ..connection_info.new_header(self.guest_cid)
320 };
321 self.send_packet_to_tx_queue(&header, &[])
322 }
323
324 fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result {
326 let header = VirtioVsockHdr {
327 op: VirtioVsockOp::CreditRequest.into(),
328 ..connection_info.new_header(self.guest_cid)
329 };
330 self.send_packet_to_tx_queue(&header, &[])
331 }
332
333 pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result {
335 self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
336
337 let len = buffer.len() as u32;
338 let header = VirtioVsockHdr {
339 op: VirtioVsockOp::Rw.into(),
340 len: len.into(),
341 ..connection_info.new_header(self.guest_cid)
342 };
343 connection_info.tx_cnt += len;
344 self.send_packet_to_tx_queue(&header, buffer)
345 }
346
347 fn check_peer_buffer_is_sufficient(
348 &mut self,
349 connection_info: &mut ConnectionInfo,
350 buffer_len: usize,
351 ) -> Result {
352 if connection_info.peer_free() as usize >= buffer_len {
353 Ok(())
354 } else {
355 if !connection_info.has_pending_credit_request {
358 self.request_credit(connection_info)?;
359 connection_info.has_pending_credit_request = true;
360 }
361 Err(SocketError::InsufficientBufferSpaceInPeer.into())
362 }
363 }
364
365 pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result {
367 let header = VirtioVsockHdr {
368 op: VirtioVsockOp::CreditUpdate.into(),
369 ..connection_info.new_header(self.guest_cid)
370 };
371 self.send_packet_to_tx_queue(&header, &[])
372 }
373
374 pub fn poll(
377 &mut self,
378 handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>,
379 ) -> Result<Option<VsockEvent>> {
380 self.rx.poll(&mut self.transport, |buffer| {
381 let (header, body) = read_header_and_body(buffer)?;
382 VsockEvent::from_header(&header).and_then(|event| handler(event, body))
383 })
384 }
385
386 pub fn shutdown_with_hints(
393 &mut self,
394 connection_info: &ConnectionInfo,
395 hints: StreamShutdown,
396 ) -> Result {
397 let header = VirtioVsockHdr {
398 op: VirtioVsockOp::Shutdown.into(),
399 flags: hints.into(),
400 ..connection_info.new_header(self.guest_cid)
401 };
402 self.send_packet_to_tx_queue(&header, &[])
403 }
404
405 pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result {
412 self.shutdown_with_hints(
413 connection_info,
414 StreamShutdown::SEND | StreamShutdown::RECEIVE,
415 )
416 }
417
418 pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result {
420 let header = VirtioVsockHdr {
421 op: VirtioVsockOp::Rst.into(),
422 ..connection_info.new_header(self.guest_cid)
423 };
424 self.send_packet_to_tx_queue(&header, &[])?;
425 Ok(())
426 }
427
428 fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
429 let _len = if buffer.is_empty() {
430 self.tx
431 .add_notify_wait_pop(&[header.as_bytes()], &mut [], &mut self.transport)?
432 } else {
433 self.tx.add_notify_wait_pop(
434 &[header.as_bytes(), buffer],
435 &mut [],
436 &mut self.transport,
437 )?
438 };
439 Ok(())
440 }
441}
442
443fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> {
444 let header = VirtioVsockHdr::read_from_prefix(buffer)
446 .map_err(|_| SocketError::BufferTooShort)?
447 .0;
448 let body_length = header.len() as usize;
449
450 let data_end = size_of::<VirtioVsockHdr>()
452 .checked_add(body_length)
453 .ok_or(SocketError::InvalidNumber)?;
454 let data = buffer
457 .get(size_of::<VirtioVsockHdr>()..data_end)
458 .ok_or(SocketError::BufferTooShort)?;
459 Ok((header, data))
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use crate::{
466 config::ReadOnly,
467 hal::fake::FakeHal,
468 transport::{
469 DeviceType,
470 fake::{FakeTransport, QueueStatus, State},
471 },
472 };
473 use alloc::{sync::Arc, vec};
474 use std::sync::Mutex;
475
476 #[test]
477 fn config() {
478 let config_space = VirtioVsockConfig {
479 guest_cid_low: ReadOnly::new(66),
480 guest_cid_high: ReadOnly::new(0),
481 };
482 let state = Arc::new(Mutex::new(State::new(
483 vec![
484 QueueStatus::default(),
485 QueueStatus::default(),
486 QueueStatus::default(),
487 ],
488 config_space,
489 )));
490 let transport = FakeTransport {
491 device_type: DeviceType::Socket,
492 max_queue_size: 32,
493 device_features: 0,
494 state: state.clone(),
495 };
496 let socket =
497 VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
498 assert_eq!(socket.guest_cid(), 0x00_0000_0042);
499 }
500}