1use std::{
22 collections::{HashMap, VecDeque},
23 net::SocketAddr,
24 sync::Arc,
25 time::Instant,
26};
27
28use ana_gotatun::{
29 noise::{Tunn, TunnResult, handshake::parse_handshake_anon, rate_limiter::RateLimiter},
30 packet::{Packet, WgKind},
31 x25519,
32};
33
34pub struct SnapTunNgServer<T> {
85 static_private: x25519::StaticSecret,
86 static_public: x25519::PublicKey,
87 active_tunnels: HashMap<SocketAddr, (x25519::PublicKey, Tunn)>,
88 rate_limiter: Arc<RateLimiter>,
89 authz: Arc<T>,
90}
91
92impl<T: SnapTunAuthorization> SnapTunNgServer<T> {
93 pub fn new(
95 static_private: x25519::StaticSecret,
96 rate_limiter: Arc<RateLimiter>,
97 authz: Arc<T>,
98 ) -> Self {
99 let static_public = x25519::PublicKey::from(&static_private);
100 Self {
101 static_private,
102 static_public,
103 active_tunnels: Default::default(),
104 rate_limiter,
105 authz,
106 }
107 }
108
109 pub fn handle_incoming_packet(
120 &mut self,
121 packet: Packet,
122 from: SocketAddr,
123 send_to_network: &mut VecDeque<WgKind>,
124 ) -> TunnResult {
125 let now = Instant::now();
126
127 let parsed_packet = match self.rate_limiter.verify_packet(from.ip(), packet) {
128 Ok(p) => p,
129 Err(TunnResult::WriteToNetwork(c)) => {
130 send_to_network.push_back(c);
131 return TunnResult::Done;
132 }
133 Err(e) => return e,
134 };
135
136 use std::collections::hash_map::Entry;
137
138 use ana_gotatun::noise::errors::WireGuardError;
139 match (self.active_tunnels.entry(from), parsed_packet) {
140 (Entry::Occupied(mut occupied_entry), p) => {
141 let (peer_static, tunn) = occupied_entry.get_mut();
142 if !self.authz.is_authorized(now, peer_static.as_bytes()) {
149 return TunnResult::Err(WireGuardError::UnexpectedPacket);
150 }
151 Self::handle_incoming_and_drain_queue(send_to_network, p, tunn)
152 }
153 (e, WgKind::HandshakeInit(wg_init)) => {
154 let peer =
155 match parse_handshake_anon(&self.static_private, &self.static_public, &wg_init)
156 {
157 Ok(v) => v,
158 Err(e) => return TunnResult::from(e),
159 };
160
161 if !self.authz.is_authorized(now, &peer.peer_static_public) {
167 return TunnResult::Err(WireGuardError::UnexpectedPacket);
168 }
169 let peer_static = x25519::PublicKey::from(peer.peer_static_public);
170 let mut tunn = Tunn::new(
171 self.static_private.clone(),
172 peer_static,
173 None,
174 None,
175 0,
176 self.rate_limiter.clone(),
177 from,
178 );
179 let res = Self::handle_incoming_and_drain_queue(
180 send_to_network,
181 WgKind::HandshakeInit(wg_init),
182 &mut tunn,
183 );
184 e.insert_entry((peer_static, tunn));
185 res
186 }
187 (_, _p) => TunnResult::Err(WireGuardError::InvalidPacket),
188 }
189 }
190
191 pub fn handle_outgoing_packet(&mut self, packet: Packet, to: SocketAddr) -> Option<WgKind> {
194 let Some((_, tunn)) = self.active_tunnels.get_mut(&to) else {
195 tracing::error!(to=?to, "No tunnel for outgoing packet found.");
196 return None;
197 };
198 tunn.handle_outgoing_packet(packet.into_bytes())
199 }
200
201 pub fn update_timers(&mut self) -> Vec<(SocketAddr, WgKind)> {
207 let mut res = vec![];
208 self.active_tunnels.retain(|k, (_, tunn)| {
209 match tunn.update_timers() {
210 Ok(Some(wg)) => res.push((*k, wg)),
211 Ok(None) => {},
212 Err(e) => tracing::error!(err=?e, remote_sockaddr=?k, "error when updating timers on tunnel"),
213 }
214
215 !tunn.is_expired()
216 });
217 res
218 }
219
220 fn handle_incoming_and_drain_queue(
221 q: &mut VecDeque<WgKind>,
222 p: WgKind,
223 tunn: &mut Tunn,
224 ) -> TunnResult {
225 let r = match tunn.handle_incoming_packet(p) {
226 TunnResult::WriteToNetwork(p) => {
227 q.push_back(p);
228 TunnResult::Done
229 }
230 TunnResult::WriteToTunnel(p) if p.is_empty() => TunnResult::Done,
232 r => r,
233 };
234 for p in tunn.get_queued_packets() {
235 q.push_back(p);
236 }
237 r
238 }
239}
240
241pub trait SnapTunAuthorization: Send + Sync {
243 fn is_authorized(&self, now: Instant, identity: &[u8; 32]) -> bool;
245}
246
247#[cfg(test)]
248mod tests {
249 use std::{collections::VecDeque, net::SocketAddr, sync::Arc};
250
251 use ana_gotatun::{
252 noise::{Tunn, TunnResult, rate_limiter::RateLimiter},
253 packet::{IpNextProtocol, Packet, WgKind},
254 x25519,
255 };
256 use zerocopy::IntoBytes;
257
258 use crate::{
259 scion_packet::{Scion, ScionHeader},
260 server::{SnapTunAuthorization, SnapTunNgServer},
261 };
262
263 type ResultT = Result<(), Box<dyn std::error::Error>>;
264
265 struct TrivialAuthz;
266
267 impl SnapTunAuthorization for TrivialAuthz {
268 fn is_authorized(&self, _now: std::time::Instant, _ident: &[u8; 32]) -> bool {
269 true
270 }
271 }
272
273 #[test]
274 fn connect_with_multiple_clients() -> ResultT {
275 let sockaddr_client0: SocketAddr = "192.168.1.1:1234".parse().unwrap();
276 let static_client0 = x25519::StaticSecret::from([0u8; 32]);
277 let sockaddr_client1: SocketAddr = "192.168.1.2:4321".parse().unwrap();
278 let static_client1 = x25519::StaticSecret::from([1u8; 32]);
279 let sockaddr_server: SocketAddr = "10.0.0.1:5001".parse().unwrap();
280 let static_server = x25519::StaticSecret::from([2u8; 32]);
281 let static_server_public = x25519::PublicKey::from(&static_server);
282
283 let rate_limiter = Arc::new(RateLimiter::new(&static_server_public, 100));
284 let mut snaptun_server =
285 SnapTunNgServer::new(static_server, rate_limiter.clone(), Arc::new(TrivialAuthz));
286
287 let mut send_to_network = VecDeque::<WgKind>::new();
288
289 let test_payload0 = [b'T', b'E', b'S', b'T', b'0'];
290 let test_payload1 = [b'T', b'E', b'S', b'T', b'1'];
291 let test_packet0 = Scion {
292 header: ScionHeader::new(
293 0, 0xAA, 0xABCDE, test_payload0.len() as _, IpNextProtocol::Udp,
298 7, 0x0123_4567_89AB_CDEF,
300 0xFEDC_BA98_7654_3210,
301 ),
302 payload: test_payload0,
303 };
304 let test_packet1 = Scion {
305 header: test_packet0.header,
306 payload: test_payload1,
307 };
308 let test_packet0 = Packet::copy_from(test_packet0.as_bytes());
309 let test_packet1 = Packet::copy_from(test_packet1.as_bytes());
310
311 let mut tunn_client0 = Tunn::new(
312 static_client0,
313 static_server_public,
314 None,
315 None,
316 0,
317 rate_limiter.clone(),
318 sockaddr_server,
319 );
320
321 let mut tunn_client1 = Tunn::new(
322 static_client1,
323 static_server_public,
324 None,
325 None,
326 0,
327 rate_limiter,
328 sockaddr_server,
329 );
330
331 let Some(WgKind::HandshakeInit(hs_init)) =
333 tunn_client0.handle_outgoing_packet(Packet::copy_from(&test_packet0))
334 else {
335 panic!("expected handshake init")
336 };
337
338 snaptun_server.handle_incoming_packet(
339 Packet::copy_from(hs_init.as_bytes()),
340 sockaddr_client0,
341 &mut send_to_network,
342 );
343
344 dispatch_one(&mut tunn_client0, &mut send_to_network);
345 assert_eq!(
346 tunn_client0.get_initiator_remote_sockaddr(),
347 Some(sockaddr_client0)
348 );
349
350 let Some(WgKind::HandshakeInit(hs_init)) =
352 tunn_client1.handle_outgoing_packet(Packet::copy_from(&test_packet1))
353 else {
354 panic!("expected handshake init")
355 };
356
357 snaptun_server.handle_incoming_packet(
358 Packet::copy_from(hs_init.as_bytes()),
359 sockaddr_client1,
360 &mut send_to_network,
361 );
362
363 dispatch_one(&mut tunn_client1, &mut send_to_network);
364 assert_eq!(
365 tunn_client1.get_initiator_remote_sockaddr(),
366 Some(sockaddr_client1)
367 );
368
369 let Some(WgKind::Data(p)) = tunn_client0.get_queued_packets().next() else {
371 panic!("expected packet to be queued");
372 };
373
374 let TunnResult::WriteToTunnel(p) = snaptun_server.handle_incoming_packet(
375 Packet::copy_from(p.as_bytes()),
376 sockaddr_client0,
377 &mut send_to_network,
378 ) else {
379 panic!("Expected packet to be processed")
380 };
381 assert_eq!(p.as_bytes(), test_packet0.as_bytes());
382
383 let Some(WgKind::Data(p1)) = tunn_client1.get_queued_packets().next() else {
387 panic!("expected packet to be queued");
388 };
389
390 let TunnResult::WriteToTunnel(p1) = snaptun_server.handle_incoming_packet(
391 Packet::copy_from(p1.as_bytes()),
392 sockaddr_client1,
393 &mut send_to_network,
394 ) else {
395 panic!("expected packet to be received on server side");
396 };
397 assert_eq!(p1.as_bytes(), test_packet1.as_bytes());
398
399 let res = snaptun_server.handle_outgoing_packet(p, sockaddr_client1);
401 let Some(p @ WgKind::Data(_)) = res else {
402 panic!("expected packet to be sent back to client")
403 };
404
405 let TunnResult::WriteToTunnel(p) = tunn_client1.handle_incoming_packet(p) else {
406 panic!("expected packet to be sent back to client")
407 };
408
409 assert_eq!(p.as_bytes(), test_packet0.as_bytes());
410
411 Ok(())
412 }
413
414 fn dispatch_one(tunn: &mut Tunn, packets: &mut VecDeque<WgKind>) -> TunnResult {
415 if let Some(p) = packets.pop_front() {
416 return tunn.handle_incoming_packet(p);
417 }
418 TunnResult::Done
419 }
420}