1use super::{
2 protocol::VsockAddr, vsock::ConnectionInfo, DisconnectReason, SocketError, VirtIOSocket,
3 VirtIOSocketDevice, VirtIOSocketManager, VsockEvent, VsockEventType, DEFAULT_RX_BUFFER_SIZE,
4};
5use crate::{
6 transport::{DeviceTransport, InterruptStatus, Transport},
7 DeviceHal, Hal, Result,
8};
9use alloc::{boxed::Box, vec::Vec};
10use core::cmp::min;
11use core::convert::TryInto;
12use core::hint::spin_loop;
13use log::debug;
14use zerocopy::FromZeros;
15
16const DEFAULT_PER_CONNECTION_BUFFER_CAPACITY: u32 = 1024;
17
18pub struct VsockConnectionManager<
50 H: Hal,
51 T: Transport,
52 const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE,
53>(VsockConnectionManagerCommon<VirtIOSocket<H, T, RX_BUFFER_SIZE>>);
54
55pub struct VsockDeviceConnectionManager<H: DeviceHal, T: DeviceTransport>(
57 VsockConnectionManagerCommon<VirtIOSocketDevice<H, T>>,
58);
59
60pub trait VsockManager {
66 fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result;
72
73 fn send(&mut self, dest: VsockAddr, src_port: u32, buffer: &[u8]) -> Result;
75
76 fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result;
78
79 fn force_close(&mut self, dest: VsockAddr, src_port: u32) -> Result;
81
82 fn listen(&mut self, port: u32);
84
85 fn poll(&mut self) -> Result<Option<VsockEvent>>;
87
88 fn local_cid(&self) -> u64;
91
92 fn shutdown(&mut self, dest: VsockAddr, src_port: u32) -> Result;
99
100 fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize>;
102
103 unsafe fn ack_interrupt(ptr: *mut Self) -> InterruptStatus;
113}
114
115struct VsockConnectionManagerCommon<M: VirtIOSocketManager> {
116 driver: M,
117 per_connection_buffer_capacity: u32,
118 connections: Vec<Connection>,
119 listening_ports: Vec<u32>,
120}
121
122#[derive(Debug)]
123struct Connection {
124 info: ConnectionInfo,
125 buffer: RingBuffer,
126 peer_requested_shutdown: bool,
129}
130
131impl Connection {
132 fn new(peer: VsockAddr, local_port: u32, buffer_capacity: u32) -> Self {
133 let mut info = ConnectionInfo::new(peer, local_port);
134 info.buf_alloc = buffer_capacity;
135 Self {
136 info,
137 buffer: RingBuffer::new(buffer_capacity.try_into().unwrap()),
138 peer_requested_shutdown: false,
139 }
140 }
141}
142
143impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize>
144 VsockConnectionManager<H, T, RX_BUFFER_SIZE>
145{
146 pub fn new(driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>) -> Self {
148 Self::new_with_capacity(driver, DEFAULT_PER_CONNECTION_BUFFER_CAPACITY)
149 }
150
151 pub fn new_with_capacity(
154 driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>,
155 per_connection_buffer_capacity: u32,
156 ) -> Self {
157 Self(VsockConnectionManagerCommon {
158 driver,
159 connections: Vec::new(),
160 listening_ports: Vec::new(),
161 per_connection_buffer_capacity,
162 })
163 }
164
165 pub fn guest_cid(&self) -> u64 {
167 self.0.local_cid()
168 }
169
170 pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
176 if self.0.connections.iter().any(|connection| {
177 connection.info.dst == destination && connection.info.src_port == src_port
178 }) {
179 return Err(SocketError::ConnectionExists.into());
180 }
181
182 let new_connection =
183 Connection::new(destination, src_port, self.0.per_connection_buffer_capacity);
184
185 self.0.driver.connect(&new_connection.info)?;
186 debug!("Connection requested: {:?}", new_connection.info);
187 self.0.connections.push(new_connection);
188 Ok(())
189 }
190 pub fn listen(&mut self, port: u32) {
192 self.0.listen(port)
193 }
194
195 pub fn unlisten(&mut self, port: u32) {
197 self.0.unlisten(port)
198 }
199
200 pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
202 self.0.send(destination, src_port, buffer)
203 }
204
205 pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
207 self.0.poll()
208 }
209
210 pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
212 self.0.recv(peer, src_port, buffer)
213 }
214
215 pub fn recv_buffer_available_bytes(&mut self, peer: VsockAddr, src_port: u32) -> Result<usize> {
220 self.0.recv_buffer_available_bytes(peer, src_port)
221 }
222
223 pub fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result {
225 self.0.update_credit(peer, src_port)
226 }
227
228 pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
230 self.0.wait_for_event()
231 }
232
233 pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
240 self.0.shutdown(destination, src_port)
241 }
242
243 pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
245 self.0.force_close(destination, src_port)
246 }
247}
248
249impl<H: DeviceHal, T: DeviceTransport> VsockDeviceConnectionManager<H, T> {
250 pub fn new(driver: VirtIOSocketDevice<H, T>) -> Self {
252 Self::new_with_capacity(driver, DEFAULT_PER_CONNECTION_BUFFER_CAPACITY)
253 }
254
255 pub fn new_with_capacity(
258 driver: VirtIOSocketDevice<H, T>,
259 per_connection_buffer_capacity: u32,
260 ) -> Self {
261 Self(VsockConnectionManagerCommon {
262 driver,
263 connections: Vec::new(),
264 listening_ports: Vec::new(),
265 per_connection_buffer_capacity,
266 })
267 }
268
269 pub fn listen(&mut self, port: u32) {
271 self.0.listen(port)
272 }
273
274 pub fn unlisten(&mut self, port: u32) {
276 self.0.unlisten(port)
277 }
278
279 pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
281 self.0.send(destination, src_port, buffer)
282 }
283
284 pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
286 self.0.poll()
287 }
288
289 pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
291 self.0.recv(peer, src_port, buffer)
292 }
293
294 pub fn recv_buffer_available_bytes(&mut self, peer: VsockAddr, src_port: u32) -> Result<usize> {
299 self.0.recv_buffer_available_bytes(peer, src_port)
300 }
301
302 pub fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result {
304 self.0.update_credit(peer, src_port)
305 }
306
307 pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
309 self.0.wait_for_event()
310 }
311
312 pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
319 self.0.shutdown(destination, src_port)
320 }
321
322 pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
324 self.0.force_close(destination, src_port)
325 }
326}
327
328impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VsockManager
329 for VsockConnectionManager<H, T, RX_BUFFER_SIZE>
330{
331 fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
332 Self::connect(self, destination, src_port)
333 }
334 fn send(&mut self, dest: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
335 Self::send(self, dest, src_port, buffer)
336 }
337 fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result {
338 Self::update_credit(self, peer, src_port)
339 }
340 fn force_close(&mut self, dest: VsockAddr, src_port: u32) -> Result {
341 Self::force_close(self, dest, src_port)
342 }
343 fn listen(&mut self, port: u32) {
344 Self::listen(self, port)
345 }
346 fn poll(&mut self) -> Result<Option<VsockEvent>> {
347 Self::poll(self)
348 }
349 fn local_cid(&self) -> u64 {
350 self.0.local_cid()
351 }
352 fn shutdown(&mut self, dest: VsockAddr, src_port: u32) -> Result {
353 Self::shutdown(self, dest, src_port)
354 }
355 fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
356 Self::recv(self, peer, src_port, buffer)
357 }
358 unsafe fn ack_interrupt(ptr: *mut Self) -> InterruptStatus {
359 let vsock_driver_ptr = unsafe { &raw mut (*ptr).0.driver };
362 unsafe { VirtIOSocket::<H, T, RX_BUFFER_SIZE>::ack_interrupt(vsock_driver_ptr) }
364 }
365}
366
367impl<H: DeviceHal, T: DeviceTransport> VsockManager for VsockDeviceConnectionManager<H, T> {
368 fn connect(&mut self, _destination: VsockAddr, _src_port: u32) -> Result {
369 unreachable!("vsock devices should not make outgoing connections")
370 }
371 fn send(&mut self, dest: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
372 Self::send(self, dest, src_port, buffer)
373 }
374 fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result {
375 Self::update_credit(self, peer, src_port)
376 }
377 fn force_close(&mut self, dest: VsockAddr, src_port: u32) -> Result {
378 Self::force_close(self, dest, src_port)
379 }
380 fn listen(&mut self, port: u32) {
381 Self::listen(self, port)
382 }
383 fn poll(&mut self) -> Result<Option<VsockEvent>> {
384 Self::poll(self)
385 }
386 fn local_cid(&self) -> u64 {
387 self.0.local_cid()
388 }
389 fn shutdown(&mut self, dest: VsockAddr, src_port: u32) -> Result {
390 Self::shutdown(self, dest, src_port)
391 }
392 fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
393 Self::recv(self, peer, src_port, buffer)
394 }
395 unsafe fn ack_interrupt(_ptr: *mut Self) -> InterruptStatus {
396 panic!("vsock devices cannot acknowledge interrupts")
397 }
398}
399
400impl<M: VirtIOSocketManager> VsockConnectionManagerCommon<M> {
401 pub fn listen(&mut self, port: u32) {
403 if !self.listening_ports.contains(&port) {
404 self.listening_ports.push(port);
405 }
406 }
407
408 pub fn unlisten(&mut self, port: u32) {
410 self.listening_ports.retain(|p| *p != port);
411 }
412
413 pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
415 let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
416
417 self.driver.send(buffer, &mut connection.info)
418 }
419
420 pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
422 let local_cid = self.driver.local_cid();
423 let connections = &mut self.connections;
424 let per_connection_buffer_capacity = self.per_connection_buffer_capacity;
425
426 let result = self.driver.poll(|event, body| {
427 let connection = get_connection_for_event(connections, &event, local_cid);
428
429 let connection = if let Some((_, connection)) = connection {
432 connection
433 } else if let VsockEventType::ConnectionRequest = event.event_type {
434 if connection.is_some() || event.destination.cid != local_cid {
436 return Ok(None);
437 }
438 connections.push(Connection::new(
441 event.source,
442 event.destination.port,
443 per_connection_buffer_capacity,
444 ));
445 connections.last_mut().unwrap()
446 } else {
447 return Ok(None);
448 };
449
450 connection.info.update_for_event(&event);
452
453 if let VsockEventType::Received { length } = event.event_type {
454 if !connection.buffer.add(body) {
456 return Err(SocketError::OutputBufferTooShort(length).into());
457 }
458 }
459
460 Ok(Some(event))
461 })?;
462
463 let Some(event) = result else {
464 return Ok(None);
465 };
466
467 let (connection_index, connection) =
469 get_connection_for_event(connections, &event, local_cid).unwrap();
470
471 match event.event_type {
472 VsockEventType::ConnectionRequest => {
473 if self.listening_ports.contains(&event.destination.port) {
474 self.driver.accept(&connection.info)?;
475 } else {
476 self.driver.force_close(&connection.info)?;
478 self.connections.swap_remove(connection_index);
479
480 return Ok(None);
482 }
483 }
484 VsockEventType::Connected => {}
485 VsockEventType::Disconnected { reason } => {
486 if connection.buffer.is_empty() {
488 if reason == DisconnectReason::Shutdown {
489 self.driver.force_close(&connection.info)?;
490 }
491 self.connections.swap_remove(connection_index);
492 } else {
493 connection.peer_requested_shutdown = true;
494 }
495 }
496 VsockEventType::Received { .. } => {
497 }
499 VsockEventType::CreditRequest => {
500 self.driver.credit_update(&connection.info)?;
502 return Ok(None);
504 }
505 VsockEventType::CreditUpdate => {}
506 }
507
508 Ok(Some(event))
509 }
510
511 pub fn local_cid(&self) -> u64 {
513 self.driver.local_cid()
514 }
515
516 pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
518 let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
519
520 let bytes_read = connection.buffer.drain(buffer);
522
523 connection.info.done_forwarding(bytes_read);
524
525 if connection.peer_requested_shutdown && connection.buffer.is_empty() {
528 self.driver.force_close(&connection.info)?;
529 self.connections.swap_remove(connection_index);
530 }
531
532 Ok(bytes_read)
533 }
534
535 pub fn recv_buffer_available_bytes(&mut self, peer: VsockAddr, src_port: u32) -> Result<usize> {
540 let (_, connection) = get_connection(&mut self.connections, peer, src_port)?;
541 Ok(connection.buffer.used())
542 }
543
544 pub fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result {
546 let (_, connection) = get_connection(&mut self.connections, peer, src_port)?;
547 self.driver.credit_update(&connection.info)
548 }
549
550 pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
552 loop {
553 if let Some(event) = self.poll()? {
554 return Ok(event);
555 } else {
556 spin_loop();
557 }
558 }
559 }
560
561 pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
568 let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
569
570 self.driver.shutdown(&connection.info)
571 }
572
573 pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
575 let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
576
577 self.driver.force_close(&connection.info)?;
578
579 self.connections.swap_remove(index);
580 Ok(())
581 }
582}
583
584fn get_connection(
589 connections: &mut [Connection],
590 peer: VsockAddr,
591 local_port: u32,
592) -> core::result::Result<(usize, &mut Connection), SocketError> {
593 connections
594 .iter_mut()
595 .enumerate()
596 .find(|(_, connection)| {
597 connection.info.dst == peer && connection.info.src_port == local_port
598 })
599 .ok_or(SocketError::NotConnected)
600}
601
602fn get_connection_for_event<'a>(
604 connections: &'a mut [Connection],
605 event: &VsockEvent,
606 local_cid: u64,
607) -> Option<(usize, &'a mut Connection)> {
608 connections
609 .iter_mut()
610 .enumerate()
611 .find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
612}
613
614#[derive(Debug)]
615struct RingBuffer {
616 buffer: Box<[u8]>,
617 used: usize,
619 start: usize,
621}
622
623impl RingBuffer {
624 pub fn new(capacity: usize) -> Self {
625 Self {
626 buffer: FromZeros::new_box_zeroed_with_elems(capacity).unwrap(),
627 used: 0,
628 start: 0,
629 }
630 }
631
632 pub fn used(&self) -> usize {
634 self.used
635 }
636
637 pub fn is_empty(&self) -> bool {
639 self.used == 0
640 }
641
642 pub fn free(&self) -> usize {
644 self.buffer.len() - self.used
645 }
646
647 pub fn add(&mut self, bytes: &[u8]) -> bool {
651 if bytes.len() > self.free() {
652 return false;
653 }
654
655 let first_available = (self.start + self.used) % self.buffer.len();
657 let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
660 self.buffer[first_available..first_available + copy_length_before_wraparound]
661 .copy_from_slice(&bytes[0..copy_length_before_wraparound]);
662 if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
663 self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
664 }
665 self.used += bytes.len();
666
667 true
668 }
669
670 pub fn drain(&mut self, out: &mut [u8]) -> usize {
673 let bytes_read = min(self.used, out.len());
674
675 let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
677 let read_after_wraparound = bytes_read
679 .checked_sub(read_before_wraparound)
680 .unwrap_or_default();
681
682 out[0..read_before_wraparound]
683 .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
684 out[read_before_wraparound..bytes_read]
685 .copy_from_slice(&self.buffer[0..read_after_wraparound]);
686
687 self.used -= bytes_read;
688 self.start = (self.start + bytes_read) % self.buffer.len();
689
690 bytes_read
691 }
692}
693
694#[cfg(test)]
695mod tests {
696 use super::*;
697 use crate::{
698 config::ReadOnly,
699 device::socket::{
700 protocol::{
701 SocketType, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp,
702 },
703 vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
704 },
705 hal::fake::FakeHal,
706 transport::{
707 fake::{FakeTransport, QueueStatus, State},
708 DeviceType,
709 },
710 };
711 use alloc::{sync::Arc, vec};
712 use core::mem::size_of;
713 use std::{sync::Mutex, thread};
714 use zerocopy::{FromBytes, IntoBytes};
715
716 #[test]
717 fn send_recv() {
718 let host_cid = 2;
719 let guest_cid = 66;
720 let host_port = 1234;
721 let guest_port = 4321;
722 let host_address = VsockAddr {
723 cid: host_cid,
724 port: host_port,
725 };
726 let hello_from_guest = "Hello from guest";
727 let hello_from_host = "Hello from host";
728
729 let config_space = VirtioVsockConfig {
730 guest_cid_low: ReadOnly::new(66),
731 guest_cid_high: ReadOnly::new(0),
732 };
733 let state = Arc::new(Mutex::new(State::new(
734 vec![
735 QueueStatus::default(),
736 QueueStatus::default(),
737 QueueStatus::default(),
738 ],
739 config_space,
740 )));
741 let transport = FakeTransport {
742 device_type: DeviceType::Socket,
743 max_queue_size: 32,
744 device_features: 0,
745 state: state.clone(),
746 };
747 let mut socket = VsockConnectionManager::new(
748 VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
749 );
750
751 let handle = thread::spawn(move || {
753 State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
755 assert_eq!(
756 VirtioVsockHdr::read_from_bytes(
757 state
758 .lock()
759 .unwrap()
760 .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
761 .as_slice()
762 )
763 .unwrap(),
764 VirtioVsockHdr {
765 op: VirtioVsockOp::Request.into(),
766 src_cid: guest_cid.into(),
767 dst_cid: host_cid.into(),
768 src_port: guest_port.into(),
769 dst_port: host_port.into(),
770 len: 0.into(),
771 socket_type: SocketType::Stream.into(),
772 flags: 0.into(),
773 buf_alloc: 1024.into(),
774 fwd_cnt: 0.into(),
775 }
776 );
777
778 state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
780 RX_QUEUE_IDX,
781 VirtioVsockHdr {
782 op: VirtioVsockOp::Response.into(),
783 src_cid: host_cid.into(),
784 dst_cid: guest_cid.into(),
785 src_port: host_port.into(),
786 dst_port: guest_port.into(),
787 len: 0.into(),
788 socket_type: SocketType::Stream.into(),
789 flags: 0.into(),
790 buf_alloc: 50.into(),
791 fwd_cnt: 0.into(),
792 }
793 .as_bytes(),
794 );
795
796 State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
798 let request = state
799 .lock()
800 .unwrap()
801 .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
802 assert_eq!(
803 request.len(),
804 size_of::<VirtioVsockHdr>() + hello_from_guest.len()
805 );
806 assert_eq!(
807 VirtioVsockHdr::read_from_prefix(request.as_slice())
808 .unwrap()
809 .0,
810 VirtioVsockHdr {
811 op: VirtioVsockOp::Rw.into(),
812 src_cid: guest_cid.into(),
813 dst_cid: host_cid.into(),
814 src_port: guest_port.into(),
815 dst_port: host_port.into(),
816 len: (hello_from_guest.len() as u32).into(),
817 socket_type: SocketType::Stream.into(),
818 flags: 0.into(),
819 buf_alloc: 1024.into(),
820 fwd_cnt: 0.into(),
821 }
822 );
823 assert_eq!(
824 &request[size_of::<VirtioVsockHdr>()..],
825 hello_from_guest.as_bytes()
826 );
827
828 println!("Host sending");
829
830 let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
832 VirtioVsockHdr {
833 op: VirtioVsockOp::Rw.into(),
834 src_cid: host_cid.into(),
835 dst_cid: guest_cid.into(),
836 src_port: host_port.into(),
837 dst_port: guest_port.into(),
838 len: (hello_from_host.len() as u32).into(),
839 socket_type: SocketType::Stream.into(),
840 flags: 0.into(),
841 buf_alloc: 50.into(),
842 fwd_cnt: (hello_from_guest.len() as u32).into(),
843 }
844 .write_to_prefix(response.as_mut_slice())
845 .unwrap();
846 response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
847 state
848 .lock()
849 .unwrap()
850 .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
851
852 State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
854 assert_eq!(
855 VirtioVsockHdr::read_from_bytes(
856 state
857 .lock()
858 .unwrap()
859 .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
860 .as_slice()
861 )
862 .unwrap(),
863 VirtioVsockHdr {
864 op: VirtioVsockOp::Shutdown.into(),
865 src_cid: guest_cid.into(),
866 dst_cid: host_cid.into(),
867 src_port: guest_port.into(),
868 dst_port: host_port.into(),
869 len: 0.into(),
870 socket_type: SocketType::Stream.into(),
871 flags: (StreamShutdown::SEND | StreamShutdown::RECEIVE).into(),
872 buf_alloc: 1024.into(),
873 fwd_cnt: (hello_from_host.len() as u32).into(),
874 }
875 );
876 });
877
878 socket.connect(host_address, guest_port).unwrap();
879 assert_eq!(
880 socket.wait_for_event().unwrap(),
881 VsockEvent {
882 source: host_address,
883 destination: VsockAddr {
884 cid: guest_cid,
885 port: guest_port,
886 },
887 event_type: VsockEventType::Connected,
888 buffer_status: VsockBufferStatus {
889 buffer_allocation: 50,
890 forward_count: 0,
891 },
892 }
893 );
894 println!("Guest sending");
895 socket
896 .send(host_address, guest_port, "Hello from guest".as_bytes())
897 .unwrap();
898 println!("Guest waiting to receive.");
899 assert_eq!(
900 socket.wait_for_event().unwrap(),
901 VsockEvent {
902 source: host_address,
903 destination: VsockAddr {
904 cid: guest_cid,
905 port: guest_port,
906 },
907 event_type: VsockEventType::Received {
908 length: hello_from_host.len()
909 },
910 buffer_status: VsockBufferStatus {
911 buffer_allocation: 50,
912 forward_count: hello_from_guest.len() as u32,
913 },
914 }
915 );
916 println!("Guest getting received data.");
917 let mut buffer = [0u8; 64];
918 assert_eq!(
919 socket.recv(host_address, guest_port, &mut buffer).unwrap(),
920 hello_from_host.len()
921 );
922 assert_eq!(
923 &buffer[0..hello_from_host.len()],
924 hello_from_host.as_bytes()
925 );
926 socket.shutdown(host_address, guest_port).unwrap();
927
928 handle.join().unwrap();
929 }
930
931 #[test]
932 fn incoming_connection() {
933 let host_cid = 2;
934 let guest_cid = 66;
935 let host_port = 1234;
936 let guest_port = 4321;
937 let wrong_guest_port = 4444;
938 let host_address = VsockAddr {
939 cid: host_cid,
940 port: host_port,
941 };
942
943 let config_space = VirtioVsockConfig {
944 guest_cid_low: ReadOnly::new(66),
945 guest_cid_high: ReadOnly::new(0),
946 };
947 let state = Arc::new(Mutex::new(State::new(
948 vec![
949 QueueStatus::default(),
950 QueueStatus::default(),
951 QueueStatus::default(),
952 ],
953 config_space,
954 )));
955 let transport = FakeTransport {
956 device_type: DeviceType::Socket,
957 max_queue_size: 32,
958 device_features: 0,
959 state: state.clone(),
960 };
961 let mut socket = VsockConnectionManager::new(
962 VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
963 );
964
965 socket.listen(guest_port);
966
967 let handle = thread::spawn(move || {
969 println!("Host sending connection request to wrong port");
971 state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
972 RX_QUEUE_IDX,
973 VirtioVsockHdr {
974 op: VirtioVsockOp::Request.into(),
975 src_cid: host_cid.into(),
976 dst_cid: guest_cid.into(),
977 src_port: host_port.into(),
978 dst_port: wrong_guest_port.into(),
979 len: 0.into(),
980 socket_type: SocketType::Stream.into(),
981 flags: 0.into(),
982 buf_alloc: 50.into(),
983 fwd_cnt: 0.into(),
984 }
985 .as_bytes(),
986 );
987
988 println!("Host waiting for rejection");
990 State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
991 assert_eq!(
992 VirtioVsockHdr::read_from_bytes(
993 state
994 .lock()
995 .unwrap()
996 .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
997 .as_slice()
998 )
999 .unwrap(),
1000 VirtioVsockHdr {
1001 op: VirtioVsockOp::Rst.into(),
1002 src_cid: guest_cid.into(),
1003 dst_cid: host_cid.into(),
1004 src_port: wrong_guest_port.into(),
1005 dst_port: host_port.into(),
1006 len: 0.into(),
1007 socket_type: SocketType::Stream.into(),
1008 flags: 0.into(),
1009 buf_alloc: 1024.into(),
1010 fwd_cnt: 0.into(),
1011 }
1012 );
1013
1014 println!("Host sending connection request to right port");
1016 state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
1017 RX_QUEUE_IDX,
1018 VirtioVsockHdr {
1019 op: VirtioVsockOp::Request.into(),
1020 src_cid: host_cid.into(),
1021 dst_cid: guest_cid.into(),
1022 src_port: host_port.into(),
1023 dst_port: guest_port.into(),
1024 len: 0.into(),
1025 socket_type: SocketType::Stream.into(),
1026 flags: 0.into(),
1027 buf_alloc: 50.into(),
1028 fwd_cnt: 0.into(),
1029 }
1030 .as_bytes(),
1031 );
1032
1033 println!("Host waiting for response");
1035 State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
1036 assert_eq!(
1037 VirtioVsockHdr::read_from_bytes(
1038 state
1039 .lock()
1040 .unwrap()
1041 .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
1042 .as_slice()
1043 )
1044 .unwrap(),
1045 VirtioVsockHdr {
1046 op: VirtioVsockOp::Response.into(),
1047 src_cid: guest_cid.into(),
1048 dst_cid: host_cid.into(),
1049 src_port: guest_port.into(),
1050 dst_port: host_port.into(),
1051 len: 0.into(),
1052 socket_type: SocketType::Stream.into(),
1053 flags: 0.into(),
1054 buf_alloc: 1024.into(),
1055 fwd_cnt: 0.into(),
1056 }
1057 );
1058
1059 println!("Host finished");
1060 });
1061
1062 println!("Guest expecting incoming connection.");
1064 assert_eq!(
1065 socket.wait_for_event().unwrap(),
1066 VsockEvent {
1067 source: host_address,
1068 destination: VsockAddr {
1069 cid: guest_cid,
1070 port: guest_port,
1071 },
1072 event_type: VsockEventType::ConnectionRequest,
1073 buffer_status: VsockBufferStatus {
1074 buffer_allocation: 50,
1075 forward_count: 0,
1076 },
1077 }
1078 );
1079
1080 handle.join().unwrap();
1081 }
1082}