Skip to main content

snap_tun/
server.rs

1// Copyright 2026 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! The server of the SNAPtun protocol.
15//!
16//! As the underlying protocol is symmetric (both peers can act as
17//! initiator/responders that establish a session), technically, there is no
18//! server. The term "server" here just refers to and endpoint that manages
19//! multiple peers.
20
21use std::{
22    collections::{HashMap, VecDeque},
23    net::SocketAddr,
24    sync::Arc,
25    time::Instant,
26};
27
28use ana_gotatun::{
29    noise::{Tunn, TunnResult, handshake::parse_handshake_anon, rate_limiter::RateLimiter},
30    packet::{Packet, WgKind},
31    x25519,
32};
33
34/// The [SnapTunNgServer] manages one [Tunn] per remote socket address.
35///
36/// The main structural difference between WireGuard (R) and snaptun-ng is that
37/// there is a one-to-one relation between a remote socket address (of the
38/// initiator) and a tunnel. The [SnapTunNgServer] manages that relation.
39///
40/// ## Scaling
41///
42/// The main methods [SnapTunNgServer::handle_incoming_packet],
43/// [SnapTunNgServer::handle_outgoing_packet], and
44/// [SnapTunNgServer::update_timers] all require an exclusive reference to the
45/// internal state. The reason is that processing both, incoming and outgoing
46/// packets requires access to the session state.
47///
48/// One simple way to achieve load distribution across different cores/threads
49/// is to shard over multiple [SnapTunNgServer]-instances based on a hash of the
50/// remote socket address.
51///
52/// ## Future improvements
53///
54/// * Separate incoming and outgoing code paths and optimistically lock the session state.
55///
56/// ## How to use
57///
58/// The [SnapTunNgServer] is i/o-free; i.e. it only manages state. The following
59/// is a pseudo-code like description of the simplest i/o-layer integration:
60///
61/// ```text
62/// let mut server = SnapTunNgServer::new(/*...*/);
63/// let mut send_to_network = VecDequeue::new();
64/// let mut current_sockaddr = ;
65/// loop {
66///   switch {
67///     (network_packet, sockaddr) = network_socket => {
68///       server.handle_incoming_packet(/*...*/);
69///       /* dispatch packets to tunnel if necessary */
70///     }
71///     tunnel_packet = tunnel_socket => {
72///       server.handle_outgoing_packet(/*...*/);
73///     }
74///     timer = tick(250ms) => {
75///       server.update_timers();
76///     }
77///   }
78///   // dispatch packets to network
79///   for p in send_to_network {
80///     network_socket.send(sockaddr, p);
81///   }
82/// }
83/// ```
84pub struct SnapTunNgServer<T> {
85    static_private: x25519::StaticSecret,
86    static_public: x25519::PublicKey,
87    active_tunnels: HashMap<SocketAddr, (x25519::PublicKey, Tunn)>,
88    rate_limiter: Arc<RateLimiter>,
89    authz: Arc<T>,
90}
91
92impl<T: SnapTunAuthorization> SnapTunNgServer<T> {
93    /// Creates a new [SnapTunNgServer] instance.
94    pub fn new(
95        static_private: x25519::StaticSecret,
96        rate_limiter: Arc<RateLimiter>,
97        authz: Arc<T>,
98    ) -> Self {
99        let static_public = x25519::PublicKey::from(&static_private);
100        Self {
101            static_private,
102            static_public,
103            active_tunnels: Default::default(),
104            rate_limiter,
105            authz,
106        }
107    }
108
109    /// Handle incoming packet for a tunnel assocated with remote socket address
110    /// `from`.
111    ///
112    /// This method _never_ returns [TunnResult::WriteToNetwork]. Instead,
113    /// it codifies the expected protocol behavior which is that, upon receiving
114    /// a packet from the remote, the queue of outgoing packets is completely
115    /// drained.
116    ///
117    /// If the rate limiter signals that the server is under load, at most one
118    /// packet is added to the queue.
119    pub fn handle_incoming_packet(
120        &mut self,
121        packet: Packet,
122        from: SocketAddr,
123        send_to_network: &mut VecDeque<WgKind>,
124    ) -> TunnResult {
125        let now = Instant::now();
126
127        let parsed_packet = match self.rate_limiter.verify_packet(from.ip(), packet) {
128            Ok(p) => p,
129            Err(TunnResult::WriteToNetwork(c)) => {
130                send_to_network.push_back(c);
131                return TunnResult::Done;
132            }
133            Err(e) => return e,
134        };
135
136        use std::collections::hash_map::Entry;
137
138        use ana_gotatun::noise::errors::WireGuardError;
139        match (self.active_tunnels.entry(from), parsed_packet) {
140            (Entry::Occupied(mut occupied_entry), p) => {
141                let (peer_static, tunn) = occupied_entry.get_mut();
142                // TODO(dsd): At the moment, this keeps a tunnel alive even
143                // though the processing might fail, but gives the authorization
144                // layer a chance to block incomding packets in case an identity
145                // is unauthorized.
146                //
147                // Will fix later.
148                if !self.authz.is_authorized(now, peer_static.as_bytes()) {
149                    return TunnResult::Err(WireGuardError::UnexpectedPacket);
150                }
151                Self::handle_incoming_and_drain_queue(send_to_network, p, tunn)
152            }
153            (e, WgKind::HandshakeInit(wg_init)) => {
154                let peer =
155                    match parse_handshake_anon(&self.static_private, &self.static_public, &wg_init)
156                    {
157                        Ok(v) => v,
158                        Err(e) => return TunnResult::from(e),
159                    };
160
161                // TODO(dsd): if the socket is occupied, and tunnel.identity !=
162                // peer.identity, then send a cookie and abort
163
164                // TODO(dsd): extend ana-gotatun::Tunn such that peer static
165                // identity can be retrieved
166                if !self.authz.is_authorized(now, &peer.peer_static_public) {
167                    return TunnResult::Err(WireGuardError::UnexpectedPacket);
168                }
169                let peer_static = x25519::PublicKey::from(peer.peer_static_public);
170                let mut tunn = Tunn::new(
171                    self.static_private.clone(),
172                    peer_static,
173                    None,
174                    None,
175                    0,
176                    self.rate_limiter.clone(),
177                    from,
178                );
179                let res = Self::handle_incoming_and_drain_queue(
180                    send_to_network,
181                    WgKind::HandshakeInit(wg_init),
182                    &mut tunn,
183                );
184                e.insert_entry((peer_static, tunn));
185                res
186            }
187            (_, _p) => TunnResult::Err(WireGuardError::InvalidPacket),
188        }
189    }
190
191    /// Handles an outgoing packet sent through the tunnel identified by the
192    /// remote socket address `to`.
193    pub fn handle_outgoing_packet(&mut self, packet: Packet, to: SocketAddr) -> Option<WgKind> {
194        let Some((_, tunn)) = self.active_tunnels.get_mut(&to) else {
195            tracing::error!(to=?to, "No tunnel for outgoing packet found.");
196            return None;
197        };
198        tunn.handle_outgoing_packet(packet.into_bytes())
199    }
200
201    /// Update timers of all tunnels. Generate corresponding keepalive or
202    /// session handshake initializations.
203    ///
204    /// As a result of this call, all expired tunnels are removed. Note that
205    /// this is not the same as unauthorized tunnels.
206    pub fn update_timers(&mut self) -> Vec<WgKind> {
207        let mut res = vec![];
208        self.active_tunnels.retain(|k, (_, tunn)| {
209            match tunn.update_timers() {
210                Ok(Some(wg)) => res.push(wg),
211                Ok(None) => {},
212                Err(e) => tracing::error!(err=?e, remote_sockaddr=?k, "error when updating timers on tunnel"),
213            }
214
215            !tunn.is_expired()
216        });
217        res
218    }
219
220    fn handle_incoming_and_drain_queue(
221        q: &mut VecDeque<WgKind>,
222        p: WgKind,
223        tunn: &mut Tunn,
224    ) -> TunnResult {
225        let r = match tunn.handle_incoming_packet(p) {
226            TunnResult::WriteToNetwork(p) => {
227                q.push_back(p);
228                TunnResult::Done
229            }
230            r => r,
231        };
232        for p in tunn.get_queued_packets() {
233            q.push_back(p);
234        }
235        r
236    }
237}
238
239/// Authorization layer for the snaptun server.
240pub trait SnapTunAuthorization {
241    /// Returns true iff the peer is allowed to send traffic to the server.
242    fn is_authorized(&self, now: Instant, identity: &[u8; 32]) -> bool;
243}
244
245#[cfg(test)]
246mod tests {
247    use std::{collections::VecDeque, net::SocketAddr, sync::Arc};
248
249    use ana_gotatun::{
250        noise::{Tunn, TunnResult, rate_limiter::RateLimiter},
251        packet::{IpNextProtocol, Packet, WgKind},
252        x25519,
253    };
254    use zerocopy::IntoBytes;
255
256    use crate::{
257        scion_packet::{Scion, ScionHeader},
258        server::{SnapTunAuthorization, SnapTunNgServer},
259    };
260
261    type ResultT = Result<(), Box<dyn std::error::Error>>;
262
263    struct TrivialAuthz;
264
265    impl SnapTunAuthorization for TrivialAuthz {
266        fn is_authorized(&self, _now: std::time::Instant, _ident: &[u8; 32]) -> bool {
267            true
268        }
269    }
270
271    #[test]
272    fn connect_with_multiple_clients() -> ResultT {
273        let sockaddr_client0: SocketAddr = "192.168.1.1:1234".parse().unwrap();
274        let static_client0 = x25519::StaticSecret::from([0u8; 32]);
275        let sockaddr_client1: SocketAddr = "192.168.1.2:4321".parse().unwrap();
276        let static_client1 = x25519::StaticSecret::from([1u8; 32]);
277        let sockaddr_server: SocketAddr = "10.0.0.1:5001".parse().unwrap();
278        let static_server = x25519::StaticSecret::from([2u8; 32]);
279        let static_server_public = x25519::PublicKey::from(&static_server);
280
281        let rate_limiter = Arc::new(RateLimiter::new(&static_server_public, 100));
282        let mut snaptun_server =
283            SnapTunNgServer::new(static_server, rate_limiter.clone(), Arc::new(TrivialAuthz));
284
285        let mut send_to_network = VecDeque::<WgKind>::new();
286
287        let test_payload0 = [b'T', b'E', b'S', b'T', b'0'];
288        let test_payload1 = [b'T', b'E', b'S', b'T', b'1'];
289        let test_packet0 = Scion {
290            header: ScionHeader::new(
291                0,                        // version
292                0xAA,                     // traffic_class
293                0xABCDE,                  // flow_id (20 bits)
294                test_payload0.len() as _, // payload_len
295                IpNextProtocol::Udp,
296                7, // hop_count
297                0x0123_4567_89AB_CDEF,
298                0xFEDC_BA98_7654_3210,
299            ),
300            payload: test_payload0,
301        };
302        let test_packet1 = Scion {
303            header: test_packet0.header,
304            payload: test_payload1,
305        };
306        let test_packet0 = Packet::copy_from(test_packet0.as_bytes());
307        let test_packet1 = Packet::copy_from(test_packet1.as_bytes());
308
309        let mut tunn_client0 = Tunn::new(
310            static_client0,
311            static_server_public,
312            None,
313            None,
314            0,
315            rate_limiter.clone(),
316            sockaddr_server,
317        );
318
319        let mut tunn_client1 = Tunn::new(
320            static_client1,
321            static_server_public,
322            None,
323            None,
324            0,
325            rate_limiter,
326            sockaddr_server,
327        );
328
329        /* handshake 0 */
330        let Some(WgKind::HandshakeInit(hs_init)) =
331            tunn_client0.handle_outgoing_packet(Packet::copy_from(&test_packet0))
332        else {
333            panic!("expected handshake init")
334        };
335
336        snaptun_server.handle_incoming_packet(
337            Packet::copy_from(hs_init.as_bytes()),
338            sockaddr_client0,
339            &mut send_to_network,
340        );
341
342        dispatch_one(&mut tunn_client0, &mut send_to_network);
343        assert_eq!(
344            tunn_client0.get_initiator_remote_sockaddr(),
345            Some(sockaddr_client0)
346        );
347
348        /* handshake 1 */
349        let Some(WgKind::HandshakeInit(hs_init)) =
350            tunn_client1.handle_outgoing_packet(Packet::copy_from(&test_packet1))
351        else {
352            panic!("expected handshake init")
353        };
354
355        snaptun_server.handle_incoming_packet(
356            Packet::copy_from(hs_init.as_bytes()),
357            sockaddr_client1,
358            &mut send_to_network,
359        );
360
361        dispatch_one(&mut tunn_client1, &mut send_to_network);
362        assert_eq!(
363            tunn_client1.get_initiator_remote_sockaddr(),
364            Some(sockaddr_client1)
365        );
366
367        /* send C0 -> S */
368        let Some(WgKind::Data(p)) = tunn_client0.get_queued_packets().next() else {
369            panic!("expected packet to be queued");
370        };
371
372        let TunnResult::WriteToTunnel(p) = snaptun_server.handle_incoming_packet(
373            Packet::copy_from(p.as_bytes()),
374            sockaddr_client0,
375            &mut send_to_network,
376        ) else {
377            panic!("Expected packet to be processed")
378        };
379        assert_eq!(p.as_bytes(), test_packet0.as_bytes());
380
381        /* send C1 -> S */
382        // before we can send a packet to client1, we need to send a packet from
383        // client1 so the server starts using the session.
384        let Some(WgKind::Data(p1)) = tunn_client1.get_queued_packets().next() else {
385            panic!("expected packet to be queued");
386        };
387
388        let TunnResult::WriteToTunnel(p1) = snaptun_server.handle_incoming_packet(
389            Packet::copy_from(p1.as_bytes()),
390            sockaddr_client1,
391            &mut send_to_network,
392        ) else {
393            panic!("expected packet to be received on server side");
394        };
395        assert_eq!(p1.as_bytes(), test_packet1.as_bytes());
396
397        /* send S -> C1 */
398        let res = snaptun_server.handle_outgoing_packet(p, sockaddr_client1);
399        let Some(p @ WgKind::Data(_)) = res else {
400            panic!("expected packet to be sent back to client")
401        };
402
403        let TunnResult::WriteToTunnel(p) = tunn_client1.handle_incoming_packet(p) else {
404            panic!("expected packet to be sent back to client")
405        };
406
407        assert_eq!(p.as_bytes(), test_packet0.as_bytes());
408
409        Ok(())
410    }
411
412    fn dispatch_one(tunn: &mut Tunn, packets: &mut VecDeque<WgKind>) -> TunnResult {
413        if let Some(p) = packets.pop_front() {
414            return tunn.handle_incoming_packet(p);
415        }
416        TunnResult::Done
417    }
418}