Skip to main content

snap_tun/
client.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! SNAP tunnel client.
15
16use std::{
17    io,
18    net::SocketAddr,
19    pin,
20    sync::{Arc, Mutex},
21    task::ready,
22    time::Duration,
23};
24
25use ana_gotatun::{
26    noise::{Tunn, TunnResult, errors::WireGuardError, rate_limiter::RateLimiter},
27    packet::{Packet, WgKind},
28    x25519::{self},
29};
30use bytes::{Bytes, BytesMut};
31use tokio::{
32    select,
33    sync::mpsc::{self, error::TrySendError},
34    task::JoinHandle,
35    time::Interval,
36};
37use tracing::instrument;
38use zerocopy::IntoBytes as _;
39
40const UDP_DATAGRAM_BUFFER_SIZE: usize = 65535;
41const HANDSHAKE_RATE_LIMIT: u64 = 20;
42
43/// Error when sending or receiving packets on the SNAP tunnel.
44#[derive(Debug, thiserror::Error)]
45pub enum SnapTunNgSocketError {
46    /// I/O error.
47    #[error("i/o error: {0}")]
48    IoError(#[from] std::io::Error),
49    /// Receive queue closed.
50    #[error("receive queue closed")]
51    ReceiveQueueClosed,
52    /// Initial handshake timed out.
53    #[error("initial handshake timed out")]
54    InitialHandshakeTimeout,
55    /// Wireguard error.
56    #[error("wireguard error: {0:?}")]
57    WireguardError(WireGuardError),
58}
59
60struct SnapTunNgClientDriver {
61    pub tunn: Arc<Mutex<Tunn>>,
62    pub underlay_socket: Arc<tokio::net::UdpSocket>,
63    pub dataplane_address: SocketAddr,
64    pub update_timers_interval: Interval,
65    pub packet_sender: mpsc::Sender<BytesMut>,
66    pub local_sockaddr: Option<SocketAddr>,
67}
68
69fn to_bytes(wg: WgKind) -> Packet<[u8]> {
70    match wg {
71        WgKind::HandshakeInit(p) => p.into_bytes(),
72        WgKind::HandshakeResp(p) => p.into_bytes(),
73        WgKind::CookieReply(p) => p.into_bytes(),
74        WgKind::Data(p) => p.into_bytes(),
75    }
76}
77
78impl SnapTunNgClientDriver {
79    fn new(
80        tunn: Arc<Mutex<Tunn>>,
81        underlay_socket: Arc<tokio::net::UdpSocket>,
82        dataplane_address: SocketAddr,
83        packet_sender: mpsc::Sender<BytesMut>,
84    ) -> Self {
85        let update_timers_interval = tokio::time::interval_at(
86            tokio::time::Instant::now() + Duration::from_millis(250),
87            Duration::from_millis(250),
88        );
89        Self {
90            tunn,
91            underlay_socket,
92            dataplane_address,
93            update_timers_interval,
94            packet_sender,
95            local_sockaddr: None,
96        }
97    }
98
99    #[instrument(name = "st-client", skip(self), fields(socket_addr= ?self.local_sockaddr))]
100    async fn initial_connection(&mut self) -> Result<SocketAddr, SnapTunNgSocketError> {
101        let handshake_init = self.tunn.lock().unwrap().format_handshake_initiation(false);
102        if let Some(wg_init) = handshake_init
103            && let Err(e) = self
104                .underlay_socket
105                .send_to(
106                    to_bytes(WgKind::HandshakeInit(wg_init)).as_bytes(),
107                    self.dataplane_address,
108                )
109                .await
110        {
111            return Err(SnapTunNgSocketError::IoError(e));
112        }
113        // Drive the tunnel until any error occurs or the handshake is completed.
114        loop {
115            self.drive_once().await?;
116            if let Some(sockaddr) = self.tunn.lock().unwrap().get_initiator_remote_sockaddr() {
117                self.local_sockaddr = Some(sockaddr);
118                tracing::debug!(socket_addr=?sockaddr, "handshake completed, socket address assigned");
119                return Ok(sockaddr);
120            }
121        }
122    }
123
124    #[instrument(name = "st-client", skip(self), fields(socket_addr= ?self.local_sockaddr))]
125    async fn main_loop(mut self) {
126        loop {
127            let result = self.drive_once().await;
128            if let Err(ref e) = result {
129                tracing::error!(err=?e, "error driving tunnel");
130            }
131            if let Err(SnapTunNgSocketError::ReceiveQueueClosed) = result {
132                tracing::info!("receive queue closed, snap tunnel driver shutting down");
133                return;
134            }
135        }
136    }
137
138    async fn drive_once(&mut self) -> Result<(), SnapTunNgSocketError> {
139        let mut buf = BytesMut::zeroed(UDP_DATAGRAM_BUFFER_SIZE);
140        select! {
141            _ = self.update_timers_interval.tick() => {
142                let p = match self.tunn.lock().unwrap().update_timers() {
143                    Ok(Some(wg)) => { Some(wg) },
144                    Ok(None) => None,
145                    Err(WireGuardError::ConnectionExpired) => {
146                        return Err(SnapTunNgSocketError::InitialHandshakeTimeout);
147                    }
148                    Err(e) => {
149                        tracing::error!(err=?e, "error updating timers on tunnel");
150                        None
151                    }
152                };
153                if let Some(wg) = p && let Err(e) = self.underlay_socket.send_to(to_bytes(wg).as_bytes(), self.dataplane_address).await {
154                    return Err(SnapTunNgSocketError::IoError(e));
155                }
156            },
157            recv = self.underlay_socket.recv_from(&mut buf) => {
158                let (n, sender_addr) = match recv {
159                    Ok((n, sender_addr)) => (n, sender_addr),
160                    Err(e) => {
161                        return Err(SnapTunNgSocketError::IoError(e));
162                    }
163                };
164                if sender_addr != self.dataplane_address {
165                    // Ignore packets that are not from the dataplane address.
166                    return Ok(());
167                }
168                buf.truncate(n);
169                let packet: Packet<[u8]> = Packet::from_bytes(buf);
170                let wg = packet.try_into_wg().expect("this needs to be handled");
171                // Process the packet and release the lock before accessing it again
172                let result = self.tunn.lock().unwrap().handle_incoming_packet(wg);
173                let ps = match result {
174                    TunnResult::Done => None,
175                    TunnResult::Err(e) => {
176                        return Err(SnapTunNgSocketError::WireguardError(e));
177                    }
178                    TunnResult::WriteToNetwork(p) => {
179                        // Send all queued packets to the network.
180                        let queued_packets = self.tunn.lock().unwrap().get_queued_packets().collect::<Vec<_>>();
181                        let packets = std::iter::once(p).chain(queued_packets.into_iter());
182                        Some(packets)
183
184                    }
185                    TunnResult::WriteToTunnel(mut p) => {
186                        let buf = p.buf_mut().to_owned();
187
188                        // Ignore empty packets, they are keepalive packets.
189                        if !buf.is_empty() {
190                            match self.packet_sender.try_send(buf) {
191                                Ok(()) => {},
192                                Err(TrySendError::Full(_)) => {
193                                    tracing::error!("receive channel is full, dropping packet");
194                                }
195                                Err(_) => {
196                                    // The channel is closed. Stop the task.
197                                    return Err(SnapTunNgSocketError::ReceiveQueueClosed);
198                                }
199                            }
200                        }
201                        None
202                    }
203                };
204                if let Some(packets) = ps {
205                    for p in packets {
206                        if let Err(e) = self.underlay_socket.send_to(to_bytes(p).as_bytes(), self.dataplane_address).await {
207                            return Err(SnapTunNgSocketError::IoError(e));
208                        }
209                    }
210                }
211            }
212        }
213        Ok(())
214    }
215}
216
217/// A SNAP tun ng socket.
218pub struct SnapTunNgSocket {
219    tunn: Arc<Mutex<Tunn>>,
220    underlay_socket: Arc<tokio::net::UdpSocket>,
221    dataplane_address: SocketAddr,
222    local_sockaddr: SocketAddr,
223    receive_queue: tokio::sync::Mutex<mpsc::Receiver<BytesMut>>,
224    /// Tasks that drives the SNAP tunnel.
225    /// Cancelled when the socket is dropped.
226    driver_task: JoinHandle<()>,
227}
228
229impl Drop for SnapTunNgSocket {
230    fn drop(&mut self) {
231        self.driver_task.abort();
232    }
233}
234
235impl SnapTunNgSocket {
236    /// Creates a new SNAP tunnel and waits for the handshake to complete.
237    ///
238    /// # Arguments
239    ///
240    /// * `static_private` - The client's static private key
241    /// * `peer_public` - The server's static public key (needed for handshake)
242    /// * `rate_limiter` - Rate limiter for the tunnel
243    /// * `underlay_socket` - UDP socket for sending/receiving packets
244    /// * `dataplane_address` - Address of the remote server
245    /// * `receive_queue_capacity` - Capacity of the receive queue
246    pub async fn new(
247        static_private: x25519::StaticSecret,
248        peer_public: x25519::PublicKey,
249        underlay_socket: Arc<tokio::net::UdpSocket>,
250        dataplane_address: SocketAddr,
251        receive_queue_capacity: usize,
252    ) -> Result<Self, SnapTunNgSocketError> {
253        let local_public = x25519::PublicKey::from(&static_private);
254        let tunn = Arc::new(Mutex::new(Tunn::new(
255            static_private,
256            peer_public,
257            None,
258            None,
259            0,
260            Arc::new(RateLimiter::new(&local_public, HANDSHAKE_RATE_LIMIT)),
261            dataplane_address,
262        )));
263        let (packet_sender, packet_receiver) = mpsc::channel(receive_queue_capacity);
264        let mut driver = SnapTunNgClientDriver::new(
265            tunn.clone(),
266            underlay_socket.clone(),
267            dataplane_address,
268            packet_sender,
269        );
270        let socket_addr = driver.initial_connection().await?;
271        Ok(Self {
272            tunn,
273            underlay_socket,
274            dataplane_address,
275            local_sockaddr: socket_addr,
276            // TODO(uniquefine): This should be refactored to use a more efficient receive queue.
277            // https://github.com/Anapaya/scion/issues/27487
278            receive_queue: tokio::sync::Mutex::new(packet_receiver),
279            driver_task: tokio::spawn(driver.main_loop()),
280        })
281    }
282
283    /// Send a packet to the remote server.
284    #[instrument(name = "st-client", skip_all, fields(socket_addr= ?self.local_sockaddr, payload_len= payload.len()))]
285    pub async fn send(&self, payload: BytesMut) -> io::Result<()> {
286        let packet: Packet<[u8]> = Packet::from_bytes(payload);
287        let encapsulated_packet = self.tunn.lock().unwrap().handle_outgoing_packet(packet);
288        match encapsulated_packet {
289            Some(wg) => {
290                let bytes = match wg {
291                    WgKind::HandshakeInit(p) => p.into_bytes(),
292                    WgKind::HandshakeResp(p) => p.into_bytes(),
293                    WgKind::CookieReply(p) => p.into_bytes(),
294                    WgKind::Data(p) => p.into_bytes(),
295                };
296                tracing::trace!(dataplane_address=?self.dataplane_address, "sending packet");
297                self.underlay_socket
298                    .send_to(bytes.as_bytes(), self.dataplane_address)
299                    .await?;
300                Ok(())
301            }
302            None => {
303                // None is returned if a handshake is ongoing but not yet complete.
304                // In this case the packet is queued and will be sent when the handshake is
305                // complete.
306                tracing::trace!("handshake ongoing, queueing packet");
307                Ok(())
308            }
309        }
310    }
311
312    /// Try to send a packet to the remote server. Returns error of try_send_to.
313    #[instrument(name = "st-client", skip_all, fields(socket_addr= ?self.local_sockaddr, payload_len= payload.len()))]
314    pub fn try_send(&self, payload: BytesMut) -> io::Result<()> {
315        let packet: Packet<[u8]> = Packet::from_bytes(payload);
316        match self.tunn.lock().unwrap().handle_outgoing_packet(packet) {
317            Some(wg) => {
318                let bytes = match wg {
319                    WgKind::HandshakeInit(p) => p.into_bytes(),
320                    WgKind::HandshakeResp(p) => p.into_bytes(),
321                    WgKind::CookieReply(p) => p.into_bytes(),
322                    WgKind::Data(p) => p.into_bytes(),
323                };
324                tracing::trace!(dataplane_address=?self.dataplane_address, "trying to send packet");
325                self.underlay_socket
326                    .try_send_to(bytes.as_bytes(), self.dataplane_address)?;
327                Ok(())
328            }
329            None => {
330                // None is returned if a handshake is ongoing but not yet complete.
331                // In this case the packet is queued and will be sent when the handshake is
332                // complete.
333                Ok(())
334            }
335        }
336    }
337
338    /// Receive a packet from the remote server.
339    pub async fn recv(&self) -> Result<Bytes, SnapTunNgSocketError> {
340        match self.receive_queue.lock().await.recv().await {
341            Some(packet) => Ok(packet.into()),
342            None => Err(SnapTunNgSocketError::ReceiveQueueClosed),
343        }
344    }
345
346    /// Poll for a packet from the remote server.
347    pub fn poll_recv(
348        &self,
349        cx: &mut std::task::Context<'_>,
350    ) -> std::task::Poll<Result<Bytes, SnapTunNgSocketError>> {
351        let mut receiver = ready!(pin::pin!(self.receive_queue.lock()).poll(cx));
352        match receiver.poll_recv(cx) {
353            std::task::Poll::Ready(Some(packet)) => std::task::Poll::Ready(Ok(packet.into())),
354            std::task::Poll::Ready(None) => {
355                tracing::trace!("receive queue closed, returning error");
356                std::task::Poll::Ready(Err(SnapTunNgSocketError::ReceiveQueueClosed))
357            }
358            std::task::Poll::Pending => std::task::Poll::Pending,
359        }
360    }
361
362    /// Get the local socket address. Assigned by the remote server.
363    pub fn local_addr(&self) -> SocketAddr {
364        self.local_sockaddr
365    }
366
367    /// Check if the socket is writable.
368    pub async fn writable(&self) -> io::Result<()> {
369        self.underlay_socket.writable().await
370    }
371}