1use std::collections::HashMap;
12use std::io;
13use std::os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd};
14use std::sync::{Arc, Mutex, RwLock};
15use std::time::Instant;
16
17use crate::config::VmSocketEndpoint;
18
19const MIN_FRAME_SIZE: usize = 14;
21const MAX_FRAME_SIZE: usize = 1518;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub struct MacAddress([u8; 6]);
27
28impl MacAddress {
29 pub fn from_bytes(bytes: &[u8]) -> Option<MacAddress> {
30 if bytes.len() < 6 {
31 return None;
32 }
33 let mut mac = [0u8; 6];
34 mac.copy_from_slice(&bytes[..6]);
35 Some(MacAddress(mac))
36 }
37
38 pub fn is_broadcast(&self) -> bool {
39 self.0 == [0xff, 0xff, 0xff, 0xff, 0xff, 0xff]
40 }
41
42 pub fn is_multicast(&self) -> bool {
43 self.0[0] & 0x01 != 0
44 }
45}
46
47impl std::fmt::Display for MacAddress {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(
50 f,
51 "{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}",
52 self.0[0], self.0[1], self.0[2], self.0[3], self.0[4], self.0[5]
53 )
54 }
55}
56
57#[derive(Debug)]
59struct SwitchPort {
60 fd: OwnedFd,
62}
63
64impl AsRawFd for SwitchPort {
65 fn as_raw_fd(&self) -> RawFd {
66 self.fd.as_raw_fd()
67 }
68}
69
70type MacTable = HashMap<MacAddress, (usize, Instant)>;
72
73pub struct NetworkSwitch {
75 networks: Arc<Mutex<HashMap<String, Vec<SwitchPort>>>>,
76 mac_tables: Arc<RwLock<HashMap<String, MacTable>>>,
77 running: Arc<std::sync::atomic::AtomicBool>,
78 worker: Mutex<Option<std::thread::JoinHandle<()>>>,
79}
80
81impl NetworkSwitch {
82 pub fn new() -> Self {
83 Self {
84 networks: Arc::new(Mutex::new(HashMap::new())),
85 mac_tables: Arc::new(RwLock::new(HashMap::new())),
86 running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
87 worker: Mutex::new(None),
88 }
89 }
90
91 pub fn add_port(&self, network_id: &str, _label: &str) -> io::Result<VmSocketEndpoint> {
97 let (switch_fd, vm_fd) = create_socketpair()?;
98
99 let port = SwitchPort { fd: switch_fd };
100
101 let mut networks = self
102 .networks
103 .lock()
104 .map_err(|e| io::Error::other(format!("lock poisoned: {}", e)))?;
105 networks
106 .entry(network_id.to_string())
107 .or_default()
108 .push(port);
109
110 let mut mac_tables = self
111 .mac_tables
112 .write()
113 .map_err(|e| io::Error::other(format!("lock poisoned: {}", e)))?;
114 mac_tables.entry(network_id.to_string()).or_default();
115
116 Ok(VmSocketEndpoint::new(vm_fd))
117 }
118
119 pub fn start(&self) -> io::Result<()> {
124 use std::sync::atomic::Ordering;
125
126 if self.running.load(Ordering::Relaxed) {
127 return Ok(());
128 }
129 self.running.store(true, Ordering::SeqCst);
130
131 let networks = Arc::clone(&self.networks);
132 let mac_tables = Arc::clone(&self.mac_tables);
133 let running = Arc::clone(&self.running);
134
135 let handle = std::thread::Builder::new()
136 .name("network-switch".to_string())
137 .spawn(move || {
138 forwarding_loop(&networks, &mac_tables, &running);
139 })?;
140
141 let mut worker = self
142 .worker
143 .lock()
144 .map_err(|e| io::Error::other(format!("lock poisoned: {}", e)))?;
145 *worker = Some(handle);
146
147 Ok(())
148 }
149
150 pub fn stop(&self) {
152 self.running
153 .store(false, std::sync::atomic::Ordering::SeqCst);
154 match self.worker.lock() {
155 Ok(mut worker) => {
156 if let Some(handle) = worker.take() {
157 if let Err(e) = handle.join() {
158 tracing::error!("network switch thread panicked: {:?}", e);
159 }
160 }
161 }
162 Err(e) => {
163 tracing::error!("network switch worker lock poisoned during stop: {}", e);
164 }
165 }
166 }
167}
168
169impl Default for NetworkSwitch {
170 fn default() -> Self {
171 Self::new()
172 }
173}
174
175impl Drop for NetworkSwitch {
176 fn drop(&mut self) {
177 self.stop();
178 }
180}
181
182fn create_socketpair() -> io::Result<(OwnedFd, OwnedFd)> {
184 let mut fds = [0i32; 2];
185 let ret = unsafe { libc::socketpair(libc::AF_UNIX, libc::SOCK_DGRAM, 0, fds.as_mut_ptr()) };
187 if ret != 0 {
188 return Err(io::Error::last_os_error());
189 }
190
191 let switch_fd = unsafe { OwnedFd::from_raw_fd(fds[0]) };
194 let vm_fd = unsafe { OwnedFd::from_raw_fd(fds[1]) };
195
196 unsafe {
199 let flags = libc::fcntl(switch_fd.as_raw_fd(), libc::F_GETFL);
200 if flags == -1 {
201 return Err(io::Error::last_os_error());
202 }
203 if libc::fcntl(
204 switch_fd.as_raw_fd(),
205 libc::F_SETFL,
206 flags | libc::O_NONBLOCK,
207 ) == -1
208 {
209 return Err(io::Error::last_os_error());
210 }
211 }
212
213 Ok((switch_fd, vm_fd))
214}
215
216fn forwarding_loop(
218 networks: &Mutex<HashMap<String, Vec<SwitchPort>>>,
219 mac_tables: &RwLock<HashMap<String, MacTable>>,
220 running: &std::sync::atomic::AtomicBool,
221) {
222 use std::sync::atomic::Ordering;
223
224 const MAC_AGE_INTERVAL_SECS: u64 = 30;
225 const MAC_ENTRY_LIFETIME_SECS: u64 = 120;
226
227 let mut buf = [0u8; MAX_FRAME_SIZE];
228 let mut last_aged = Instant::now();
229
230 while running.load(Ordering::Relaxed) {
231 if last_aged.elapsed().as_secs() >= MAC_AGE_INTERVAL_SECS {
233 if let Ok(mut tables) = mac_tables.write() {
234 for table in tables.values_mut() {
235 table.retain(|_mac, (_port, ts)| {
236 ts.elapsed().as_secs() < MAC_ENTRY_LIFETIME_SECS
237 });
238 }
239 }
240 last_aged = Instant::now();
241 }
242
243 let nets = match networks.lock() {
245 Ok(n) => n,
246 Err(_) => break, };
248 let mut pollfds: Vec<libc::pollfd> = Vec::new();
249 let mut fd_map: Vec<(String, usize)> = Vec::new();
250 let mut port_fds: HashMap<String, Vec<RawFd>> = HashMap::new();
251
252 for (net_id, ports) in nets.iter() {
253 let fds: Vec<RawFd> = ports.iter().map(|p| p.fd.as_raw_fd()).collect();
254 port_fds.insert(net_id.clone(), fds);
255 for (idx, port) in ports.iter().enumerate() {
256 pollfds.push(libc::pollfd {
257 fd: port.fd.as_raw_fd(),
258 events: libc::POLLIN,
259 revents: 0,
260 });
261 fd_map.push((net_id.clone(), idx));
262 }
263 }
264 drop(nets);
265
266 if pollfds.is_empty() {
267 std::thread::sleep(std::time::Duration::from_millis(50));
268 continue;
269 }
270
271 let ready = unsafe { libc::poll(pollfds.as_mut_ptr(), pollfds.len() as libc::nfds_t, 50) };
273
274 if ready <= 0 {
275 continue;
276 }
277
278 for (i, pfd) in pollfds.iter().enumerate() {
279 if pfd.revents & libc::POLLIN == 0 {
280 continue;
281 }
282
283 let (ref net_id, src_port_idx) = fd_map[i];
284
285 let n =
288 unsafe { libc::recv(pfd.fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len(), 0) };
289
290 if n < MIN_FRAME_SIZE as isize {
291 continue;
292 }
293 let frame = &buf[..n as usize];
294
295 let dst_mac = match MacAddress::from_bytes(&frame[0..6]) {
297 Some(m) => m,
298 None => continue,
299 };
300 let src_mac = match MacAddress::from_bytes(&frame[6..12]) {
301 Some(m) => m,
302 None => continue,
303 };
304
305 if let Ok(mut tables) = mac_tables.write() {
307 if let Some(table) = tables.get_mut(net_id.as_str()) {
308 table.insert(src_mac, (src_port_idx, Instant::now()));
309 }
310 }
311
312 let fds = match port_fds.get(net_id.as_str()) {
314 Some(f) => f,
315 None => continue,
316 };
317
318 if dst_mac.is_broadcast() || dst_mac.is_multicast() {
319 for (idx, &fd) in fds.iter().enumerate() {
321 if idx == src_port_idx {
322 continue;
323 }
324 send_frame(fd, frame);
325 }
326 } else {
327 let dst_port = match mac_tables.read() {
329 Ok(tables) => tables
330 .get(net_id.as_str())
331 .and_then(|t| t.get(&dst_mac))
332 .map(|(port_idx, _ts)| *port_idx),
333 Err(_) => {
334 tracing::error!("MAC table read lock poisoned, flooding frame");
337 None
338 }
339 };
340
341 if let Some(dst_idx) = dst_port {
342 if dst_idx != src_port_idx && dst_idx < fds.len() {
343 send_frame(fds[dst_idx], frame);
344 }
345 } else {
346 for (idx, &fd) in fds.iter().enumerate() {
348 if idx == src_port_idx {
349 continue;
350 }
351 send_frame(fd, frame);
352 }
353 }
354 }
355 }
356 }
357 running.store(false, Ordering::SeqCst);
358}
359
360fn send_frame(fd: RawFd, frame: &[u8]) {
362 let sent = unsafe {
364 libc::send(
365 fd,
366 frame.as_ptr() as *const libc::c_void,
367 frame.len(),
368 libc::MSG_DONTWAIT,
369 )
370 };
371
372 if sent < 0 {
373 let err = std::io::Error::last_os_error();
374 tracing::debug!(fd = fd, error = %err, "dropping frame because send failed");
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
385 fn mac_from_bytes_valid() {
386 let mac = MacAddress::from_bytes(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]).expect("valid MAC");
387 assert_eq!(mac.0, [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
388 }
389
390 #[test]
391 fn mac_from_bytes_too_short() {
392 assert!(MacAddress::from_bytes(&[0xaa, 0xbb]).is_none());
393 }
394
395 #[test]
396 fn mac_from_bytes_empty() {
397 assert!(MacAddress::from_bytes(&[]).is_none());
398 }
399
400 #[test]
401 fn mac_from_bytes_extra_bytes_ignored() {
402 let mac = MacAddress::from_bytes(&[1, 2, 3, 4, 5, 6, 7, 8]).expect("invalid MAC");
403 assert_eq!(mac.0, [1, 2, 3, 4, 5, 6]);
404 }
405
406 #[test]
407 fn mac_broadcast() {
408 let mac = MacAddress([0xff, 0xff, 0xff, 0xff, 0xff, 0xff]);
409 assert!(mac.is_broadcast());
410 assert!(mac.is_multicast()); }
412
413 #[test]
414 fn mac_not_broadcast() {
415 let mac = MacAddress([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
416 assert!(!mac.is_broadcast());
417 }
418
419 #[test]
420 fn mac_multicast() {
421 let mac = MacAddress([0x01, 0x00, 0x5e, 0x00, 0x00, 0x01]);
423 assert!(mac.is_multicast());
424 assert!(!mac.is_broadcast());
425 }
426
427 #[test]
428 fn mac_unicast() {
429 let mac = MacAddress([0x02, 0x42, 0xac, 0x11, 0x00, 0x02]);
431 assert!(!mac.is_multicast());
432 assert!(!mac.is_broadcast());
433 }
434
435 #[test]
436 fn mac_display_format() {
437 let mac = MacAddress([0x02, 0x42, 0xac, 0x11, 0x00, 0x02]);
438 assert_eq!(format!("{}", mac), "02:42:ac:11:00:02");
439 }
440
441 #[test]
442 fn mac_display_zero() {
443 let mac = MacAddress([0, 0, 0, 0, 0, 0]);
444 assert_eq!(format!("{}", mac), "00:00:00:00:00:00");
445 }
446
447 #[test]
450 fn switch_add_port_returns_fd() {
451 let switch = NetworkSwitch::new();
452 let vm_fd = switch.add_port("net0", "web").expect("add port");
453 assert!(vm_fd.as_raw_fd() >= 0);
454 }
455
456 #[test]
457 fn switch_add_multiple_ports_same_network() {
458 let switch = NetworkSwitch::new();
459 let fd1 = switch.add_port("net0", "web").expect("add web port");
460 let fd2 = switch.add_port("net0", "db").expect("add db port");
461 assert_ne!(fd1.as_raw_fd(), fd2.as_raw_fd());
462 }
463
464 #[test]
465 fn switch_add_ports_different_networks() {
466 let switch = NetworkSwitch::new();
467 let fd1 = switch
468 .add_port("frontend", "web")
469 .expect("add frontend port");
470 let fd2 = switch.add_port("backend", "db").expect("add backend port");
471 assert_ne!(fd1.as_raw_fd(), fd2.as_raw_fd());
472 }
473
474 #[test]
475 fn switch_frame_delivery_same_network() {
476 let switch = NetworkSwitch::new();
477 let fd1 = switch.add_port("net0", "sender").expect("add sender port");
478 let fd2 = switch
479 .add_port("net0", "receiver")
480 .expect("add receiver port");
481 switch.start().expect("start switch");
482
483 let mut frame = [0u8; 14];
485 frame[0..6].copy_from_slice(&[0xff, 0xff, 0xff, 0xff, 0xff, 0xff]);
487 frame[6..12].copy_from_slice(&[0x02, 0x00, 0x00, 0x00, 0x00, 0x01]);
489 frame[12..14].copy_from_slice(&[0x08, 0x00]);
491
492 let sent = unsafe {
495 libc::send(
496 fd1.as_raw_fd(),
497 frame.as_ptr() as *const libc::c_void,
498 frame.len(),
499 0,
500 )
501 };
502 assert_eq!(sent, 14);
503
504 std::thread::sleep(std::time::Duration::from_millis(200));
506
507 let mut buf = vec![0u8; 1518];
509 let recvd = unsafe {
511 libc::recv(
512 fd2.as_raw_fd(),
513 buf.as_mut_ptr() as *mut libc::c_void,
514 buf.len(),
515 libc::MSG_DONTWAIT,
516 )
517 };
518 assert_eq!(
519 recvd, 14,
520 "broadcast frame should be forwarded to the other port"
521 );
522 assert_eq!(&buf[..14], &frame[..14]);
523
524 switch.stop();
525 }
526
527 #[test]
528 fn switch_no_cross_network_forwarding() {
529 let switch = NetworkSwitch::new();
530 let fd1 = switch.add_port("net-a", "sender").expect("add sender port");
531 let fd2 = switch
532 .add_port("net-b", "isolated")
533 .expect("add isolated port");
534 switch.start().expect("start switch");
535
536 let mut frame = [0u8; 14];
538 frame[0..6].copy_from_slice(&[0xff, 0xff, 0xff, 0xff, 0xff, 0xff]);
539 frame[6..12].copy_from_slice(&[0x02, 0x00, 0x00, 0x00, 0x00, 0x01]);
540 frame[12..14].copy_from_slice(&[0x08, 0x00]);
541
542 unsafe {
544 libc::send(
545 fd1.as_raw_fd(),
546 frame.as_ptr() as *const libc::c_void,
547 frame.len(),
548 0,
549 );
550 }
551
552 std::thread::sleep(std::time::Duration::from_millis(200));
553
554 let mut buf = vec![0u8; 1518];
556 let recvd = unsafe {
558 libc::recv(
559 fd2.as_raw_fd(),
560 buf.as_mut_ptr() as *mut libc::c_void,
561 buf.len(),
562 libc::MSG_DONTWAIT,
563 )
564 };
565 assert!(recvd <= 0, "frame should NOT cross network boundaries");
566
567 switch.stop();
568 }
569}