1use std::sync::Arc;
19
20use parking_lot::Mutex;
21use squib_core::GuestMemory;
22
23use crate::{
24 device::{ActivateError, VirtioDevice},
25 device_id::VirtioDeviceType,
26 interrupt::IrqLine,
27 queue::Queue,
28};
29
30pub const VMADDR_CID_HYPERVISOR: u64 = 0;
32pub const VMADDR_CID_LOCAL: u64 = 1;
34pub const VMADDR_CID_HOST: u64 = 2;
36pub const VMADDR_CID_ANY: u64 = u64::MAX;
38pub const MIN_GUEST_CID: u64 = 3;
40
41pub const TYPE_STREAM: u16 = 1;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46#[repr(u16)]
47pub enum VsockOp {
48 Invalid = 0,
50 Request = 1,
52 Response = 2,
54 Rst = 3,
56 Shutdown = 4,
58 Rw = 5,
60 CreditUpdate = 6,
62 CreditRequest = 7,
64}
65
66impl VsockOp {
67 #[must_use]
69 pub fn from_wire(value: u16) -> Self {
70 match value {
71 1 => Self::Request,
72 2 => Self::Response,
73 3 => Self::Rst,
74 4 => Self::Shutdown,
75 5 => Self::Rw,
76 6 => Self::CreditUpdate,
77 7 => Self::CreditRequest,
78 _ => Self::Invalid,
79 }
80 }
81}
82
83#[derive(Debug, Clone, Copy)]
85pub struct VsockHeader {
86 pub src_cid: u64,
88 pub dst_cid: u64,
90 pub src_port: u32,
92 pub dst_port: u32,
94 pub len: u32,
96 pub type_: u16,
98 pub op: VsockOp,
100 pub flags: u32,
102 pub buf_alloc: u32,
104 pub fwd_cnt: u32,
106}
107
108const HDR_SIZE: usize = 44;
109
110#[derive(Debug, thiserror::Error, PartialEq, Eq)]
112#[non_exhaustive]
113pub enum VsockParseError {
114 #[error("vsock packet shorter than 44-byte header")]
116 HeaderTooShort,
117}
118
119impl VsockHeader {
120 pub fn from_bytes(bytes: &[u8]) -> Result<Self, VsockParseError> {
125 if bytes.len() < HDR_SIZE {
126 return Err(VsockParseError::HeaderTooShort);
127 }
128 let u64 = |i: usize| {
132 u64::from_le_bytes([
133 bytes[i],
134 bytes[i + 1],
135 bytes[i + 2],
136 bytes[i + 3],
137 bytes[i + 4],
138 bytes[i + 5],
139 bytes[i + 6],
140 bytes[i + 7],
141 ])
142 };
143 let u32 =
144 |i: usize| u32::from_le_bytes([bytes[i], bytes[i + 1], bytes[i + 2], bytes[i + 3]]);
145 let u16 = |i: usize| u16::from_le_bytes([bytes[i], bytes[i + 1]]);
146 Ok(Self {
147 src_cid: u64(0),
148 dst_cid: u64(8),
149 src_port: u32(16),
150 dst_port: u32(20),
151 len: u32(24),
152 type_: u16(28),
153 op: VsockOp::from_wire(u16(30)),
154 flags: u32(32),
155 buf_alloc: u32(36),
156 fwd_cnt: u32(40),
157 })
158 }
159
160 #[must_use]
162 pub fn to_bytes(&self) -> [u8; HDR_SIZE] {
163 let mut out = [0u8; HDR_SIZE];
164 out[0..8].copy_from_slice(&self.src_cid.to_le_bytes());
165 out[8..16].copy_from_slice(&self.dst_cid.to_le_bytes());
166 out[16..20].copy_from_slice(&self.src_port.to_le_bytes());
167 out[20..24].copy_from_slice(&self.dst_port.to_le_bytes());
168 out[24..28].copy_from_slice(&self.len.to_le_bytes());
169 out[28..30].copy_from_slice(&self.type_.to_le_bytes());
170 out[30..32].copy_from_slice(&(self.op as u16).to_le_bytes());
171 out[32..36].copy_from_slice(&self.flags.to_le_bytes());
172 out[36..40].copy_from_slice(&self.buf_alloc.to_le_bytes());
173 out[40..44].copy_from_slice(&self.fwd_cnt.to_le_bytes());
174 out
175 }
176}
177
178#[derive(Debug, Clone)]
180pub struct VsockPacket {
181 pub hdr: VsockHeader,
183 pub payload: Vec<u8>,
185}
186
187pub trait VsockMuxer: Send + Sync + std::fmt::Debug {
191 fn handle_tx(&self, pkt: VsockPacket) -> Vec<VsockPacket>;
195
196 fn drain_rx(&self) -> Vec<VsockPacket>;
199}
200
201#[derive(Debug, Default)]
204pub struct InMemoryMuxer {
205 pub tx_log: Mutex<Vec<VsockPacket>>,
207 pub rx_queue: Mutex<Vec<VsockPacket>>,
209 pub auto_respond: bool,
211}
212
213impl VsockMuxer for InMemoryMuxer {
214 fn handle_tx(&self, pkt: VsockPacket) -> Vec<VsockPacket> {
215 let mut replies = Vec::new();
216 if self.auto_respond && pkt.hdr.op == VsockOp::Request {
217 let mut hdr = pkt.hdr;
218 std::mem::swap(&mut hdr.src_cid, &mut hdr.dst_cid);
219 std::mem::swap(&mut hdr.src_port, &mut hdr.dst_port);
220 hdr.op = VsockOp::Response;
221 hdr.len = 0;
222 replies.push(VsockPacket {
223 hdr,
224 payload: Vec::new(),
225 });
226 }
227 self.tx_log.lock().push(pkt);
228 replies
229 }
230 fn drain_rx(&self) -> Vec<VsockPacket> {
231 std::mem::take(&mut *self.rx_queue.lock())
232 }
233}
234
235#[derive(Debug, Clone)]
237pub struct VsockConfig {
238 pub vsock_id: String,
240 pub guest_cid: u64,
242 pub uds_path: String,
244 pub tsi: bool,
246}
247
248const RX_QUEUE: usize = 0;
249const TX_QUEUE: usize = 1;
250const _EVENT_QUEUE: usize = 2;
254const QUEUE_MAX_SIZE: u16 = 256;
255
256#[derive(Debug)]
258pub struct VsockDevice {
259 avail: u64,
260 acked: u64,
261 queues: Vec<Queue>,
262 config: VsockConfig,
263 muxer: Arc<dyn VsockMuxer>,
264 state: Arc<Mutex<ActiveState>>,
265 rx_buffer: Arc<Mutex<Vec<VsockPacket>>>,
267}
268
269#[derive(Debug, Default)]
270struct ActiveState {
271 mem: Option<Arc<dyn GuestMemory>>,
272 irq: Option<IrqLine>,
273 activated: bool,
274}
275
276impl VsockDevice {
277 pub fn new(config: VsockConfig, muxer: Arc<dyn VsockMuxer>) -> Result<Self, std::io::Error> {
282 if config.guest_cid < MIN_GUEST_CID {
283 return Err(std::io::Error::new(
284 std::io::ErrorKind::InvalidInput,
285 format!("guest_cid must be >= {MIN_GUEST_CID}"),
286 ));
287 }
288 if config.tsi {
290 tracing::warn!(
291 vsock_id = %config.vsock_id,
292 "vsock_tsi=true requires a libkrun-patched guest kernel; \
293 stock guest kernels treat AF_VSOCK as plain vsock and the \
294 TSI proxy is inactive (see docs/macos-setup.md)"
295 );
296 }
297 Ok(Self {
298 avail: 0,
299 acked: 0,
300 queues: vec![
301 Queue::new(QUEUE_MAX_SIZE),
302 Queue::new(QUEUE_MAX_SIZE),
303 Queue::new(QUEUE_MAX_SIZE),
304 ],
305 config,
306 muxer,
307 state: Arc::new(Mutex::new(ActiveState::default())),
308 rx_buffer: Arc::new(Mutex::new(Vec::new())),
309 })
310 }
311
312 fn drain_tx(&mut self) {
313 let (mem, irq) = {
314 let state = self.state.lock();
315 match (state.mem.clone(), state.irq.clone()) {
316 (Some(m), Some(i)) => (m, i),
317 _ => return,
318 }
319 };
320 let muxer = Arc::clone(&self.muxer);
321 let rx_buffer = Arc::clone(&self.rx_buffer);
322 let queue = &mut self.queues[TX_QUEUE];
323 let mut completed = false;
324 loop {
325 let chain = match queue.pop_avail(mem.as_ref()) {
326 Ok(Some(c)) => c,
327 Ok(None) => break,
328 Err(err) => {
329 tracing::warn!(error = %err, "vsock: tx walk failed");
330 break;
331 }
332 };
333 let head = chain.head_index();
334 let descs = match chain.collect(mem.as_ref()) {
335 Ok(d) => d,
336 Err(err) => {
337 tracing::warn!(error = %err, "vsock: tx chain collect failed");
338 break;
339 }
340 };
341 let mut buf = Vec::new();
343 for desc in &descs {
344 if desc.is_write_only() {
345 continue;
346 }
347 let len = desc.len as usize;
348 let mut piece = vec![0u8; len];
349 if mem.read(desc.addr, &mut piece).is_err() {
350 continue;
351 }
352 buf.extend_from_slice(&piece);
353 }
354 if buf.len() < HDR_SIZE {
355 let _ = queue.push_used(mem.as_ref(), head, 0);
356 completed = true;
357 continue;
358 }
359 let Ok(hdr) = VsockHeader::from_bytes(&buf[..HDR_SIZE]) else {
360 continue;
361 };
362 let payload_len = (hdr.len as usize).min(buf.len() - HDR_SIZE);
363 let payload = buf[HDR_SIZE..HDR_SIZE + payload_len].to_vec();
364 let pkt = VsockPacket { hdr, payload };
365 let replies = muxer.handle_tx(pkt);
366 if !replies.is_empty() {
367 rx_buffer.lock().extend(replies);
368 }
369 if let Err(err) = queue.push_used(mem.as_ref(), head, 0) {
370 tracing::warn!(error = %err, "vsock: tx push_used failed");
371 break;
372 }
373 completed = true;
374 }
375 if completed {
376 let _ = irq.trigger_queue();
377 }
378 }
379
380 fn drain_rx(&mut self) {
381 let (mem, irq) = {
382 let state = self.state.lock();
383 match (state.mem.clone(), state.irq.clone()) {
384 (Some(m), Some(i)) => (m, i),
385 _ => return,
386 }
387 };
388 let muxer = Arc::clone(&self.muxer);
391 let mut packets: Vec<VsockPacket> = std::mem::take(&mut *self.rx_buffer.lock());
392 packets.extend(muxer.drain_rx());
393 if packets.is_empty() {
394 return;
395 }
396 let queue = &mut self.queues[RX_QUEUE];
397 let mut completed = false;
398 for pkt in packets {
399 let chain = match queue.pop_avail(mem.as_ref()) {
400 Ok(Some(c)) => c,
401 Ok(None) => {
402 self.rx_buffer.lock().push(pkt);
404 break;
405 }
406 Err(err) => {
407 tracing::warn!(error = %err, "vsock: rx walk failed");
408 break;
409 }
410 };
411 let head = chain.head_index();
412 let descs = match chain.collect(mem.as_ref()) {
413 Ok(d) => d,
414 Err(err) => {
415 tracing::warn!(error = %err, "vsock: rx chain collect failed");
416 break;
417 }
418 };
419 let mut wire = pkt.hdr.to_bytes().to_vec();
420 wire.extend_from_slice(&pkt.payload);
421 let mut written: u32 = 0;
422 let mut wire_off = 0usize;
423 for desc in descs {
424 if !desc.is_write_only() {
425 continue;
426 }
427 let len = (desc.len as usize).min(wire.len() - wire_off);
428 if len == 0 {
429 continue;
430 }
431 if mem
432 .write(desc.addr, &wire[wire_off..wire_off + len])
433 .is_err()
434 {
435 break;
436 }
437 wire_off += len;
438 written = written.saturating_add(len as u32);
439 if wire_off >= wire.len() {
440 break;
441 }
442 }
443 if let Err(err) = queue.push_used(mem.as_ref(), head, written) {
444 tracing::warn!(error = %err, "vsock: rx push_used failed");
445 break;
446 }
447 completed = true;
448 }
449 if completed {
450 let _ = irq.trigger_queue();
451 }
452 }
453}
454
455impl VirtioDevice for VsockDevice {
456 fn device_type(&self) -> VirtioDeviceType {
457 VirtioDeviceType::Vsock
458 }
459 fn avail_features(&self) -> u64 {
460 self.avail
461 }
462 fn acked_features(&self) -> u64 {
463 self.acked
464 }
465 fn set_acked_features(&mut self, value: u64) {
466 self.acked = value;
467 }
468 fn queue_max_sizes(&self) -> &[u16] {
469 const SIZES: &[u16] = &[QUEUE_MAX_SIZE, QUEUE_MAX_SIZE, QUEUE_MAX_SIZE];
470 SIZES
471 }
472 fn queues(&self) -> &[Queue] {
473 &self.queues
474 }
475 fn queues_mut(&mut self) -> &mut [Queue] {
476 &mut self.queues
477 }
478 fn read_config(&self, offset: u64, data: &mut [u8]) {
479 let bytes = self.config.guest_cid.to_le_bytes();
481 let off = offset as usize;
482 for (i, b) in data.iter_mut().enumerate() {
483 *b = bytes.get(off + i).copied().unwrap_or(0);
484 }
485 }
486 fn write_config(&mut self, _offset: u64, _data: &[u8]) {}
487 fn activate(&mut self, mem: Arc<dyn GuestMemory>, irq: IrqLine) -> Result<(), ActivateError> {
488 let mut state = self.state.lock();
489 state.mem = Some(mem);
490 state.irq = Some(irq);
491 state.activated = true;
492 Ok(())
493 }
494 fn is_activated(&self) -> bool {
495 self.state.lock().activated
496 }
497 fn process_queue(&mut self, queue_index: u16) {
498 match queue_index as usize {
499 TX_QUEUE => {
500 self.drain_tx();
501 self.drain_rx();
504 }
505 RX_QUEUE => self.drain_rx(),
506 _ => {}
509 }
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use squib_arch::IntId;
516 use squib_core::{GuestAddress, SliceGuestMemory};
517 use squib_gic::Gic;
518
519 use super::*;
520 use crate::queue::VIRTQ_DESC_F_WRITE;
521
522 #[derive(Debug, Default)]
523 struct StubGic;
524 impl Gic for StubGic {
525 fn pulse_spi(&self, _: IntId) -> Result<(), squib_gic::GicError> {
526 Ok(())
527 }
528 fn set_spi_level(&self, _: IntId, _: bool) -> Result<(), squib_gic::GicError> {
529 Ok(())
530 }
531 fn save_state(&self) -> Result<Vec<u8>, squib_gic::GicError> {
532 Ok(Vec::new())
533 }
534 fn restore_state(&self, _data: &[u8]) -> Result<(), squib_gic::GicError> {
535 Ok(())
536 }
537 }
538
539 fn line() -> IrqLine {
540 let gic: Arc<dyn Gic + Send + Sync> = Arc::new(StubGic);
541 IrqLine::new(gic, IntId::from_spi_cell(18).unwrap())
542 }
543
544 fn config(cid: u64, tsi: bool) -> VsockConfig {
545 VsockConfig {
546 vsock_id: "vsock0".into(),
547 guest_cid: cid,
548 uds_path: "/var/run/squib.vsock".into(),
549 tsi,
550 }
551 }
552
553 #[test]
554 fn test_should_reject_guest_cid_below_3() {
555 let muxer = Arc::new(InMemoryMuxer::default());
556 let err = VsockDevice::new(config(2, false), muxer).unwrap_err();
557 assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
558 }
559
560 #[test]
561 fn test_should_publish_guest_cid_in_config() {
562 let muxer = Arc::new(InMemoryMuxer::default());
563 let dev = VsockDevice::new(config(42, false), muxer).unwrap();
564 let mut got = [0u8; 8];
565 dev.read_config(0, &mut got);
566 assert_eq!(u64::from_le_bytes(got), 42);
567 }
568
569 #[test]
570 fn test_should_round_trip_header_through_to_bytes_and_back() {
571 let h = VsockHeader {
572 src_cid: 3,
573 dst_cid: 2,
574 src_port: 1024,
575 dst_port: 80,
576 len: 7,
577 type_: TYPE_STREAM,
578 op: VsockOp::Request,
579 flags: 0,
580 buf_alloc: 4096,
581 fwd_cnt: 0,
582 };
583 let bytes = h.to_bytes();
584 let parsed = VsockHeader::from_bytes(&bytes).unwrap();
585 assert_eq!(parsed.src_cid, 3);
586 assert_eq!(parsed.dst_port, 80);
587 assert_eq!(parsed.op, VsockOp::Request);
588 }
589
590 #[test]
591 fn test_should_route_tx_packet_to_muxer_and_buffer_replies() {
592 let muxer = Arc::new(InMemoryMuxer {
593 auto_respond: true,
594 ..Default::default()
595 });
596 let mut dev = VsockDevice::new(config(3, false), muxer.clone()).unwrap();
597 let mem = Arc::new(SliceGuestMemory::new(GuestAddress(0x4000_0000), 0x4000));
598 let q = &mut dev.queues_mut()[TX_QUEUE];
599 q.size = 8;
600 q.desc_table_addr = GuestAddress(0x4000_0000);
601 q.avail_ring_addr = GuestAddress(0x4000_0800);
602 q.used_ring_addr = GuestAddress(0x4000_1000);
603 q.ready = true;
604 let hdr = VsockHeader {
606 src_cid: 3,
607 dst_cid: VMADDR_CID_HOST,
608 src_port: 1024,
609 dst_port: 80,
610 len: 0,
611 type_: TYPE_STREAM,
612 op: VsockOp::Request,
613 flags: 0,
614 buf_alloc: 4096,
615 fwd_cnt: 0,
616 };
617 mem.write(GuestAddress(0x4000_2000), &hdr.to_bytes())
618 .unwrap();
619 let base = 0x4000_0000u64;
620 mem.write_u32_le(GuestAddress(base), 0x4000_2000).unwrap();
621 mem.write_u32_le(GuestAddress(base + 4), 0).unwrap();
622 mem.write_u32_le(GuestAddress(base + 8), HDR_SIZE as u32)
623 .unwrap();
624 mem.write_u16_le(GuestAddress(base + 12), 0).unwrap();
625 mem.write_u16_le(GuestAddress(base + 14), 0).unwrap();
626 mem.write_u16_le(GuestAddress(0x4000_0804), 0).unwrap();
627 mem.write_u16_le(GuestAddress(0x4000_0802), 1).unwrap();
628
629 let q = &mut dev.queues_mut()[RX_QUEUE];
631 q.size = 8;
632 q.desc_table_addr = GuestAddress(0x4000_0100);
633 q.avail_ring_addr = GuestAddress(0x4000_0900);
634 q.used_ring_addr = GuestAddress(0x4000_1100);
635 q.ready = true;
636 let rxbase = 0x4000_0100u64;
637 mem.write_u32_le(GuestAddress(rxbase), 0x4000_3000).unwrap();
638 mem.write_u32_le(GuestAddress(rxbase + 4), 0).unwrap();
639 mem.write_u32_le(GuestAddress(rxbase + 8), 64).unwrap();
640 mem.write_u16_le(GuestAddress(rxbase + 12), VIRTQ_DESC_F_WRITE)
641 .unwrap();
642 mem.write_u16_le(GuestAddress(rxbase + 14), 0).unwrap();
643 mem.write_u16_le(GuestAddress(0x4000_0904), 0).unwrap();
644 mem.write_u16_le(GuestAddress(0x4000_0902), 1).unwrap();
645
646 dev.activate(mem.clone(), line()).unwrap();
647 dev.process_queue(TX_QUEUE as u16);
648 assert_eq!(muxer.tx_log.lock().len(), 1);
650 let mut wire = [0u8; HDR_SIZE];
652 mem.read(GuestAddress(0x4000_3000), &mut wire).unwrap();
653 let parsed = VsockHeader::from_bytes(&wire).unwrap();
654 assert_eq!(parsed.op, VsockOp::Response);
655 assert_eq!(parsed.src_port, 80);
657 assert_eq!(parsed.dst_port, 1024);
658 }
659
660 #[test]
661 fn test_should_log_tsi_warning_when_tsi_enabled() {
662 let muxer = Arc::new(InMemoryMuxer::default());
665 let dev = VsockDevice::new(config(3, true), muxer).unwrap();
666 assert!(dev.config.tsi);
667 }
668}