1use std::{
22 collections::{HashMap, VecDeque},
23 net::SocketAddr,
24 sync::Arc,
25 time::Instant,
26};
27
28use ana_gotatun::{
29 noise::{Tunn, TunnResult, handshake::parse_handshake_anon, rate_limiter::RateLimiter},
30 packet::{Packet, WgKind},
31 x25519,
32};
33
34pub struct SnapTunNgServer<T> {
85 static_private: x25519::StaticSecret,
86 static_public: x25519::PublicKey,
87 active_tunnels: HashMap<SocketAddr, (x25519::PublicKey, Tunn)>,
88 rate_limiter: Arc<RateLimiter>,
89 authz: Arc<T>,
90}
91
92impl<T: SnapTunAuthorization> SnapTunNgServer<T> {
93 pub fn new(
95 static_private: x25519::StaticSecret,
96 rate_limiter: Arc<RateLimiter>,
97 authz: Arc<T>,
98 ) -> Self {
99 let static_public = x25519::PublicKey::from(&static_private);
100 Self {
101 static_private,
102 static_public,
103 active_tunnels: Default::default(),
104 rate_limiter,
105 authz,
106 }
107 }
108
109 pub fn handle_incoming_packet(
120 &mut self,
121 packet: Packet,
122 from: SocketAddr,
123 send_to_network: &mut VecDeque<WgKind>,
124 ) -> TunnResult {
125 let now = Instant::now();
126
127 let parsed_packet = match self.rate_limiter.verify_packet(from.ip(), packet) {
128 Ok(p) => p,
129 Err(TunnResult::WriteToNetwork(c)) => {
130 send_to_network.push_back(c);
131 return TunnResult::Done;
132 }
133 Err(e) => return e,
134 };
135
136 use std::collections::hash_map::Entry;
137
138 use ana_gotatun::noise::errors::WireGuardError;
139 match (self.active_tunnels.entry(from), parsed_packet) {
140 (Entry::Occupied(mut occupied_entry), p) => {
141 let (peer_static, tunn) = occupied_entry.get_mut();
142 if !self.authz.is_authorized(now, peer_static.as_bytes()) {
149 return TunnResult::Err(WireGuardError::UnexpectedPacket);
150 }
151 Self::handle_incoming_and_drain_queue(send_to_network, p, tunn)
152 }
153 (e, WgKind::HandshakeInit(wg_init)) => {
154 let peer =
155 match parse_handshake_anon(&self.static_private, &self.static_public, &wg_init)
156 {
157 Ok(v) => v,
158 Err(e) => return TunnResult::from(e),
159 };
160
161 if !self.authz.is_authorized(now, &peer.peer_static_public) {
167 return TunnResult::Err(WireGuardError::UnexpectedPacket);
168 }
169 let peer_static = x25519::PublicKey::from(peer.peer_static_public);
170 let mut tunn = Tunn::new(
171 self.static_private.clone(),
172 peer_static,
173 None,
174 None,
175 0,
176 self.rate_limiter.clone(),
177 from,
178 );
179 let res = Self::handle_incoming_and_drain_queue(
180 send_to_network,
181 WgKind::HandshakeInit(wg_init),
182 &mut tunn,
183 );
184 e.insert_entry((peer_static, tunn));
185 res
186 }
187 (_, _p) => TunnResult::Err(WireGuardError::InvalidPacket),
188 }
189 }
190
191 pub fn handle_outgoing_packet(&mut self, packet: Packet, to: SocketAddr) -> Option<WgKind> {
194 let Some((_, tunn)) = self.active_tunnels.get_mut(&to) else {
195 tracing::error!(to=?to, "No tunnel for outgoing packet found.");
196 return None;
197 };
198 tunn.handle_outgoing_packet(packet.into_bytes())
199 }
200
201 pub fn update_timers(&mut self) -> Vec<WgKind> {
207 let mut res = vec![];
208 self.active_tunnels.retain(|k, (_, tunn)| {
209 match tunn.update_timers() {
210 Ok(Some(wg)) => res.push(wg),
211 Ok(None) => {},
212 Err(e) => tracing::error!(err=?e, remote_sockaddr=?k, "error when updating timers on tunnel"),
213 }
214
215 !tunn.is_expired()
216 });
217 res
218 }
219
220 fn handle_incoming_and_drain_queue(
221 q: &mut VecDeque<WgKind>,
222 p: WgKind,
223 tunn: &mut Tunn,
224 ) -> TunnResult {
225 let r = match tunn.handle_incoming_packet(p) {
226 TunnResult::WriteToNetwork(p) => {
227 q.push_back(p);
228 TunnResult::Done
229 }
230 r => r,
231 };
232 for p in tunn.get_queued_packets() {
233 q.push_back(p);
234 }
235 r
236 }
237}
238
239pub trait SnapTunAuthorization {
241 fn is_authorized(&self, now: Instant, identity: &[u8; 32]) -> bool;
243}
244
245#[cfg(test)]
246mod tests {
247 use std::{collections::VecDeque, net::SocketAddr, sync::Arc};
248
249 use ana_gotatun::{
250 noise::{Tunn, TunnResult, rate_limiter::RateLimiter},
251 packet::{IpNextProtocol, Packet, WgKind},
252 x25519,
253 };
254 use zerocopy::IntoBytes;
255
256 use crate::{
257 scion_packet::{Scion, ScionHeader},
258 server::{SnapTunAuthorization, SnapTunNgServer},
259 };
260
261 type ResultT = Result<(), Box<dyn std::error::Error>>;
262
263 struct TrivialAuthz;
264
265 impl SnapTunAuthorization for TrivialAuthz {
266 fn is_authorized(&self, _now: std::time::Instant, _ident: &[u8; 32]) -> bool {
267 true
268 }
269 }
270
271 #[test]
272 fn connect_with_multiple_clients() -> ResultT {
273 let sockaddr_client0: SocketAddr = "192.168.1.1:1234".parse().unwrap();
274 let static_client0 = x25519::StaticSecret::from([0u8; 32]);
275 let sockaddr_client1: SocketAddr = "192.168.1.2:4321".parse().unwrap();
276 let static_client1 = x25519::StaticSecret::from([1u8; 32]);
277 let sockaddr_server: SocketAddr = "10.0.0.1:5001".parse().unwrap();
278 let static_server = x25519::StaticSecret::from([2u8; 32]);
279 let static_server_public = x25519::PublicKey::from(&static_server);
280
281 let rate_limiter = Arc::new(RateLimiter::new(&static_server_public, 100));
282 let mut snaptun_server =
283 SnapTunNgServer::new(static_server, rate_limiter.clone(), Arc::new(TrivialAuthz));
284
285 let mut send_to_network = VecDeque::<WgKind>::new();
286
287 let test_payload0 = [b'T', b'E', b'S', b'T', b'0'];
288 let test_payload1 = [b'T', b'E', b'S', b'T', b'1'];
289 let test_packet0 = Scion {
290 header: ScionHeader::new(
291 0, 0xAA, 0xABCDE, test_payload0.len() as _, IpNextProtocol::Udp,
296 7, 0x0123_4567_89AB_CDEF,
298 0xFEDC_BA98_7654_3210,
299 ),
300 payload: test_payload0,
301 };
302 let test_packet1 = Scion {
303 header: test_packet0.header,
304 payload: test_payload1,
305 };
306 let test_packet0 = Packet::copy_from(test_packet0.as_bytes());
307 let test_packet1 = Packet::copy_from(test_packet1.as_bytes());
308
309 let mut tunn_client0 = Tunn::new(
310 static_client0,
311 static_server_public,
312 None,
313 None,
314 0,
315 rate_limiter.clone(),
316 sockaddr_server,
317 );
318
319 let mut tunn_client1 = Tunn::new(
320 static_client1,
321 static_server_public,
322 None,
323 None,
324 0,
325 rate_limiter,
326 sockaddr_server,
327 );
328
329 let Some(WgKind::HandshakeInit(hs_init)) =
331 tunn_client0.handle_outgoing_packet(Packet::copy_from(&test_packet0))
332 else {
333 panic!("expected handshake init")
334 };
335
336 snaptun_server.handle_incoming_packet(
337 Packet::copy_from(hs_init.as_bytes()),
338 sockaddr_client0,
339 &mut send_to_network,
340 );
341
342 dispatch_one(&mut tunn_client0, &mut send_to_network);
343 assert_eq!(
344 tunn_client0.get_initiator_remote_sockaddr(),
345 Some(sockaddr_client0)
346 );
347
348 let Some(WgKind::HandshakeInit(hs_init)) =
350 tunn_client1.handle_outgoing_packet(Packet::copy_from(&test_packet1))
351 else {
352 panic!("expected handshake init")
353 };
354
355 snaptun_server.handle_incoming_packet(
356 Packet::copy_from(hs_init.as_bytes()),
357 sockaddr_client1,
358 &mut send_to_network,
359 );
360
361 dispatch_one(&mut tunn_client1, &mut send_to_network);
362 assert_eq!(
363 tunn_client1.get_initiator_remote_sockaddr(),
364 Some(sockaddr_client1)
365 );
366
367 let Some(WgKind::Data(p)) = tunn_client0.get_queued_packets().next() else {
369 panic!("expected packet to be queued");
370 };
371
372 let TunnResult::WriteToTunnel(p) = snaptun_server.handle_incoming_packet(
373 Packet::copy_from(p.as_bytes()),
374 sockaddr_client0,
375 &mut send_to_network,
376 ) else {
377 panic!("Expected packet to be processed")
378 };
379 assert_eq!(p.as_bytes(), test_packet0.as_bytes());
380
381 let Some(WgKind::Data(p1)) = tunn_client1.get_queued_packets().next() else {
385 panic!("expected packet to be queued");
386 };
387
388 let TunnResult::WriteToTunnel(p1) = snaptun_server.handle_incoming_packet(
389 Packet::copy_from(p1.as_bytes()),
390 sockaddr_client1,
391 &mut send_to_network,
392 ) else {
393 panic!("expected packet to be received on server side");
394 };
395 assert_eq!(p1.as_bytes(), test_packet1.as_bytes());
396
397 let res = snaptun_server.handle_outgoing_packet(p, sockaddr_client1);
399 let Some(p @ WgKind::Data(_)) = res else {
400 panic!("expected packet to be sent back to client")
401 };
402
403 let TunnResult::WriteToTunnel(p) = tunn_client1.handle_incoming_packet(p) else {
404 panic!("expected packet to be sent back to client")
405 };
406
407 assert_eq!(p.as_bytes(), test_packet0.as_bytes());
408
409 Ok(())
410 }
411
412 fn dispatch_one(tunn: &mut Tunn, packets: &mut VecDeque<WgKind>) -> TunnResult {
413 if let Some(p) = packets.pop_front() {
414 return tunn.handle_incoming_packet(p);
415 }
416 TunnResult::Done
417 }
418}