Skip to main content

vm_rs/network/
switch.rs

1//! Userspace L2 Ethernet switch.
2//!
3//! Each "network" is a virtual broadcast domain. The switch reads raw Ethernet
4//! frames from VM socketpairs and forwards them using learning-bridge logic:
5//!
6//! 1. Learn source MAC -> port mapping
7//! 2. If destination MAC is known -> unicast to that port
8//! 3. If unknown or broadcast -> flood to all ports on same network
9//! 4. Never forward between different networks
10
11use std::collections::HashMap;
12use std::io;
13use std::os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd};
14use std::sync::{Arc, Mutex, RwLock};
15use std::time::Instant;
16
17use crate::config::VmSocketEndpoint;
18
19/// Minimum Ethernet frame size (without FCS).
20const MIN_FRAME_SIZE: usize = 14;
21/// Maximum Ethernet frame size (jumbo not supported).
22const MAX_FRAME_SIZE: usize = 1518;
23
24/// A 6-byte MAC address.
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub struct MacAddress([u8; 6]);
27
28impl MacAddress {
29    pub fn from_bytes(bytes: &[u8]) -> Option<MacAddress> {
30        if bytes.len() < 6 {
31            return None;
32        }
33        let mut mac = [0u8; 6];
34        mac.copy_from_slice(&bytes[..6]);
35        Some(MacAddress(mac))
36    }
37
38    pub fn is_broadcast(&self) -> bool {
39        self.0 == [0xff, 0xff, 0xff, 0xff, 0xff, 0xff]
40    }
41
42    pub fn is_multicast(&self) -> bool {
43        self.0[0] & 0x01 != 0
44    }
45}
46
47impl std::fmt::Display for MacAddress {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(
50            f,
51            "{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}",
52            self.0[0], self.0[1], self.0[2], self.0[3], self.0[4], self.0[5]
53        )
54    }
55}
56
57/// A port on the switch -- one end of a Unix datagram socketpair connected to a VM's NIC.
58#[derive(Debug)]
59struct SwitchPort {
60    /// The switch's end of the socketpair. The other end goes to the VM.
61    fd: OwnedFd,
62}
63
64impl AsRawFd for SwitchPort {
65    fn as_raw_fd(&self) -> RawFd {
66        self.fd.as_raw_fd()
67    }
68}
69
70/// MAC address table: maps MAC -> (port index, last-seen timestamp).
71type MacTable = HashMap<MacAddress, (usize, Instant)>;
72
73/// The L2 switch. Owns ports grouped by network, runs a forwarding loop.
74pub struct NetworkSwitch {
75    networks: Arc<Mutex<HashMap<String, Vec<SwitchPort>>>>,
76    mac_tables: Arc<RwLock<HashMap<String, MacTable>>>,
77    running: Arc<std::sync::atomic::AtomicBool>,
78    worker: Mutex<Option<std::thread::JoinHandle<()>>>,
79}
80
81impl NetworkSwitch {
82    pub fn new() -> Self {
83        Self {
84            networks: Arc::new(Mutex::new(HashMap::new())),
85            mac_tables: Arc::new(RwLock::new(HashMap::new())),
86            running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
87            worker: Mutex::new(None),
88        }
89    }
90
91    /// Add a port to a network. Returns the VM's end of the socketpair fd.
92    ///
93    /// Creates a Unix SOCK_DGRAM socketpair. One end is kept by the switch (for
94    /// reading/writing Ethernet frames), the other is returned so the caller can
95    /// pass it to VZFileHandleNetworkDeviceAttachment.
96    pub fn add_port(&self, network_id: &str, _label: &str) -> io::Result<VmSocketEndpoint> {
97        let (switch_fd, vm_fd) = create_socketpair()?;
98
99        let port = SwitchPort { fd: switch_fd };
100
101        let mut networks = self
102            .networks
103            .lock()
104            .map_err(|e| io::Error::other(format!("lock poisoned: {}", e)))?;
105        networks
106            .entry(network_id.to_string())
107            .or_default()
108            .push(port);
109
110        let mut mac_tables = self
111            .mac_tables
112            .write()
113            .map_err(|e| io::Error::other(format!("lock poisoned: {}", e)))?;
114        mac_tables.entry(network_id.to_string()).or_default();
115
116        Ok(VmSocketEndpoint::new(vm_fd))
117    }
118
119    /// Start the forwarding loop on a background thread.
120    ///
121    /// Uses poll(2) to watch all switch-side fds. When a frame arrives on any port,
122    /// it's forwarded according to learning-bridge rules.
123    pub fn start(&self) -> io::Result<()> {
124        use std::sync::atomic::Ordering;
125
126        if self.running.load(Ordering::Relaxed) {
127            return Ok(());
128        }
129        self.running.store(true, Ordering::SeqCst);
130
131        let networks = Arc::clone(&self.networks);
132        let mac_tables = Arc::clone(&self.mac_tables);
133        let running = Arc::clone(&self.running);
134
135        let handle = std::thread::Builder::new()
136            .name("network-switch".to_string())
137            .spawn(move || {
138                forwarding_loop(&networks, &mac_tables, &running);
139            })?;
140
141        let mut worker = self
142            .worker
143            .lock()
144            .map_err(|e| io::Error::other(format!("lock poisoned: {}", e)))?;
145        *worker = Some(handle);
146
147        Ok(())
148    }
149
150    /// Stop the forwarding loop.
151    pub fn stop(&self) {
152        self.running
153            .store(false, std::sync::atomic::Ordering::SeqCst);
154        match self.worker.lock() {
155            Ok(mut worker) => {
156                if let Some(handle) = worker.take() {
157                    if let Err(e) = handle.join() {
158                        tracing::error!("network switch thread panicked: {:?}", e);
159                    }
160                }
161            }
162            Err(e) => {
163                tracing::error!("network switch worker lock poisoned during stop: {}", e);
164            }
165        }
166    }
167}
168
169impl Default for NetworkSwitch {
170    fn default() -> Self {
171        Self::new()
172    }
173}
174
175impl Drop for NetworkSwitch {
176    fn drop(&mut self) {
177        self.stop();
178        // OwnedFd in SwitchPort handles closing file descriptors automatically.
179    }
180}
181
182/// Create a Unix SOCK_DGRAM socketpair. Returns (switch_fd, vm_fd).
183fn create_socketpair() -> io::Result<(OwnedFd, OwnedFd)> {
184    let mut fds = [0i32; 2];
185    // SAFETY: Standard POSIX socketpair call with valid fd array.
186    let ret = unsafe { libc::socketpair(libc::AF_UNIX, libc::SOCK_DGRAM, 0, fds.as_mut_ptr()) };
187    if ret != 0 {
188        return Err(io::Error::last_os_error());
189    }
190
191    // SAFETY: socketpair just created these fresh file descriptors; wrapping
192    // them in OwnedFd transfers ownership so they are closed on drop.
193    let switch_fd = unsafe { OwnedFd::from_raw_fd(fds[0]) };
194    let vm_fd = unsafe { OwnedFd::from_raw_fd(fds[1]) };
195
196    // Set switch side non-blocking for the poll loop.
197    // SAFETY: fcntl on file descriptors we just created in the socketpair above.
198    unsafe {
199        let flags = libc::fcntl(switch_fd.as_raw_fd(), libc::F_GETFL);
200        if flags == -1 {
201            return Err(io::Error::last_os_error());
202        }
203        if libc::fcntl(
204            switch_fd.as_raw_fd(),
205            libc::F_SETFL,
206            flags | libc::O_NONBLOCK,
207        ) == -1
208        {
209            return Err(io::Error::last_os_error());
210        }
211    }
212
213    Ok((switch_fd, vm_fd))
214}
215
216/// The main forwarding loop. Runs until `running` is set to false.
217fn forwarding_loop(
218    networks: &Mutex<HashMap<String, Vec<SwitchPort>>>,
219    mac_tables: &RwLock<HashMap<String, MacTable>>,
220    running: &std::sync::atomic::AtomicBool,
221) {
222    use std::sync::atomic::Ordering;
223
224    const MAC_AGE_INTERVAL_SECS: u64 = 30;
225    const MAC_ENTRY_LIFETIME_SECS: u64 = 120;
226
227    let mut buf = [0u8; MAX_FRAME_SIZE];
228    let mut last_aged = Instant::now();
229
230    while running.load(Ordering::Relaxed) {
231        // Periodic MAC aging: sweep entries older than 120s every 30s
232        if last_aged.elapsed().as_secs() >= MAC_AGE_INTERVAL_SECS {
233            if let Ok(mut tables) = mac_tables.write() {
234                for table in tables.values_mut() {
235                    table.retain(|_mac, (_port, ts)| {
236                        ts.elapsed().as_secs() < MAC_ENTRY_LIFETIME_SECS
237                    });
238                }
239            }
240            last_aged = Instant::now();
241        }
242
243        // Build poll fds and snapshot port fds
244        let nets = match networks.lock() {
245            Ok(n) => n,
246            Err(_) => break, // Lock poisoned, bail out
247        };
248        let mut pollfds: Vec<libc::pollfd> = Vec::new();
249        let mut fd_map: Vec<(String, usize)> = Vec::new();
250        let mut port_fds: HashMap<String, Vec<RawFd>> = HashMap::new();
251
252        for (net_id, ports) in nets.iter() {
253            let fds: Vec<RawFd> = ports.iter().map(|p| p.fd.as_raw_fd()).collect();
254            port_fds.insert(net_id.clone(), fds);
255            for (idx, port) in ports.iter().enumerate() {
256                pollfds.push(libc::pollfd {
257                    fd: port.fd.as_raw_fd(),
258                    events: libc::POLLIN,
259                    revents: 0,
260                });
261                fd_map.push((net_id.clone(), idx));
262            }
263        }
264        drop(nets);
265
266        if pollfds.is_empty() {
267            std::thread::sleep(std::time::Duration::from_millis(50));
268            continue;
269        }
270
271        // SAFETY: poll(2) on fds we own. pollfds array is valid and properly sized.
272        let ready = unsafe { libc::poll(pollfds.as_mut_ptr(), pollfds.len() as libc::nfds_t, 50) };
273
274        if ready <= 0 {
275            continue;
276        }
277
278        for (i, pfd) in pollfds.iter().enumerate() {
279            if pfd.revents & libc::POLLIN == 0 {
280                continue;
281            }
282
283            let (ref net_id, src_port_idx) = fd_map[i];
284
285            // Read one Ethernet frame
286            // SAFETY: Reading from our own fd into a stack buffer of MAX_FRAME_SIZE.
287            let n =
288                unsafe { libc::recv(pfd.fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len(), 0) };
289
290            if n < MIN_FRAME_SIZE as isize {
291                continue;
292            }
293            let frame = &buf[..n as usize];
294
295            // Parse Ethernet header: dst MAC (6 bytes) + src MAC (6 bytes)
296            let dst_mac = match MacAddress::from_bytes(&frame[0..6]) {
297                Some(m) => m,
298                None => continue,
299            };
300            let src_mac = match MacAddress::from_bytes(&frame[6..12]) {
301                Some(m) => m,
302                None => continue,
303            };
304
305            // Learn: src MAC -> src port
306            if let Ok(mut tables) = mac_tables.write() {
307                if let Some(table) = tables.get_mut(net_id.as_str()) {
308                    table.insert(src_mac, (src_port_idx, Instant::now()));
309                }
310            }
311
312            // Forward
313            let fds = match port_fds.get(net_id.as_str()) {
314                Some(f) => f,
315                None => continue,
316            };
317
318            if dst_mac.is_broadcast() || dst_mac.is_multicast() {
319                // Flood to all ports on same network except source
320                for (idx, &fd) in fds.iter().enumerate() {
321                    if idx == src_port_idx {
322                        continue;
323                    }
324                    send_frame(fd, frame);
325                }
326            } else {
327                // Unicast: check MAC table
328                let dst_port = match mac_tables.read() {
329                    Ok(tables) => tables
330                        .get(net_id.as_str())
331                        .and_then(|t| t.get(&dst_mac))
332                        .map(|(port_idx, _ts)| *port_idx),
333                    Err(_) => {
334                        // Lock poisoned in the forwarding hot path — flood as fallback
335                        // rather than crashing the switch thread.
336                        tracing::error!("MAC table read lock poisoned, flooding frame");
337                        None
338                    }
339                };
340
341                if let Some(dst_idx) = dst_port {
342                    if dst_idx != src_port_idx && dst_idx < fds.len() {
343                        send_frame(fds[dst_idx], frame);
344                    }
345                } else {
346                    // Unknown destination -- flood
347                    for (idx, &fd) in fds.iter().enumerate() {
348                        if idx == src_port_idx {
349                            continue;
350                        }
351                        send_frame(fd, frame);
352                    }
353                }
354            }
355        }
356    }
357    running.store(false, Ordering::SeqCst);
358}
359
360/// Send a single Ethernet frame to a socket fd. Best-effort (drops if buffer full).
361fn send_frame(fd: RawFd, frame: &[u8]) {
362    // SAFETY: Sending to our own fd. MSG_DONTWAIT prevents blocking on full buffer.
363    let sent = unsafe {
364        libc::send(
365            fd,
366            frame.as_ptr() as *const libc::c_void,
367            frame.len(),
368            libc::MSG_DONTWAIT,
369        )
370    };
371
372    if sent < 0 {
373        let err = std::io::Error::last_os_error();
374        tracing::debug!(fd = fd, error = %err, "dropping frame because send failed");
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381
382    // ── MacAddress ───────────────────────────────────────────────────────
383
384    #[test]
385    fn mac_from_bytes_valid() {
386        let mac = MacAddress::from_bytes(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]).expect("valid MAC");
387        assert_eq!(mac.0, [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
388    }
389
390    #[test]
391    fn mac_from_bytes_too_short() {
392        assert!(MacAddress::from_bytes(&[0xaa, 0xbb]).is_none());
393    }
394
395    #[test]
396    fn mac_from_bytes_empty() {
397        assert!(MacAddress::from_bytes(&[]).is_none());
398    }
399
400    #[test]
401    fn mac_from_bytes_extra_bytes_ignored() {
402        let mac = MacAddress::from_bytes(&[1, 2, 3, 4, 5, 6, 7, 8]).expect("invalid MAC");
403        assert_eq!(mac.0, [1, 2, 3, 4, 5, 6]);
404    }
405
406    #[test]
407    fn mac_broadcast() {
408        let mac = MacAddress([0xff, 0xff, 0xff, 0xff, 0xff, 0xff]);
409        assert!(mac.is_broadcast());
410        assert!(mac.is_multicast()); // broadcast is also multicast
411    }
412
413    #[test]
414    fn mac_not_broadcast() {
415        let mac = MacAddress([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
416        assert!(!mac.is_broadcast());
417    }
418
419    #[test]
420    fn mac_multicast() {
421        // LSB of first octet set = multicast
422        let mac = MacAddress([0x01, 0x00, 0x5e, 0x00, 0x00, 0x01]);
423        assert!(mac.is_multicast());
424        assert!(!mac.is_broadcast());
425    }
426
427    #[test]
428    fn mac_unicast() {
429        // LSB of first octet clear = unicast
430        let mac = MacAddress([0x02, 0x42, 0xac, 0x11, 0x00, 0x02]);
431        assert!(!mac.is_multicast());
432        assert!(!mac.is_broadcast());
433    }
434
435    #[test]
436    fn mac_display_format() {
437        let mac = MacAddress([0x02, 0x42, 0xac, 0x11, 0x00, 0x02]);
438        assert_eq!(format!("{}", mac), "02:42:ac:11:00:02");
439    }
440
441    #[test]
442    fn mac_display_zero() {
443        let mac = MacAddress([0, 0, 0, 0, 0, 0]);
444        assert_eq!(format!("{}", mac), "00:00:00:00:00:00");
445    }
446
447    // ── NetworkSwitch ────────────────────────────────────────────────────
448
449    #[test]
450    fn switch_add_port_returns_fd() {
451        let switch = NetworkSwitch::new();
452        let vm_fd = switch.add_port("net0", "web").expect("add port");
453        assert!(vm_fd.as_raw_fd() >= 0);
454    }
455
456    #[test]
457    fn switch_add_multiple_ports_same_network() {
458        let switch = NetworkSwitch::new();
459        let fd1 = switch.add_port("net0", "web").expect("add web port");
460        let fd2 = switch.add_port("net0", "db").expect("add db port");
461        assert_ne!(fd1.as_raw_fd(), fd2.as_raw_fd());
462    }
463
464    #[test]
465    fn switch_add_ports_different_networks() {
466        let switch = NetworkSwitch::new();
467        let fd1 = switch
468            .add_port("frontend", "web")
469            .expect("add frontend port");
470        let fd2 = switch.add_port("backend", "db").expect("add backend port");
471        assert_ne!(fd1.as_raw_fd(), fd2.as_raw_fd());
472    }
473
474    #[test]
475    fn switch_frame_delivery_same_network() {
476        let switch = NetworkSwitch::new();
477        let fd1 = switch.add_port("net0", "sender").expect("add sender port");
478        let fd2 = switch
479            .add_port("net0", "receiver")
480            .expect("add receiver port");
481        switch.start().expect("start switch");
482
483        // Build a minimal Ethernet frame: dst(6) + src(6) + ethertype(2) = 14 bytes
484        let mut frame = [0u8; 14];
485        // Broadcast destination
486        frame[0..6].copy_from_slice(&[0xff, 0xff, 0xff, 0xff, 0xff, 0xff]);
487        // Source MAC
488        frame[6..12].copy_from_slice(&[0x02, 0x00, 0x00, 0x00, 0x00, 0x01]);
489        // EtherType (arbitrary)
490        frame[12..14].copy_from_slice(&[0x08, 0x00]);
491
492        // Send from fd1 (the VM's end of the socketpair)
493        // SAFETY: Writing to our own socketpair fd.
494        let sent = unsafe {
495            libc::send(
496                fd1.as_raw_fd(),
497                frame.as_ptr() as *const libc::c_void,
498                frame.len(),
499                0,
500            )
501        };
502        assert_eq!(sent, 14);
503
504        // Wait briefly for the switch forwarding loop
505        std::thread::sleep(std::time::Duration::from_millis(200));
506
507        // Read from fd2 (the VM's end of the receiver socketpair)
508        let mut buf = vec![0u8; 1518];
509        // SAFETY: Reading from our own socketpair fd.
510        let recvd = unsafe {
511            libc::recv(
512                fd2.as_raw_fd(),
513                buf.as_mut_ptr() as *mut libc::c_void,
514                buf.len(),
515                libc::MSG_DONTWAIT,
516            )
517        };
518        assert_eq!(
519            recvd, 14,
520            "broadcast frame should be forwarded to the other port"
521        );
522        assert_eq!(&buf[..14], &frame[..14]);
523
524        switch.stop();
525    }
526
527    #[test]
528    fn switch_no_cross_network_forwarding() {
529        let switch = NetworkSwitch::new();
530        let fd1 = switch.add_port("net-a", "sender").expect("add sender port");
531        let fd2 = switch
532            .add_port("net-b", "isolated")
533            .expect("add isolated port");
534        switch.start().expect("start switch");
535
536        // Broadcast frame from net-a
537        let mut frame = [0u8; 14];
538        frame[0..6].copy_from_slice(&[0xff, 0xff, 0xff, 0xff, 0xff, 0xff]);
539        frame[6..12].copy_from_slice(&[0x02, 0x00, 0x00, 0x00, 0x00, 0x01]);
540        frame[12..14].copy_from_slice(&[0x08, 0x00]);
541
542        // SAFETY: Writing to our own socketpair fd.
543        unsafe {
544            libc::send(
545                fd1.as_raw_fd(),
546                frame.as_ptr() as *const libc::c_void,
547                frame.len(),
548                0,
549            );
550        }
551
552        std::thread::sleep(std::time::Duration::from_millis(200));
553
554        // net-b should NOT receive the frame
555        let mut buf = vec![0u8; 1518];
556        // SAFETY: Reading from our own socketpair fd.
557        let recvd = unsafe {
558            libc::recv(
559                fd2.as_raw_fd(),
560                buf.as_mut_ptr() as *mut libc::c_void,
561                buf.len(),
562                libc::MSG_DONTWAIT,
563            )
564        };
565        assert!(recvd <= 0, "frame should NOT cross network boundaries");
566
567        switch.stop();
568    }
569}