1#![forbid(unsafe_code)]
27#![allow(clippy::disallowed_methods)]
32
33use std::{
34 os::fd::OwnedFd,
35 path::{Path, PathBuf},
36 sync::Arc,
37};
38
39use parking_lot::Mutex;
40use squib_virtio::devices::net::{Frame, NetBackend};
41use thiserror::Error;
42use tokio::{
43 io::{AsyncReadExt, AsyncWriteExt},
44 net::UnixStream,
45 process::{Child, Command},
46 sync::mpsc,
47};
48use tracing::{debug, info, warn};
49
50#[derive(Debug, Error)]
52#[non_exhaustive]
53pub enum GvproxyError {
54 #[error("socketpair failed: {0}")]
56 SocketPair(#[source] std::io::Error),
57
58 #[error("spawn gvproxy {path}: {source}")]
60 Spawn {
61 path: PathBuf,
63 #[source]
65 source: std::io::Error,
66 },
67
68 #[error("invalid gvproxy binary_path: {reason}")]
71 InvalidBinaryPath {
72 reason: &'static str,
74 },
75
76 #[error("invalid gvproxy extra arg #{index}: {reason}")]
78 InvalidExtraArg {
79 index: usize,
81 reason: &'static str,
83 },
84
85 #[error("gvproxy io error: {0}")]
87 Io(#[from] std::io::Error),
88
89 #[error("gvproxy frame too large: {got} > {cap}")]
91 FrameTooLarge {
92 got: usize,
94 cap: usize,
96 },
97}
98
99#[derive(Debug, Clone)]
101pub struct GvproxyParams {
102 pub binary_path: PathBuf,
104 pub extra_args: Vec<String>,
107 pub max_frame_bytes: usize,
111}
112
113impl GvproxyParams {
114 #[must_use]
116 pub fn new(binary_path: impl Into<PathBuf>) -> Self {
117 Self {
118 binary_path: binary_path.into(),
119 extra_args: Vec::new(),
120 max_frame_bytes: 65_536,
121 }
122 }
123}
124
125#[derive(Debug)]
127pub struct GvproxyBackend {
128 rx_buffer: Arc<Mutex<Vec<Frame>>>,
133 tx_tx: mpsc::Sender<Frame>,
136 child: Arc<Mutex<Option<Child>>>,
138}
139
140impl GvproxyBackend {
141 pub fn start(params: &GvproxyParams) -> Result<Self, GvproxyError> {
149 validate_binary_path(¶ms.binary_path)?;
150 validate_extra_args(¶ms.extra_args)?;
151
152 let (host_uds, child_uds) = make_socketpair()?;
153
154 let mut cmd = Command::new(¶ms.binary_path);
155 cmd.kill_on_drop(true);
156 cmd.args(¶ms.extra_args);
157
158 let child_uds_clone = child_uds.try_clone()?;
165 let stdin_fd: OwnedFd = child_uds_clone.into();
166 let stdout_fd: OwnedFd = child_uds.into();
167 cmd.stdin(std::process::Stdio::from(stdin_fd));
168 cmd.stdout(std::process::Stdio::from(stdout_fd));
169 cmd.stderr(std::process::Stdio::null());
170
171 let child = cmd.spawn().map_err(|e| GvproxyError::Spawn {
172 path: params.binary_path.clone(),
173 source: e,
174 })?;
175 info!(path = %params.binary_path.display(), pid = ?child.id(), "gvproxy spawned");
176
177 host_uds.set_nonblocking(true).map_err(GvproxyError::Io)?;
178 let host_uds_async = UnixStream::from_std(host_uds).map_err(GvproxyError::Io)?;
179 let (read_half, write_half) = host_uds_async.into_split();
180
181 let rx_buffer: Arc<Mutex<Vec<Frame>>> = Arc::new(Mutex::new(Vec::new()));
182 let (tx_tx, tx_rx) = mpsc::channel::<Frame>(256);
183
184 let rx_buffer_for_task = Arc::clone(&rx_buffer);
187 let max_bytes = params.max_frame_bytes;
188 tokio::spawn(async move {
189 run_reader(read_half, rx_buffer_for_task, max_bytes).await;
190 });
191
192 tokio::spawn(async move {
195 run_writer(write_half, tx_rx).await;
196 });
197
198 Ok(Self {
199 rx_buffer,
200 tx_tx,
201 child: Arc::new(Mutex::new(Some(child))),
202 })
203 }
204}
205
206impl NetBackend for GvproxyBackend {
207 fn send(&self, frame: &Frame) {
208 if let Err(err) = self.tx_tx.try_send(frame.clone()) {
211 tracing::trace!(error = %err, "gvproxy tx queue full; dropping frame");
212 }
213 }
214 fn recv(&self) -> Vec<Frame> {
215 std::mem::take(&mut *self.rx_buffer.lock())
216 }
217}
218
219impl Drop for GvproxyBackend {
220 fn drop(&mut self) {
221 if let Some(mut child) = self.child.lock().take() {
222 if let Some(pid) = child.id() {
235 crate::sys::kill_pid(pid, libc::SIGKILL);
236 }
237 let _ = child.start_kill();
240 drop(child);
241 }
242 }
243}
244
245const RX_QUEUE_CAP: usize = 256;
247
248async fn run_reader(
249 mut read_half: tokio::net::unix::OwnedReadHalf,
250 rx_buffer: Arc<Mutex<Vec<Frame>>>,
251 max_frame_bytes: usize,
252) {
253 let mut hdr = [0u8; 4];
254 loop {
255 if let Err(err) = read_half.read_exact(&mut hdr).await {
256 if err.kind() != std::io::ErrorKind::UnexpectedEof {
257 debug!(error = %err, "gvproxy reader: header read failed");
258 }
259 return;
260 }
261 let len = u32::from_be_bytes(hdr) as usize;
262 if len > max_frame_bytes {
263 warn!(
264 len,
265 max_frame_bytes, "gvproxy reader: frame too large; closing"
266 );
267 return;
268 }
269 if len == 0 {
270 continue;
271 }
272 let mut buf = vec![0u8; len];
273 if let Err(err) = read_half.read_exact(&mut buf).await {
274 debug!(error = %err, "gvproxy reader: body read failed");
275 return;
276 }
277 let frame = Frame::from_bytes(bytes::Bytes::from(buf));
278 let mut guard = rx_buffer.lock();
279 if guard.len() >= RX_QUEUE_CAP {
280 guard.remove(0);
281 }
282 guard.push(frame);
283 }
284}
285
286async fn run_writer(
287 mut write_half: tokio::net::unix::OwnedWriteHalf,
288 mut tx_rx: mpsc::Receiver<Frame>,
289) {
290 while let Some(frame) = tx_rx.recv().await {
291 let len = u32::try_from(frame.bytes.len()).unwrap_or(u32::MAX);
292 let hdr = len.to_be_bytes();
293 if let Err(err) = write_half.write_all(&hdr).await {
294 debug!(error = %err, "gvproxy writer: header write failed");
295 return;
296 }
297 if let Err(err) = write_half.write_all(&frame.bytes).await {
298 debug!(error = %err, "gvproxy writer: body write failed");
299 return;
300 }
301 }
302}
303
304fn make_socketpair() -> Result<
309 (
310 std::os::unix::net::UnixStream,
311 std::os::unix::net::UnixStream,
312 ),
313 GvproxyError,
314> {
315 std::os::unix::net::UnixStream::pair().map_err(GvproxyError::SocketPair)
316}
317
318fn validate_binary_path(path: &Path) -> Result<(), GvproxyError> {
323 use std::os::unix::ffi::OsStrExt;
324 let bytes = path.as_os_str().as_bytes();
325 if bytes.is_empty() {
326 return Err(GvproxyError::InvalidBinaryPath {
327 reason: "must not be empty",
328 });
329 }
330 if bytes.contains(&0) {
331 return Err(GvproxyError::InvalidBinaryPath {
332 reason: "must not contain NUL bytes",
333 });
334 }
335 if bytes.len() > 1024 {
336 return Err(GvproxyError::InvalidBinaryPath {
337 reason: "exceeds PATH_MAX (1024 bytes)",
338 });
339 }
340 let metadata = std::fs::metadata(path).map_err(|_| GvproxyError::InvalidBinaryPath {
341 reason: "could not stat binary path",
342 })?;
343 if !metadata.is_file() {
344 return Err(GvproxyError::InvalidBinaryPath {
345 reason: "not a regular file",
346 });
347 }
348 Ok(())
349}
350
351fn validate_extra_args(args: &[String]) -> Result<(), GvproxyError> {
356 if args.len() > 32 {
357 return Err(GvproxyError::InvalidExtraArg {
358 index: 32,
359 reason: "more than 32 extra args",
360 });
361 }
362 for (i, a) in args.iter().enumerate() {
363 if a.len() > 512 {
364 return Err(GvproxyError::InvalidExtraArg {
365 index: i,
366 reason: "exceeds 512 bytes",
367 });
368 }
369 if a.contains('\0') {
370 return Err(GvproxyError::InvalidExtraArg {
371 index: i,
372 reason: "contains NUL byte",
373 });
374 }
375 }
376 Ok(())
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[test]
384 fn test_should_construct_default_params() {
385 let p = GvproxyParams::new("/usr/local/libexec/squib/gvproxy");
386 assert_eq!(p.max_frame_bytes, 65_536);
387 assert!(p.extra_args.is_empty());
388 }
389
390 #[test]
391 fn test_should_reject_empty_binary_path() {
392 let r = validate_binary_path(Path::new(""));
393 assert!(matches!(r, Err(GvproxyError::InvalidBinaryPath { .. })));
394 }
395
396 #[test]
397 fn test_should_reject_binary_path_with_nul_byte() {
398 use std::{ffi::OsStr, os::unix::ffi::OsStrExt};
399 let p = Path::new(OsStr::from_bytes(b"/etc/p\0asswd"));
400 let r = validate_binary_path(p);
401 assert!(matches!(r, Err(GvproxyError::InvalidBinaryPath { .. })));
402 }
403
404 #[test]
405 fn test_should_reject_extra_args_above_cap() {
406 let args: Vec<String> = (0..40).map(|i| format!("--flag-{i}")).collect();
407 let r = validate_extra_args(&args);
408 assert!(matches!(r, Err(GvproxyError::InvalidExtraArg { .. })));
409 }
410
411 #[test]
412 fn test_should_reject_extra_arg_with_nul_byte() {
413 let args = vec!["--ok".into(), "with\0nul".into()];
414 let r = validate_extra_args(&args);
415 assert!(matches!(
416 r,
417 Err(GvproxyError::InvalidExtraArg { index: 1, .. })
418 ));
419 }
420
421 #[tokio::test]
422 async fn test_should_round_trip_a_frame_through_simulated_gvproxy_endpoint() {
423 let (host_std, child_std) = std::os::unix::net::UnixStream::pair().unwrap();
427 host_std.set_nonblocking(true).unwrap();
428 child_std.set_nonblocking(true).unwrap();
429 let host = UnixStream::from_std(host_std).unwrap();
430 let mut child = UnixStream::from_std(child_std).unwrap();
431 let (read_half, write_half) = host.into_split();
432 let rx_buffer: Arc<Mutex<Vec<Frame>>> = Arc::new(Mutex::new(Vec::new()));
433 let (tx_tx, tx_rx) = mpsc::channel::<Frame>(8);
434
435 let rx_for_task = Arc::clone(&rx_buffer);
436 tokio::spawn(async move {
437 run_reader(read_half, rx_for_task, 65_536).await;
438 });
439 tokio::spawn(async move {
440 run_writer(write_half, tx_rx).await;
441 });
442
443 tx_tx.send(Frame::from_slice(b"helloeth0")).await.unwrap();
445 let mut hdr = [0u8; 4];
446 child.read_exact(&mut hdr).await.unwrap();
447 assert_eq!(u32::from_be_bytes(hdr), b"helloeth0".len() as u32);
448 let mut buf = vec![0u8; b"helloeth0".len()];
449 child.read_exact(&mut buf).await.unwrap();
450 assert_eq!(&buf, b"helloeth0");
451
452 let body = b"replyabc";
454 let mut hdr = (body.len() as u32).to_be_bytes().to_vec();
455 hdr.extend_from_slice(body);
456 child.write_all(&hdr).await.unwrap();
457 for _ in 0..50 {
459 if !rx_buffer.lock().is_empty() {
460 break;
461 }
462 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
463 }
464 let frames = std::mem::take(&mut *rx_buffer.lock());
465 assert_eq!(frames.len(), 1);
466 assert_eq!(frames[0].bytes.as_ref(), body);
467 }
468}