snap_tun/
server.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! # The snaptun server.
15//!
16//! This module contains the snaptun-[Server]. The QUIC-connection handling is left to the caller.
17//! That is, after accepting a QUIC-connection, [Server::accept_with_timeout] will establish an
18//! snaptun with a client, provided the peer behaves as expected and sends the required control
19//! requests.
20//!
21//! The [Server::accept_with_timeout] method produces three different objects: [Receiver], [Sender],
22//! and [Control]. The first is used to receive packets from the peer, the second to send packets to
23//! the peer. The third is used to _drive_ the control state of the connection.
24//!
25//! [Server::accept_with_timeout] expects the client to first send a session renew request followed
26//! by an address assignment request. If the client doesn't do so within [ACCEPT_TIMEOUT], a
27//! [AcceptError::Timeout] error is returned and the connection closed. The rationale behind this is
28//! that bogus client connections should be closed as quickly as possible.
29//!
30//! ## Synopsis
31//!
32//! ```no_exec
33//! loop {
34//!   let quic_conn = endpoint.accept().await?;
35//!
36//!   let (sender, receiver, control) = snaptun_server.accept(quic_conn)?;
37//!   let _ = tokio::spawn(control); // drive control state
38//!
39//!   let _ = tokio::spawn(async move {
40//!     while Ok(p) = receiver.receive().await {
41//!       // process incoming packet
42//!     }
43//!   });
44//!
45//!   // send an outgoing packet
46//!   sender.send(p);
47//! }
48//! ```
49
50use std::{
51    net::SocketAddr,
52    pin::Pin,
53    sync::{
54        Arc, RwLock,
55        atomic::{AtomicBool, AtomicUsize, Ordering},
56    },
57    time::SystemTime,
58    vec,
59};
60
61use bytes::Bytes;
62use chrono::{DateTime, Utc};
63use http::StatusCode;
64use ipnet::IpNet;
65use prost::Message;
66use quinn::{RecvStream, SendStream, VarInt};
67use scion_proto::address::{EndhostAddr, IsdAsn};
68use scion_sdk_token_validator::validator::{Token, TokenValidator, TokenValidatorError};
69use serde::Deserialize;
70use tokio::sync::watch;
71use tracing::{debug, error, info, instrument, warn};
72
73use crate::{
74    AUTH_HEADER, AddressAllocation, AddressAllocator, IPV4_WILDCARD, IPV6_WILDCARD,
75    PATH_ADDR_ASSIGNMENT, PATH_SESSION_RENEWAL,
76    metrics::{Metrics, ReceiverMetrics, SenderMetrics},
77    requests::{
78        AddrError, AddressAssignRequest, AddressAssignResponse, SessionRenewalResponse,
79        unix_epoch_from_system_time,
80    },
81};
82
83/// SNAP tunnel connection errors.
84#[derive(Copy, Clone)]
85pub enum SnaptunConnErrors {
86    /// Invalid control request error.
87    InvalidRequest = 1,
88    /// Timeout error.
89    Timeout = 2,
90    /// Unauthenticated error.
91    Unauthenticated = 3,
92    /// Session expired error.
93    SessionExpired = 4,
94    /// Internal error.
95    InternalError = 5,
96}
97
98impl From<SnaptunConnErrors> for quinn::VarInt {
99    fn from(e: SnaptunConnErrors) -> Self {
100        VarInt::from_u32(e as u32)
101    }
102}
103
104/// Deserializable SNAP token trait.
105pub trait SnapTunToken: for<'de> Deserialize<'de> + Token + Clone {}
106impl<T> SnapTunToken for T where T: for<'de> Deserialize<'de> + Token + Clone {}
107
108/// A client MUST first send a session renew request, followed by an address assignment request
109/// within the `ACCEPT_TIMEOUT`.
110pub const ACCEPT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(3);
111
112/// Sending a control response to the client may take no longer than
113/// `SEND_TIMEOUT`.
114pub const SEND_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2);
115/// All control requests issued by the client MUST NOT exceed
116/// `CTRL_REQUEST_BUF_SIZE` bytes.
117const CTRL_REQUEST_BUF_SIZE: usize = 4096;
118
119/// The snaptun server accepts connections from clients and provides them with an address
120/// assignment.
121pub struct Server<T> {
122    metrics: Metrics,
123    validator: Arc<dyn TokenValidator<T>>,
124    allocator: Arc<dyn AddressAllocator<T>>,
125}
126
127/// Accept errors.
128#[derive(Debug, thiserror::Error)]
129pub enum AcceptError {
130    /// Timeout reached.
131    #[error("timeout reached.")]
132    Timeout,
133    /// QUIC connection error.
134    #[error("quinn connection error: {0}")]
135    ConnectionError(#[from] quinn::ConnectionError),
136    /// Parse control request error.
137    #[error("parse control request error: {0}")]
138    ParseControlRequestError(#[from] ParseControlRequestError),
139    /// Send control response error.
140    #[error("send control response error: {0}")]
141    SendControlResponseError(#[from] SendControlResponseError),
142    /// Unexpected control request.
143    #[error("unexpected control request")]
144    UnexpectedControlRequest,
145}
146
147impl<T: SnapTunToken> Server<T> {
148    /// Create a new server that can accept QUIC connections and turn them into
149    /// snap tunnels.
150    pub fn new(
151        allocator: Arc<dyn AddressAllocator<T>>,
152        validator: Arc<dyn TokenValidator<T>>,
153        metrics: Metrics,
154    ) -> Self {
155        Self {
156            allocator,
157            validator,
158            metrics,
159        }
160    }
161
162    /// Accept a connection and establish a tunnel.
163    ///
164    /// ## Tunnel initialization
165    ///
166    /// The client is expected to first send a session renew request, followed by an address
167    /// assignment request. The connection is closed with a [SnaptunConnErrors::Timeout] if the
168    /// client does not send the requests within [ACCEPT_TIMEOUT].
169    pub async fn accept_with_timeout(
170        &self,
171        conn: quinn::Connection,
172    ) -> Result<(Sender<T>, Receiver<T>, Control), AcceptError> {
173        match tokio::time::timeout(ACCEPT_TIMEOUT, self.accept(conn.clone())).await {
174            Ok(res) => res,
175            Err(_elapsed) => {
176                conn.close(
177                    SnaptunConnErrors::Timeout.into(),
178                    b"timeout establishing snaptun",
179                );
180                Err(AcceptError::Timeout)
181            }
182        }
183    }
184
185    /// Accept a connection and establish a snaptun.
186    ///
187    /// ## Tunnel initialization
188    ///
189    /// The client is expected to first send a session renew request, followed by an address
190    /// assignment request.
191    #[instrument(name = "SnapTunServer::accept", skip_all, fields(conn_id = conn.stable_id()))]
192    async fn accept(
193        &self,
194        conn: quinn::Connection,
195    ) -> Result<(Sender<T>, Receiver<T>, Control), AcceptError> {
196        let state_machine = Arc::new(TunnelStateMachine::new(
197            self.validator.clone(),
198            self.allocator.clone(),
199        ));
200
201        //
202        // First request MUST be a session renew request.
203        let (address_assign_request, mut snd, _rcv) = receive_expected_control_request(
204            &conn,
205            |r| matches!(r, ControlRequest::SessionRenewal(_)),
206            b"expected session renewal request",
207        )
208        .await?;
209
210        let now = SystemTime::now();
211        debug!(?now, request=?address_assign_request, "Process expected session renewal request");
212        let (code, body) = state_machine.process_control_request(now, address_assign_request);
213        let send_res = send_http_response(&mut snd, code, &body).await;
214        if !code.is_success() {
215            conn.close(
216                SnaptunConnErrors::InvalidRequest.into(),
217                b"handling session renewal request",
218            );
219            return Err(AcceptError::UnexpectedControlRequest);
220        }
221        if let Err(e) = send_res {
222            conn.close(
223                SnaptunConnErrors::InternalError.into(),
224                b"send control response error",
225            );
226            return Err(AcceptError::SendControlResponseError(e));
227        }
228
229        //
230        // Second request MUST be an address assignment request.
231        let (address_assign_request, mut snd, _rcv) = receive_expected_control_request(
232            &conn,
233            |r| matches!(r, ControlRequest::AddressAssignment { .. }),
234            b"expected address assignment request",
235        )
236        .await?;
237
238        let now = SystemTime::now();
239        debug!(?now, request=?address_assign_request, "Process expected address assignment request");
240        let (code, body) = state_machine.process_control_request(now, address_assign_request);
241        let send_res = send_http_response(&mut snd, code, &body).await;
242        if !code.is_success() {
243            conn.close(
244                SnaptunConnErrors::InvalidRequest.into(),
245                b"handling address assignment request",
246            );
247            return Err(AcceptError::UnexpectedControlRequest);
248        }
249        if let Err(e) = send_res {
250            conn.close(
251                SnaptunConnErrors::InternalError.into(),
252                b"send control response error",
253            );
254            return Err(AcceptError::SendControlResponseError(e));
255        }
256
257        let initial_state_version = state_machine.state_version();
258        Ok((
259            Sender::new(
260                state_machine.get_addresses().expect("assigned state"),
261                conn.clone(),
262                state_machine.clone(),
263                initial_state_version,
264                self.metrics.sender_metrics.clone(),
265            ),
266            Receiver::new(
267                conn.clone(),
268                state_machine.clone(),
269                initial_state_version,
270                self.metrics.receiver_metrics.clone(),
271            ),
272            Control::new(conn, state_machine.clone()),
273        ))
274    }
275}
276
277async fn receive_expected_control_request(
278    conn: &quinn::Connection,
279    expected: fn(&ControlRequest) -> bool,
280    wrong_request_conn_close_reason: &'static [u8],
281) -> Result<(ControlRequest, SendStream, RecvStream), AcceptError> {
282    let (snd, mut rcv) = conn
283        .accept_bi()
284        .await
285        .map_err(AcceptError::ConnectionError)?;
286    let mut buf = vec![0u8; CTRL_REQUEST_BUF_SIZE];
287    let req = match parse_http_request(&mut buf, &mut rcv).await {
288        Ok(req) if expected(&req) => req,
289        Ok(_) => {
290            conn.close(
291                SnaptunConnErrors::InvalidRequest.into(),
292                wrong_request_conn_close_reason,
293            );
294            return Err(AcceptError::UnexpectedControlRequest);
295        }
296        Err(err) => {
297            handle_invalid_request(conn, &err);
298            return Err(err.into());
299        }
300    };
301    Ok((req, snd, rcv))
302}
303
304/// Sender can be used to send packets to the client. It is returned by
305/// [Server::accept_with_timeout].
306///
307/// Sender offers a synchronous and an asychronous API to send packets to the client.
308pub struct Sender<T: SnapTunToken> {
309    metrics: SenderMetrics,
310    addresses: Vec<EndhostAddr>,
311    conn: quinn::Connection,
312    state_machine: Arc<TunnelStateMachine<T>>,
313    last_state_version: AtomicUsize,
314    is_closed: AtomicBool,
315}
316
317impl<T: SnapTunToken> Sender<T> {
318    fn new(
319        addresses: Vec<EndhostAddr>,
320        conn: quinn::Connection,
321        state_machine: Arc<TunnelStateMachine<T>>,
322        initial_state_version: usize,
323        metrics: SenderMetrics,
324    ) -> Self {
325        Self {
326            addresses,
327            conn,
328            state_machine,
329            last_state_version: AtomicUsize::new(initial_state_version),
330            is_closed: AtomicBool::new(false),
331            metrics,
332        }
333    }
334
335    /// Returns the addresses assigned to this sender.
336    pub fn assigned_addresses(&self) -> Vec<EndhostAddr> {
337        self.addresses.clone()
338    }
339
340    /// Returns the remote address of the underling QUIC connection.
341    pub fn remote_underlay_address(&self) -> SocketAddr {
342        self.conn.remote_address()
343    }
344
345    /// Send a packet to the client. The packet needs to fit entirely into a QUIC datagram.
346    ///
347    /// ## Errors
348    ///
349    /// The function returns an error if either the connection is in an
350    /// erroneous state (non-recoverable), or the address assignment has
351    /// changed. In the latter case, [SendPacketError::NewAssignedAddress] is
352    /// returned with a new [Sender] object that is assigned the new address.
353    /// The old object will return a [SendPacketError::ConnectionClosed] error.
354    pub fn send(&self, pkt: Bytes) -> Result<(), SendPacketError<T>> {
355        let pkt = self.validate_tun(pkt)?;
356        self.conn.send_datagram(pkt)?;
357        self.metrics.datagrams_sent_total.inc();
358        Ok(())
359    }
360
361    /// Send a packet to the client. The packet needs to fit entirely into a QUIC datagram.
362    ///
363    /// Unlike [Self::send], this method will wait for buffer space during congestion
364    /// conditions, which effectively prioritizes old datagrams over new datagrams.
365    pub async fn send_wait(&self, pkt: Bytes) -> Result<(), SendPacketError<T>> {
366        let pkt = self.validate_tun(pkt)?;
367        self.conn.send_datagram_wait(pkt).await?;
368        Ok(())
369    }
370
371    fn validate_tun(&self, pkt: Bytes) -> Result<Bytes, SendPacketError<T>> {
372        // if the connection is closed, immediately return an error
373        if self.is_closed.load(Ordering::Acquire) {
374            return Err(SendPacketError::ConnectionClosed);
375        }
376        // check if something changed in the state machine
377        let current_state_version = self.state_machine.state_version();
378        if self
379            .last_state_version
380            .compare_exchange(
381                current_state_version - 1,
382                current_state_version,
383                Ordering::AcqRel,
384                Ordering::Acquire,
385            )
386            .is_ok()
387        {
388            // state has been updated
389            // check if the state machine is closed
390            if self.state_machine.is_closed() {
391                self.is_closed.store(true, Ordering::Release);
392                return Err(SendPacketError::ConnectionClosed);
393            }
394            // if the state machine has changed, we need to re-fetch the addresses from it
395            let addresses = self.state_machine.get_addresses()?;
396
397            // Return the new sender with the updated addresses
398            return Err(SendPacketError::NewAssignedAddress((
399                Box::new(Sender::new(
400                    addresses,
401                    self.conn.clone(),
402                    self.state_machine.clone(),
403                    current_state_version,
404                    self.metrics.clone(),
405                )),
406                pkt,
407            )));
408        }
409
410        Ok(pkt)
411    }
412}
413
414impl<T: SnapTunToken> std::fmt::Debug for Sender<T> {
415    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416        f.debug_struct("Sender")
417            .field("addresses", &self.addresses)
418            .field("conn", &self.conn.stable_id())
419            .field("last_state_version", &self.last_state_version)
420            .finish()
421    }
422}
423
424/// Send packet error.
425#[derive(Debug, thiserror::Error)]
426pub enum SendPacketError<T: SnapTunToken> {
427    /// Connection closed.
428    #[error("connection closed")]
429    ConnectionClosed,
430    /// New address assigned.
431    #[error("address was re-assigned")]
432    NewAssignedAddress((Box<Sender<T>>, Bytes)),
433    /// Address assignment error.
434    #[error("address assignment error: {0}")]
435    AddressAssignmentError(#[from] AddressAssignmentError),
436    /// QUIC send data gram error.
437    #[error("underlying send error")]
438    SendDatagramError(#[from] quinn::SendDatagramError),
439}
440
441/// Receiver can be used to receive packets from the client. It is returned by
442/// [Server::accept_with_timeout].
443pub struct Receiver<T: SnapTunToken> {
444    metrics: ReceiverMetrics,
445    conn: quinn::Connection,
446    state_machine: Arc<TunnelStateMachine<T>>,
447    last_state_version: AtomicUsize,
448    is_closed: AtomicBool,
449}
450
451/// Packet receive error.
452#[derive(Debug, thiserror::Error)]
453pub enum ReceivePacketError {
454    /// QUIC connection error.
455    #[error("quinn error: {0}")]
456    ConnectionError(#[from] quinn::ConnectionError),
457    /// Connection closed.
458    #[error("connection closed")]
459    ConnectionClosed,
460}
461
462impl<T: SnapTunToken> Receiver<T> {
463    fn new(
464        conn: quinn::Connection,
465        state_machine: Arc<TunnelStateMachine<T>>,
466        initial_state_version: usize,
467        metrics: ReceiverMetrics,
468    ) -> Self {
469        Self {
470            conn,
471            state_machine,
472            last_state_version: AtomicUsize::new(initial_state_version),
473            is_closed: AtomicBool::new(false),
474            metrics,
475        }
476    }
477
478    /// Receive a packet from the client.
479    pub async fn receive(&self) -> Result<Bytes, ReceivePacketError> {
480        // if the state machine changed, check whether the connection is still valid
481        let current_state_version = self.state_machine.state_version();
482        if self
483            .last_state_version
484            .compare_exchange(
485                current_state_version - 1,
486                current_state_version,
487                Ordering::AcqRel,
488                Ordering::Acquire,
489            )
490            .is_ok()
491        {
492            // state has been updated, check if the state machine is closed
493            if self.state_machine.is_closed() {
494                self.is_closed.store(true, Ordering::Release);
495            }
496        }
497        if self.is_closed.load(Ordering::Acquire) {
498            return Err(ReceivePacketError::ConnectionClosed);
499        }
500        let p = self.conn.read_datagram().await?;
501        self.metrics.datagrams_received_total.inc();
502        Ok(p)
503    }
504}
505
506/// Control errors.
507#[derive(Debug, thiserror::Error)]
508pub enum ControlError {
509    /// Parse control request error.
510    #[error("parse control request error: {0}")]
511    ParseError(#[from] ParseControlRequestError),
512    /// Send control response error.
513    #[error("send control response error: {0}")]
514    SendError(#[from] SendControlResponseError),
515    /// QUIC stopped error.
516    #[error("wait for completion error: {0}")]
517    StoppedError(#[from] quinn::StoppedError),
518    /// Session expired.
519    #[error("session expired")]
520    SessionExpired,
521    /// Connection closed prematurely.
522    #[error("connection closed prematurely")]
523    ClosedPrematurely,
524}
525
526/// Control is used to handle control requests from the client. It is returned by
527/// [Server::accept_with_timeout] and must be polled to process control requests.
528pub struct Control {
529    driver_fut: Pin<Box<dyn Future<Output = Result<(), ControlError>> + Send>>,
530}
531
532impl Control {
533    fn new<T>(conn: quinn::Connection, tunnel_state: Arc<TunnelStateMachine<T>>) -> Self
534    where
535        T: for<'de> Deserialize<'de> + Token + Clone,
536    {
537        let fut = async move {
538            loop {
539                tokio::select! {
540                    _ = tunnel_state.await_session_expiry() => {
541                        // session expired, close the connection
542                        tunnel_state.shutdown();
543                        conn.close(SnaptunConnErrors::SessionExpired.into(), b"session expired");
544                        return Err(ControlError::SessionExpired)
545                    }
546                    res = conn.accept_bi() => {
547                        let (mut snd, mut rcv) = match res {
548                            Ok(v) => v,
549                            Err(quinn::ConnectionError::ApplicationClosed(_)) => {
550                                tunnel_state.shutdown();
551                                return Ok(());
552                            }
553                            Err(_) => {
554                                tunnel_state.shutdown();
555                                return Err(ControlError::ClosedPrematurely);
556                            }
557                        };
558
559                        let mut buf = vec![0u8; CTRL_REQUEST_BUF_SIZE];
560                        let control_request  = parse_http_request(&mut buf, &mut rcv).await.inspect_err(|err| {
561                            handle_invalid_request(&conn, err);
562                            tunnel_state.shutdown();
563                        })?;
564
565                        let (code, body) = tunnel_state.process_control_request(SystemTime::now(), control_request);
566                        send_http_response(&mut snd, code, &body).await
567                            .inspect_err(|_| {
568                                tunnel_state.shutdown();
569                                conn.close(SnaptunConnErrors::InternalError.into(), b"send control response error");
570                            })?;
571
572                        snd.stopped().await?;
573                    }
574                }
575            }
576        };
577        let driver_fut = Box::pin(fut);
578        Self { driver_fut }
579    }
580}
581
582impl Future for Control {
583    type Output = Result<(), ControlError>;
584
585    fn poll(
586        mut self: std::pin::Pin<&mut Self>,
587        cx: &mut std::task::Context<'_>,
588    ) -> std::task::Poll<Self::Output> {
589        self.driver_fut.as_mut().poll(cx)
590    }
591}
592
593/// Address assignment error.
594#[derive(Debug, thiserror::Error)]
595pub enum AddressAssignmentError {
596    /// No address assigned.
597    #[error("no address assigned")]
598    NoAddressAssigned,
599}
600
601/// The state transitions of an edgetun connection.
602///
603/// ```text
604/// Unassigned --> Assigend --> Closed
605/// ```
606///
607/// Once the connection is closed, it remains closed.
608/// The state machine has an internal state version that is incremented whenever the state changes.
609/// This can be used to cheaply detect changes in the state machine from the outside.
610pub struct TunnelStateMachine<T: SnapTunToken> {
611    validator: Arc<dyn TokenValidator<T>>,
612    allocator: Arc<dyn AddressAllocator<T>>,
613    inner_state: RwLock<TunnelState>,
614    state_version: AtomicUsize,
615    // channel to notify the session termination about session expiry updates
616    sender: watch::Sender<()>,
617    receiver: watch::Receiver<()>,
618}
619
620impl<T: SnapTunToken> Drop for TunnelStateMachine<T> {
621    fn drop(&mut self) {
622        // Make sure that the state is closed and address is released
623        self.shutdown();
624    }
625}
626
627impl<T: SnapTunToken> TunnelStateMachine<T> {
628    pub(crate) fn new(
629        validator: Arc<dyn TokenValidator<T>>,
630        allocator: Arc<dyn AddressAllocator<T>>,
631    ) -> Self {
632        let (sender, receiver) = watch::channel(());
633
634        Self {
635            validator,
636            allocator,
637            inner_state: Default::default(),
638            state_version: AtomicUsize::new(0),
639            sender,
640            receiver,
641        }
642    }
643
644    /// Processes an address assignment request, updates the internal protocol
645    /// state and returns the response that should be sent back to the client.
646    fn process_control_request(
647        &self,
648        now: SystemTime,
649        control_request: ControlRequest,
650    ) -> (http::StatusCode, Vec<u8>) {
651        let mut inner_state = self.inner_state.write().expect("no fail");
652        if let TunnelState::Closed = *inner_state {
653            return (http::StatusCode::BAD_REQUEST, vec![]);
654        }
655        match control_request {
656            ControlRequest::AddressAssignment(token, address_assign_request) => {
657                self.locked_process_addr_assignment_request(
658                    &mut inner_state,
659                    now,
660                    token,
661                    address_assign_request,
662                )
663            }
664            ControlRequest::SessionRenewal(token) => {
665                self.locked_process_session_renewal(&mut inner_state, now, token)
666            }
667        }
668    }
669
670    fn locked_process_session_renewal(
671        &self,
672        inner_state: &mut TunnelState,
673        now: SystemTime,
674        token: String,
675    ) -> (http::StatusCode, Vec<u8>) {
676        let mut resp_body = vec![];
677        let resp_code = match self.validator.validate(now, &token) {
678            Ok(claims) => {
679                let token_expiry = claims.exp_time();
680
681                // update internal state
682                self.locked_update_tunnel_session(inner_state, token_expiry);
683
684                let resp = SessionRenewalResponse {
685                    valid_until: unix_epoch_from_system_time(token_expiry),
686                };
687                resp.encode(&mut resp_body).expect("no fail");
688                StatusCode::OK
689            }
690            Err(TokenValidatorError::JwtSignatureInvalid()) => {
691                info!("Invalid signature");
692                StatusCode::UNAUTHORIZED
693            }
694            Err(TokenValidatorError::JwtError(err)) => {
695                info!(?err, "Token validation failed");
696                StatusCode::BAD_REQUEST
697            }
698            Err(TokenValidatorError::TokenExpired(err)) => {
699                info!(?err, "Token validation failed: token expired");
700                StatusCode::UNAUTHORIZED
701            }
702        };
703        (resp_code, resp_body)
704    }
705
706    /// Processes an address assignment request, updates the internal protocol
707    /// state and returns the response that should be sent back to the client.
708    fn locked_process_addr_assignment_request(
709        &self,
710        inner_state: &mut TunnelState,
711        now: SystemTime,
712        token: String,
713        addr_assignments: AddressAssignRequest,
714    ) -> (http::StatusCode, Vec<u8>) {
715        let mut resp_body = vec![];
716        let resp_code = match self.validator.validate(now, &token) {
717            Ok(claims) => {
718                if addr_assignments.requested_addresses.len() > 1 {
719                    // We only implement single address assignments at the moment
720                    warn!("Address assignment failed, multiple address assignments not supported");
721                    return (StatusCode::NOT_IMPLEMENTED, resp_body);
722                }
723
724                let mut requests: Vec<(IsdAsn, IpNet)> = match addr_assignments
725                    .requested_addresses
726                    .iter()
727                    .map(|range| range.try_into())
728                    .collect::<Result<Vec<_>, AddrError>>()
729                {
730                    Ok(reqs) => reqs,
731                    Err(_) => return (StatusCode::BAD_REQUEST, vec![]),
732                };
733
734                // We only implement single address assignments at the moment
735                if requests
736                    .iter()
737                    .any(|(_, net)| net.prefix_len() != net.max_prefix_len())
738                {
739                    warn!("Address assignment failed, prefix assignments are not supported");
740                    return (StatusCode::NOT_IMPLEMENTED, resp_body);
741                }
742
743                // If no addresses are requested, try allocating either a IPv4 or IPv6 address.
744                if requests.is_empty() {
745                    requests.push((IsdAsn::WILDCARD, IPV4_WILDCARD));
746                    requests.push((IsdAsn::WILDCARD, IPV6_WILDCARD));
747                }
748
749                // check that our current state is valid
750                let session_expiry = match inner_state.session_validity() {
751                    Ok(v) => v,
752                    Err(err) => {
753                        error!(
754                            ?err,
755                            "Failed to get session validity when processing address assignment request"
756                        );
757                        return (StatusCode::INTERNAL_SERVER_ERROR, vec![]);
758                    }
759                };
760
761                // We return the first successfully allocated address.
762                let mut assigned_address: Option<AddressAllocation> = None;
763                for (requested_isd_as, requested_net) in requests.iter() {
764                    match self
765                        .allocator
766                        .allocate(*requested_isd_as, *requested_net, claims.clone())
767                    {
768                        Ok(allocation) => {
769                            assigned_address = Some(allocation);
770                            break;
771                        }
772                        Err(err) => {
773                            debug!(
774                                ?err,
775                                "Address allocation failed for ISD-AS {requested_isd_as} and net {requested_net}"
776                            );
777                        }
778                    }
779                }
780
781                // Only return an error if no addresses were assigned.
782                let Some(assigned_address) = assigned_address else {
783                    warn!("Address assignment failed - no available addresses for: {requests:?}",);
784                    return (StatusCode::BAD_REQUEST, vec![]);
785                };
786
787                self.locked_update_state(
788                    inner_state,
789                    TunnelState::Assigned {
790                        session_expiry,
791                        address: assigned_address.clone(),
792                    },
793                );
794
795                let resp = AddressAssignResponse {
796                    assigned_addresses: vec![(&assigned_address.address).into()],
797                };
798
799                resp.encode(&mut resp_body).expect("no fail");
800                StatusCode::OK
801            }
802            Err(TokenValidatorError::JwtSignatureInvalid()) => {
803                info!("Invalid JWT Signature");
804                StatusCode::UNAUTHORIZED
805            }
806            Err(TokenValidatorError::JwtError(err)) => {
807                info!(?err, "Token validation failed");
808                StatusCode::BAD_REQUEST
809            }
810            Err(TokenValidatorError::TokenExpired(err)) => {
811                info!(?err, "Token validation failed: token expired");
812                StatusCode::UNAUTHORIZED
813            }
814        };
815        (resp_code, resp_body)
816    }
817
818    fn locked_update_tunnel_session(
819        &self,
820        inner_state: &mut TunnelState,
821        session_expiry: SystemTime,
822    ) {
823        match inner_state {
824            TunnelState::Unassigned => {
825                *inner_state = TunnelState::SessionEstablished { session_expiry };
826            }
827            TunnelState::SessionEstablished { .. } => {
828                *inner_state = TunnelState::SessionEstablished { session_expiry };
829            }
830            TunnelState::Assigned { address, .. } => {
831                *inner_state = TunnelState::Assigned {
832                    session_expiry,
833                    address: address.clone(),
834                };
835            }
836            // XXX(bunert): Should not happen as we error out before updating the state.
837            TunnelState::Closed => tracing::error!("Updating tunnel session but in closed state"),
838        };
839    }
840
841    fn locked_update_state(&self, inner_state: &mut TunnelState, new_state: TunnelState) {
842        tracing::debug!(%new_state, "Updating tunnel state");
843        *inner_state = new_state;
844
845        self.state_version.fetch_add(1, Ordering::AcqRel);
846
847        if self.sender.send(()).is_err() {
848            // This happens only if the channel is closed, which means that the session has
849            // expired and the receiver is no longer interested in updates.
850            debug!("Failed to notify session expiry update");
851        }
852    }
853
854    fn get_addresses(&self) -> Result<Vec<EndhostAddr>, AddressAssignmentError> {
855        let guard = self.inner_state.read().expect("no fail");
856        if let TunnelState::Assigned {
857            address,
858            session_expiry: _,
859        } = &*guard
860        {
861            return Ok(vec![address.address]);
862        }
863        Err(AddressAssignmentError::NoAddressAssigned)
864    }
865
866    async fn await_session_expiry(&self) {
867        let mut expiry_notifier = self.receiver.clone();
868        loop {
869            let valid_duration = {
870                let res = {
871                    let guard = self.inner_state.read().expect("no fail");
872                    guard.session_validity()
873                };
874                match res {
875                    Ok(session_validity) => {
876                        match session_validity.duration_since(SystemTime::now()) {
877                            Ok(dur) => dur,
878                            Err(_) => return, // session already expired
879                        }
880                    }
881                    Err(err) => {
882                        // Tunnel in an invalid state, should only happen if the tunnel is closed
883                        // (e.g. session already expired).
884                        tracing::warn!(%err, "Tunnel in an invalid state");
885                        return;
886                    }
887                }
888            };
889
890            tokio::select! {
891                _ = expiry_notifier.changed() => {
892                    // session expiry updated
893                    continue;
894                }
895                _ = tokio::time::sleep(valid_duration) => {
896                    // Sleep until the session expires
897                    return;
898                }
899            }
900        }
901    }
902
903    fn state_version(&self) -> usize {
904        self.state_version.load(Ordering::Acquire)
905    }
906
907    fn is_closed(&self) -> bool {
908        if let TunnelState::Closed = *self.inner_state.read().expect("no fail") {
909            return true;
910        }
911        false
912    }
913
914    fn shutdown(&self) {
915        let mut inner_state = self.inner_state.write().expect("no fail");
916
917        // Put address grant on hold
918        if let TunnelState::Assigned {
919            session_expiry: _,
920            address,
921        } = &*inner_state
922        {
923            if !self.allocator.put_on_hold(address.id.clone()) {
924                error!(addr=?address.address, "Could not set address to hold during shutdown - address was released while tunnel was still assigned");
925            }
926        }
927
928        self.locked_update_state(&mut inner_state, TunnelState::Closed);
929    }
930}
931
932#[derive(Debug, thiserror::Error)]
933enum TunnelStateError {
934    #[error("invalid state: {0}")]
935    InvalidState(TunnelState),
936}
937
938#[derive(Debug, Clone)]
939enum TunnelState {
940    Unassigned,
941    SessionEstablished {
942        session_expiry: SystemTime,
943    },
944    Assigned {
945        session_expiry: SystemTime,
946        address: AddressAllocation,
947    },
948    Closed,
949}
950
951impl TunnelState {
952    fn session_validity(&self) -> Result<SystemTime, TunnelStateError> {
953        match self {
954            TunnelState::SessionEstablished { session_expiry } => Ok(*session_expiry),
955            TunnelState::Assigned { session_expiry, .. } => Ok(*session_expiry),
956            _ => Err(TunnelStateError::InvalidState(self.clone())),
957        }
958    }
959}
960
961impl Default for TunnelState {
962    fn default() -> Self {
963        Self::Unassigned
964    }
965}
966
967impl std::fmt::Display for TunnelState {
968    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
969        match self {
970            TunnelState::Unassigned => write!(f, "Unassigned"),
971            TunnelState::SessionEstablished { session_expiry } => {
972                write!(
973                    f,
974                    "SessionEstablished ({})",
975                    DateTime::<Utc>::from(*session_expiry)
976                )
977            }
978            TunnelState::Assigned {
979                session_expiry,
980                address,
981            } => {
982                write!(
983                    f,
984                    "Assigned (valid until: {}, addresses: [{}])",
985                    DateTime::<Utc>::from(*session_expiry),
986                    address.address
987                )
988            }
989            TunnelState::Closed => write!(f, "Closed"),
990        }
991    }
992}
993
994#[derive(Debug)]
995enum ControlRequest {
996    AddressAssignment(String, AddressAssignRequest),
997    SessionRenewal(String),
998}
999
1000fn handle_invalid_request(conn: &quinn::Connection, err: &ParseControlRequestError) {
1001    match err {
1002        ParseControlRequestError::ClosedPrematurely => {
1003            conn.close(
1004                SnaptunConnErrors::InternalError.into(),
1005                b"closed prematurely",
1006            );
1007        }
1008        ParseControlRequestError::ReadError(_) => {
1009            conn.close(SnaptunConnErrors::InternalError.into(), b"read error");
1010        }
1011        ParseControlRequestError::InvalidRequest(reason) => {
1012            conn.close(SnaptunConnErrors::InvalidRequest.into(), reason.as_bytes());
1013        }
1014        ParseControlRequestError::Unauthenticated(reason) => {
1015            conn.close(SnaptunConnErrors::Unauthenticated.into(), reason.as_bytes());
1016        }
1017    }
1018}
1019
1020/// Error parsing control request.
1021#[derive(Debug, thiserror::Error)]
1022pub enum ParseControlRequestError {
1023    /// Invalid request.
1024    #[error("invalid request: {0}")]
1025    InvalidRequest(String),
1026    /// Failed to read from QUIC stream.
1027    #[error("read error: {0}")]
1028    ReadError(#[from] quinn::ReadError),
1029    /// Unauthenticated request.
1030    #[error("unauthenticated: {0}")]
1031    Unauthenticated(String),
1032    /// Connection closed prematurely.
1033    #[error("closed prematurely")]
1034    ClosedPrematurely,
1035}
1036
1037// We serialize the request/responses as actual http/1.1 requests. This is an
1038// arbitrary choice, as what matters is the semantics. However, we require so
1039// little flexibility in this matter that this is actually simpler than
1040// specifying a (protobuf) encoding for http-headers.
1041//
1042// We are liberal in what we accept:
1043// * The request MUST be a POST request.
1044// * The request MUST specify an Authorization-header of Bearer-type.
1045// * The request MUST have a correct path.
1046//
1047// All other headers are ignored.
1048async fn parse_http_request(
1049    buf: &mut [u8],
1050    rcv: &mut RecvStream,
1051) -> Result<ControlRequest, ParseControlRequestError> {
1052    use ParseControlRequestError::*;
1053    let mut cursor = 0;
1054    while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
1055        cursor += n;
1056        let mut headers = [httparse::EMPTY_HEADER; 16];
1057        let mut req = httparse::Request::new(&mut headers);
1058        if let Ok(httparse::Status::Complete(body_offset)) = req.parse(&buf[..cursor]) {
1059            if !matches!(req.method, Some("POST")) {
1060                return Err(InvalidRequest("invalid method".into()));
1061            }
1062            // A first defensive check that the path is correct before we
1063            // actually act on it. (1)
1064            match req.path {
1065                Some(PATH_ADDR_ASSIGNMENT) => {}
1066                Some(PATH_SESSION_RENEWAL) => {}
1067                Some(_) | None => return Err(InvalidRequest("invalid path".into())),
1068            }
1069            let Some(h) = req.headers.iter().find(|h| h.name == AUTH_HEADER) else {
1070                return Err(Unauthenticated("no auth header".into()));
1071            };
1072            let t = h
1073                .value
1074                .strip_prefix(b"Bearer ")
1075                .ok_or(Unauthenticated(
1076                    "bearer not found in authorization header".into(),
1077                ))
1078                .map(|x| String::from_utf8_lossy(x).to_string())?;
1079            // assert: req.path.is_some() and is valid, see (1)
1080            let path = req.path.unwrap();
1081            match path {
1082                PATH_ADDR_ASSIGNMENT => {
1083                    // We want to keep this method cancel-safe, therefore we are
1084                    // _not_ using read_to_end().
1085                    while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
1086                        cursor += n;
1087                    }
1088                    // parse address assignment request
1089                    let Ok(addr_req) = AddressAssignRequest::decode(&buf[body_offset..cursor])
1090                    else {
1091                        return Err(InvalidRequest(
1092                            "error when parsing address assignment request".into(),
1093                        ));
1094                    };
1095                    return Ok(ControlRequest::AddressAssignment(t, addr_req));
1096                }
1097                PATH_SESSION_RENEWAL => return Ok(ControlRequest::SessionRenewal(t)),
1098                path => unreachable!("invalid path: {path}"),
1099            }
1100        }
1101        // Reached size of buffer w/o success in parsing.
1102        if cursor == buf.len() {
1103            return Err(InvalidRequest("request too big".into()));
1104        }
1105    }
1106    Err(ClosedPrematurely)
1107}
1108
1109/// Error when sending a control response.
1110#[derive(Debug, thiserror::Error)]
1111pub enum SendControlResponseError {
1112    /// I/O error.
1113    #[error("i/o error: {0}")]
1114    IoError(#[from] std::io::Error),
1115    /// Stream was closed.
1116    #[error("stream closed: {0}")]
1117    ClosedStream(#[from] quinn::ClosedStream),
1118}
1119
1120// todo: refine these response headers to be in line with the spec.
1121async fn send_http_response(
1122    stream: &mut SendStream,
1123    code: http::StatusCode,
1124    body: &[u8],
1125) -> Result<(), SendControlResponseError> {
1126    // write_all is not cancel-safe, so we use loops instead.
1127    async fn write_all(stream: &mut SendStream, data: &[u8]) -> std::io::Result<()> {
1128        let mut cursor = 0;
1129        while cursor < data.len() {
1130            cursor += stream.write(&data[cursor..]).await?;
1131        }
1132        Ok(())
1133    }
1134
1135    write_all(
1136        stream,
1137        format!(
1138            "HTTP/1.1 {} {}\r\nContent-Length: {}\r\n\r\n",
1139            code.as_str(),
1140            code.canonical_reason().unwrap_or(""),
1141            body.len(),
1142        )
1143        .as_bytes(),
1144    )
1145    .await?;
1146    write_all(stream, body).await?;
1147
1148    // Gracefully terminate the stream.
1149    stream.finish()?;
1150    Ok(())
1151}
1152
1153#[cfg(test)]
1154mod tests {
1155    use std::time::{Duration, UNIX_EPOCH};
1156
1157    use snap_tokens::{Pssid, snap_token::SnapTokenClaims};
1158
1159    use super::*;
1160
1161    mod address_allocation {
1162
1163        fn setup() -> (TunnelStateMachine<SnapTokenClaims>, Arc<MockAllocator>) {
1164            let alloc = Arc::new(MockAllocator {
1165                is_allocated: AtomicBool::new(false),
1166                is_on_hold: AtomicBool::new(false),
1167            });
1168
1169            let tun = TunnelStateMachine::new(Arc::new(MockValidator), alloc.clone());
1170            // Prepare the state machine by doing a session renewal first
1171            let (status, body) = tun.process_control_request(
1172                SystemTime::now(),
1173                ControlRequest::SessionRenewal("valid_token".into()),
1174            );
1175            assert_eq!(
1176                status,
1177                http::StatusCode::OK,
1178                "failed to renew session - body: {body:?}"
1179            );
1180
1181            (tun, alloc)
1182        }
1183
1184        use snap_tokens::snap_token::SnapTokenClaims;
1185
1186        use super::*;
1187
1188        #[test]
1189        fn should_put_on_hold_after_shutdown() {
1190            let (tun, alloc) = setup();
1191
1192            let (status, body) = tun.process_control_request(
1193                SystemTime::now(),
1194                ControlRequest::AddressAssignment(
1195                    "valid_token".into(),
1196                    AddressAssignRequest {
1197                        requested_addresses: vec![],
1198                    },
1199                ),
1200            );
1201            assert_eq!(status, http::StatusCode::OK, "failed - body: {body:?}");
1202            assert!(alloc.is_allocated.load(Ordering::Acquire));
1203            tun.shutdown();
1204            assert!(alloc.is_on_hold.load(Ordering::Acquire));
1205        }
1206
1207        #[test]
1208        fn should_put_on_hold_after_drop() {
1209            let (tun, alloc) = setup();
1210
1211            let (status, body) = tun.process_control_request(
1212                SystemTime::now(),
1213                ControlRequest::AddressAssignment(
1214                    "valid_token".into(),
1215                    AddressAssignRequest {
1216                        requested_addresses: vec![],
1217                    },
1218                ),
1219            );
1220            assert_eq!(status, http::StatusCode::OK, "failed - body: {body:?}");
1221            assert!(alloc.is_allocated.load(Ordering::Acquire));
1222            drop(tun);
1223            assert!(alloc.is_on_hold.load(Ordering::Acquire));
1224        }
1225    }
1226
1227    struct MockValidator;
1228    impl TokenValidator<SnapTokenClaims> for MockValidator {
1229        fn validate(
1230            &self,
1231            now: SystemTime,
1232            _: &str,
1233        ) -> Result<SnapTokenClaims, TokenValidatorError> {
1234            Ok(SnapTokenClaims {
1235                pssid: Pssid::new(),
1236                exp: (now.duration_since(UNIX_EPOCH).unwrap() + Duration::from_secs(3600))
1237                    .as_secs(),
1238            })
1239        }
1240    }
1241
1242    struct MockAllocator {
1243        is_allocated: AtomicBool,
1244        is_on_hold: AtomicBool,
1245    }
1246    impl AddressAllocator<SnapTokenClaims> for MockAllocator {
1247        fn allocate(
1248            &self,
1249            isd_as: IsdAsn,
1250            prefix: IpNet,
1251            claims: SnapTokenClaims,
1252        ) -> Result<AddressAllocation, crate::AddressAllocationError> {
1253            if self.is_allocated.load(Ordering::Acquire) {
1254                return Err(crate::AddressAllocationError::NoAddressesAvailable);
1255            }
1256            self.is_allocated.store(true, Ordering::Release);
1257
1258            Ok(AddressAllocation {
1259                id: crate::AddressAllocationId {
1260                    isd_as,
1261                    id: claims.id(),
1262                },
1263                address: EndhostAddr::new(isd_as, prefix.addr()),
1264            })
1265        }
1266
1267        fn put_on_hold(&self, _id: crate::AddressAllocationId) -> bool {
1268            self.is_on_hold.store(true, Ordering::Release);
1269            true
1270        }
1271
1272        fn deallocate(&self, _id: crate::AddressAllocationId) -> bool {
1273            false
1274        }
1275    }
1276}