smolvm_network/
frame_stream.rs1use crate::queues::NetworkFrameQueues;
42use std::io::{self, Read, Write};
43use std::net::Shutdown;
44use std::os::fd::{AsRawFd, FromRawFd, RawFd};
45use std::os::unix::net::UnixStream;
46use std::sync::Arc;
47use std::thread::{self, JoinHandle};
48
49const FRAME_HEADER_LEN: usize = 4;
50const SOCKET_SENDBUF_BYTES: libc::c_int = 16 * 1024 * 1024;
51const MAX_FRAME_LEN: usize = 64 * 1024;
52
53pub struct FrameStreamBridge {
60 control: UnixStream,
61 queues: Arc<NetworkFrameQueues>,
62 reader_handle: Option<JoinHandle<()>>,
63 writer_handle: Option<JoinHandle<()>>,
64}
65
66pub fn start_frame_stream_bridge(
68 fd: RawFd,
69 queues: Arc<NetworkFrameQueues>,
70) -> io::Result<FrameStreamBridge> {
71 let stream = unsafe { UnixStream::from_raw_fd(fd) };
73 set_socket_send_buffer(&stream)?;
74 let control = stream.try_clone()?;
76 let reader = stream.try_clone()?;
77 let writer = stream;
78
79 let reader_handle = thread::Builder::new()
80 .name("smolvm-net-reader".into())
81 .spawn({
82 let queues = queues.clone();
83 move || run_reader(reader, queues)
84 })?;
85
86 let writer_queues = queues.clone();
87 let writer_handle = thread::Builder::new()
88 .name("smolvm-net-writer".into())
89 .spawn(move || run_writer(writer, writer_queues))?;
90
91 Ok(FrameStreamBridge {
92 control,
93 queues,
94 reader_handle: Some(reader_handle),
95 writer_handle: Some(writer_handle),
96 })
97}
98
99impl Drop for FrameStreamBridge {
100 fn drop(&mut self) {
106 self.queues.begin_shutdown();
107 let _ = self.control.shutdown(Shutdown::Both);
108
109 if let Some(handle) = self.reader_handle.take() {
110 let _ = handle.join();
111 }
112 if let Some(handle) = self.writer_handle.take() {
113 let _ = handle.join();
114 }
115 }
116}
117
118fn run_reader(mut reader: UnixStream, queues: Arc<NetworkFrameQueues>) {
119 loop {
122 match read_frame(&mut reader) {
123 Ok(frame) => {
124 if queues.guest_to_host.push(frame).is_ok() {
125 queues.guest_wake.wake();
126 } else {
127 tracing::warn!("dropping guest ethernet frame because the host queue is full");
128 }
129 }
130 Err(err) => {
131 queues.begin_shutdown();
132 tracing::debug!(error = %err, "virtio-net reader thread stopped");
133 return;
134 }
135 }
136 }
137}
138
139fn run_writer(mut writer: UnixStream, queues: Arc<NetworkFrameQueues>) {
140 loop {
143 if queues.is_shutting_down() && queues.host_to_guest.is_empty() {
144 return;
145 }
146 match queues.host_wake.wait(None) {
147 Ok(true) => queues.host_wake.drain(),
148 Ok(false) => continue,
149 Err(err) => {
150 queues.begin_shutdown();
151 tracing::debug!(error = %err, "virtio-net writer wake pipe failed");
152 return;
153 }
154 }
155
156 while let Some(frame) = queues.host_to_guest.pop() {
157 if let Err(err) = write_frame(&mut writer, &frame) {
158 queues.begin_shutdown();
159 tracing::debug!(error = %err, "virtio-net writer thread stopped");
160 return;
161 }
162 }
163 }
164}
165
166pub(crate) fn read_frame<R: Read>(reader: &mut R) -> io::Result<Vec<u8>> {
186 let mut header = [0u8; FRAME_HEADER_LEN];
187 reader.read_exact(&mut header)?;
188 let frame_len = u32::from_be_bytes(header) as usize;
189
190 if frame_len == 0 || frame_len > MAX_FRAME_LEN {
191 return Err(io::Error::new(
192 io::ErrorKind::InvalidData,
193 format!("invalid ethernet frame length: {frame_len}"),
194 ));
195 }
196
197 let mut frame = vec![0u8; frame_len];
198 reader.read_exact(&mut frame)?;
199 Ok(frame)
200}
201
202pub(crate) fn write_frame<W: Write>(writer: &mut W, frame: &[u8]) -> io::Result<()> {
216 if frame.is_empty() || frame.len() > MAX_FRAME_LEN {
217 return Err(io::Error::new(
218 io::ErrorKind::InvalidInput,
219 format!("invalid ethernet frame length: {}", frame.len()),
220 ));
221 }
222
223 let header = (frame.len() as u32).to_be_bytes();
224 write_all(writer, &header)?;
225 write_all(writer, frame)?;
226 writer.flush()
227}
228
229fn write_all<W: Write>(writer: &mut W, mut buf: &[u8]) -> io::Result<()> {
230 while !buf.is_empty() {
234 let written = writer.write(buf)?;
235 if written == 0 {
236 return Err(io::Error::new(
237 io::ErrorKind::WriteZero,
238 "short write while sending ethernet frame",
239 ));
240 }
241 buf = &buf[written..];
242 }
243 Ok(())
244}
245
246fn set_socket_send_buffer(stream: &UnixStream) -> io::Result<()> {
251 let size = SOCKET_SENDBUF_BYTES;
253 let result = unsafe {
255 libc::setsockopt(
256 stream.as_raw_fd(),
258 libc::SOL_SOCKET,
260 libc::SO_SNDBUF,
262 (&size as *const libc::c_int).cast(),
264 std::mem::size_of_val(&size) as libc::socklen_t,
267 )
268 };
269 if result < 0 {
270 tracing::warn!(
271 error = %io::Error::last_os_error(),
272 "failed to increase SO_SNDBUF for virtio-net unixstream"
273 );
274 return Ok(());
275 }
276 Ok(())
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 struct PartialWriter {
284 written: Vec<u8>,
285 chunk_size: usize,
286 }
287
288 impl Write for PartialWriter {
289 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
290 let take = buf.len().min(self.chunk_size);
291 self.written.extend_from_slice(&buf[..take]);
292 Ok(take)
293 }
294
295 fn flush(&mut self) -> io::Result<()> {
296 Ok(())
297 }
298 }
299
300 #[test]
301 fn write_frame_handles_partial_writes() {
302 let mut writer = PartialWriter {
303 written: Vec::new(),
304 chunk_size: 3,
305 };
306 write_frame(&mut writer, &[1, 2, 3, 4, 5, 6]).unwrap();
307 assert_eq!(writer.written[..4], [0, 0, 0, 6]);
308 assert_eq!(writer.written[4..], [1, 2, 3, 4, 5, 6]);
309 }
310
311 #[test]
312 fn read_frame_decodes_length_prefix() {
313 let mut input = std::io::Cursor::new(vec![0, 0, 0, 3, 7, 8, 9]);
314 assert_eq!(read_frame(&mut input).unwrap(), vec![7, 8, 9]);
315 }
316
317 #[test]
318 fn unix_stream_round_trip_multiple_frames() {
319 let (mut left, mut right) = UnixStream::pair().unwrap();
320 write_frame(&mut left, &[1, 2, 3]).unwrap();
321 write_frame(&mut left, &[4, 5]).unwrap();
322
323 assert_eq!(read_frame(&mut right).unwrap(), vec![1, 2, 3]);
324 assert_eq!(read_frame(&mut right).unwrap(), vec![4, 5]);
325 }
326}