1use crate::device::VirtioNetworkDevice;
54use crate::queues::NetworkFrameQueues;
55use crate::tcp_listeners::AcceptedTcpConnection;
56use crate::tcp_relay::{spawn_tcp_relay, TcpRelayTable};
57use crate::{virtio_net_log, DEFAULT_DNS_ADDR};
58use smoltcp::iface::{
59 Config, Interface, PollIngressSingleResult, PollResult, SocketHandle, SocketSet,
60};
61use smoltcp::socket::udp::{PacketBuffer, PacketMetadata, Socket as UdpSocket, UdpMetadata};
62use smoltcp::time::Instant;
63use smoltcp::wire::{
64 EthernetAddress, EthernetFrame, EthernetProtocol, HardwareAddress, IpAddress, IpCidr,
65 Ipv4Packet, TcpPacket, UdpPacket,
66};
67use std::net::{Ipv4Addr, SocketAddr, UdpSocket as HostUdpSocket};
68use std::sync::atomic::Ordering;
69use std::sync::mpsc::{Receiver, TryRecvError};
70use std::sync::Arc;
71use std::thread::{self, JoinHandle};
72use std::time::{Duration, Instant as StdInstant};
73
74const DNS_SOCKET_PORT: u16 = 53;
75const DNS_PACKET_SLOTS: usize = 8;
76const DNS_BUFFER_BYTES: usize = 2048;
77const DEFAULT_IDLE_TIMEOUT_MS: i32 = 100;
78
79#[derive(Debug, Clone, Copy)]
85pub struct VirtioPollConfig {
86 pub gateway_mac: [u8; 6],
88 pub guest_mac: [u8; 6],
90 pub gateway_ipv4: Ipv4Addr,
92 pub guest_ipv4: Ipv4Addr,
94 pub mtu: usize,
96}
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99enum FrameAction {
100 TcpSyn {
101 source: SocketAddr,
102 destination: SocketAddr,
103 },
104 DnsQuery,
105 UnsupportedUdp,
106 Passthrough,
107}
108
109pub fn start_network_stack(
119 queues: Arc<NetworkFrameQueues>,
120 config: VirtioPollConfig,
121 tcp_receiver: Option<Receiver<AcceptedTcpConnection>>,
122) -> std::io::Result<JoinHandle<()>> {
123 virtio_net_log!(
124 "virtio-net: spawning poll thread guest_ip={} gateway_ip={} mtu={}",
125 config.guest_ipv4,
126 config.gateway_ipv4,
127 config.mtu
128 );
129 thread::Builder::new()
130 .name("smolvm-net-poll".into())
131 .spawn(move || run_network_stack(queues, config, tcp_receiver))
132}
133
134fn run_network_stack(
135 queues: Arc<NetworkFrameQueues>,
136 config: VirtioPollConfig,
137 mut tcp_receiver: Option<Receiver<AcceptedTcpConnection>>,
138) {
139 virtio_net_log!(
153 "virtio-net: poll loop started guest_ip={} gateway_ip={}",
154 config.guest_ipv4,
155 config.gateway_ipv4
156 );
157 let clock = StdInstant::now();
158 let mut device = VirtioNetworkDevice::new(queues.clone(), config.mtu);
159 let mut interface = create_interface(&mut device, &config);
160 let mut sockets = SocketSet::new(vec![]);
161 let dns_socket_handle = add_dns_socket(&mut sockets, config.gateway_ipv4);
162 let relay_wake = Arc::new(queues.relay_wake.clone());
163 let mut relays = TcpRelayTable::new(None);
164
165 let mut poll_fds = [
169 libc::pollfd {
170 fd: queues.guest_wake.as_raw_fd(),
171 events: libc::POLLIN,
172 revents: 0,
173 },
174 libc::pollfd {
175 fd: queues.relay_wake.as_raw_fd(),
176 events: libc::POLLIN,
177 revents: 0,
178 },
179 ];
180
181 loop {
182 if queues.is_shutting_down() {
183 return;
184 }
185 let now = smoltcp_now(clock);
186
187 while let Some(frame) = device.stage_next_frame() {
188 match classify_guest_frame(frame) {
194 FrameAction::TcpSyn {
195 source,
196 destination,
197 } => {
198 virtio_net_log!(
199 "virtio-net: guest TCP SYN source={} destination={}",
200 source,
201 destination
202 );
203 if !relays.has_socket_for(&source, &destination) {
204 relays.create_tcp_socket(source, destination, &mut sockets);
205 }
206 if matches!(
207 interface.poll_ingress_single(now, &mut device, &mut sockets),
208 PollIngressSingleResult::None
209 ) {
210 device.drop_staged_frame();
211 }
212 }
213 FrameAction::DnsQuery | FrameAction::Passthrough => {
214 if matches!(
215 interface.poll_ingress_single(now, &mut device, &mut sockets),
216 PollIngressSingleResult::None
217 ) {
218 device.drop_staged_frame();
219 }
220 }
221 FrameAction::UnsupportedUdp => {
222 virtio_net_log!("virtio-net: dropping unsupported guest UDP datagram");
225 device.drop_staged_frame();
226 }
227 }
228 }
229
230 relay_accepted_tcp_connection(
231 &mut tcp_receiver,
232 &mut relays,
233 &mut interface,
234 &mut sockets,
235 config.gateway_ipv4,
236 config.guest_ipv4,
237 );
238
239 flush_interface_egress(&mut interface, &mut device, &mut sockets, now);
242 interface.poll_maintenance(now);
243 wake_guest_if_needed(&queues, &device);
244
245 relays.relay_data(&mut sockets);
248 process_dns_queries(dns_socket_handle, &mut sockets);
249
250 for connection in relays.take_new_connections(&mut sockets) {
253 spawn_tcp_relay(
254 connection.destination,
255 connection.relay_target,
256 connection.from_smoltcp,
257 connection.to_smoltcp,
258 relay_wake.clone(),
259 connection.exit_state,
260 );
261 }
262
263 relays.cleanup_closed(&mut sockets);
264
265 flush_interface_egress(&mut interface, &mut device, &mut sockets, now);
268 wake_guest_if_needed(&queues, &device);
269
270 let timeout_ms = interface
271 .poll_delay(now, &sockets)
272 .map(|duration| duration.total_millis().min(i32::MAX as u64) as i32)
273 .unwrap_or(DEFAULT_IDLE_TIMEOUT_MS);
274
275 unsafe {
277 libc::poll(
278 poll_fds.as_mut_ptr(),
279 poll_fds.len() as libc::nfds_t,
280 timeout_ms,
281 );
282 }
283
284 if poll_fds[0].revents & libc::POLLIN != 0 {
285 queues.guest_wake.drain();
286 }
287 if poll_fds[1].revents & libc::POLLIN != 0 {
288 queues.relay_wake.drain();
289 }
290 }
291}
292
293fn create_interface(device: &mut VirtioNetworkDevice, config: &VirtioPollConfig) -> Interface {
294 let mut interface = Interface::new(
303 Config::new(HardwareAddress::Ethernet(EthernetAddress(
304 config.gateway_mac,
305 ))),
306 device,
307 Instant::ZERO,
308 );
309 interface.update_ip_addrs(|addresses| {
310 addresses
311 .push(IpCidr::new(IpAddress::Ipv4(config.gateway_ipv4), 30))
312 .expect("failed to add gateway IPv4 address");
313 });
314 interface
318 .routes_mut()
319 .add_default_ipv4_route(config.gateway_ipv4)
320 .expect("failed to add default IPv4 route");
321 interface.set_any_ip(true);
322 interface
323}
324
325fn add_dns_socket(sockets: &mut SocketSet<'_>, gateway_ipv4: Ipv4Addr) -> SocketHandle {
330 let rx_meta = vec![PacketMetadata::EMPTY; DNS_PACKET_SLOTS];
331 let tx_meta = vec![PacketMetadata::EMPTY; DNS_PACKET_SLOTS];
332 let rx_buffer = PacketBuffer::new(rx_meta, vec![0u8; DNS_BUFFER_BYTES]);
333 let tx_buffer = PacketBuffer::new(tx_meta, vec![0u8; DNS_BUFFER_BYTES]);
334 let mut socket = UdpSocket::new(rx_buffer, tx_buffer);
335 socket
336 .bind(smoltcp::wire::IpListenEndpoint {
337 addr: Some(gateway_ipv4.into()),
338 port: DNS_SOCKET_PORT,
339 })
340 .expect("failed to bind gateway DNS socket");
341 sockets.add(socket)
342}
343
344fn relay_accepted_tcp_connection(
347 tcp_receiver: &mut Option<Receiver<AcceptedTcpConnection>>,
348 relays: &mut TcpRelayTable,
349 interface: &mut Interface,
350 sockets: &mut SocketSet<'_>,
351 gateway_ipv4: Ipv4Addr,
352 guest_ipv4: Ipv4Addr,
353) {
354 let mut disconnected = false;
364
365 if let Some(receiver) = tcp_receiver.as_mut() {
366 loop {
367 match receiver.try_recv() {
368 Ok(connection) => {
369 let guest_destination =
370 SocketAddr::new(std::net::IpAddr::V4(guest_ipv4), connection.guest_port);
371 virtio_net_log!(
372 "virtio-net: accepted published TCP connection peer={} host_port={} guest_destination={}",
373 connection.peer_addr,
374 connection.host_port,
375 guest_destination
376 );
377 if !relays.create_published_socket(
378 interface,
379 gateway_ipv4,
380 guest_destination,
381 connection.stream,
382 sockets,
383 ) {
384 tracing::warn!(
385 host_port = connection.host_port,
386 guest_port = connection.guest_port,
387 peer_addr = %connection.peer_addr,
388 "dropping published TCP connection because the guest relay path could not be created"
389 );
390 }
391 }
392 Err(TryRecvError::Empty) => break,
393 Err(TryRecvError::Disconnected) => {
394 disconnected = true;
395 break;
396 }
397 }
398 }
399 }
400
401 if disconnected {
402 *tcp_receiver = None;
403 }
404}
405
406fn process_dns_queries(dns_socket_handle: SocketHandle, sockets: &mut SocketSet<'_>) {
407 let upstream_dns = match DEFAULT_DNS_ADDR {
411 std::net::IpAddr::V4(ip) => ip,
412 std::net::IpAddr::V6(_) => return,
413 };
414
415 let socket = sockets.get_mut::<UdpSocket>(dns_socket_handle);
416 while socket.can_recv() {
417 let (query, metadata) = match socket.recv() {
418 Ok(result) => result,
419 Err(_) => break,
420 };
421 virtio_net_log!(
422 "virtio-net: forwarding guest DNS query guest={} local_address={:?} query_len={} upstream_dns={}",
423 metadata.endpoint,
424 metadata.local_address,
425 query.len(),
426 upstream_dns
427 );
428 let response = match forward_dns_query(upstream_dns, query) {
429 Ok(response) => response,
430 Err(err) => {
431 virtio_net_log!("virtio-net: host DNS forwarding failed error={}", err);
432 continue;
433 }
434 };
435 virtio_net_log!(
436 "virtio-net: forwarded DNS response back to guest guest={} response_len={}",
437 metadata.endpoint,
438 response.len()
439 );
440
441 let response_meta = UdpMetadata {
442 endpoint: metadata.endpoint,
443 local_address: metadata.local_address,
444 meta: Default::default(),
445 };
446 let _ = socket.send_slice(&response, response_meta);
447 }
448}
449
450fn forward_dns_query(upstream_dns: Ipv4Addr, query: &[u8]) -> std::io::Result<Vec<u8>> {
451 let socket = HostUdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0))?;
459 socket.set_read_timeout(Some(Duration::from_secs(2)))?;
460 let local_addr = socket.local_addr()?;
461 virtio_net_log!(
462 "virtio-net: sending DNS query to upstream resolver local_addr={} upstream_dns={} query_len={}",
463 local_addr,
464 upstream_dns,
465 query.len()
466 );
467 socket.send_to(query, (upstream_dns, DNS_SOCKET_PORT))?;
468
469 let mut buffer = vec![0u8; DNS_BUFFER_BYTES];
470 let (bytes_read, _) = socket.recv_from(&mut buffer)?;
471 buffer.truncate(bytes_read);
472 virtio_net_log!(
473 "virtio-net: received DNS response from upstream resolver upstream_dns={} response_len={}",
474 upstream_dns,
475 buffer.len()
476 );
477 Ok(buffer)
478}
479
480fn flush_interface_egress(
481 interface: &mut Interface,
482 device: &mut VirtioNetworkDevice,
483 sockets: &mut SocketSet<'_>,
484 now: Instant,
485) {
486 loop {
490 let result = interface.poll_egress(now, device, sockets);
491 if matches!(result, PollResult::None) {
492 break;
493 }
494 }
495}
496
497fn wake_guest_if_needed(queues: &NetworkFrameQueues, device: &VirtioNetworkDevice) {
498 if device.frames_emitted.swap(false, Ordering::Relaxed) {
502 queues.host_wake.wake();
503 }
504}
505
506fn smoltcp_now(clock: StdInstant) -> Instant {
507 let elapsed = clock.elapsed();
508 Instant::from_millis(elapsed.as_millis() as i64)
509}
510
511fn classify_guest_frame(frame: &[u8]) -> FrameAction {
512 let ethernet = match EthernetFrame::new_checked(frame) {
513 Ok(frame) => frame,
514 Err(_) => return FrameAction::Passthrough,
515 };
516
517 if ethernet.ethertype() != EthernetProtocol::Ipv4 {
518 return FrameAction::Passthrough;
519 }
520
521 let ipv4 = match Ipv4Packet::new_checked(ethernet.payload()) {
522 Ok(packet) => packet,
523 Err(_) => return FrameAction::Passthrough,
524 };
525
526 match ipv4.next_header() {
527 smoltcp::wire::IpProtocol::Tcp => {
528 let tcp = match TcpPacket::new_checked(ipv4.payload()) {
529 Ok(packet) => packet,
530 Err(_) => return FrameAction::Passthrough,
531 };
532
533 if tcp.syn() && !tcp.ack() {
534 FrameAction::TcpSyn {
535 source: SocketAddr::new(std::net::IpAddr::V4(ipv4.src_addr()), tcp.src_port()),
536 destination: SocketAddr::new(
537 std::net::IpAddr::V4(ipv4.dst_addr()),
538 tcp.dst_port(),
539 ),
540 }
541 } else {
542 FrameAction::Passthrough
543 }
544 }
545 smoltcp::wire::IpProtocol::Udp => {
546 let udp = match UdpPacket::new_checked(ipv4.payload()) {
547 Ok(packet) => packet,
548 Err(_) => return FrameAction::Passthrough,
549 };
550
551 if udp.dst_port() == DNS_SOCKET_PORT {
552 FrameAction::DnsQuery
553 } else {
554 FrameAction::UnsupportedUdp
555 }
556 }
557 _ => FrameAction::Passthrough,
558 }
559}