Skip to main content

snap_tun/
server_deprecated.rs

1// Copyright 2026 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! # The snaptun server.
15//!
16//! This module contains the snaptun-[Server]. The QUIC-connection handling is left to the caller.
17//! That is, after accepting a QUIC-connection, [Server::accept_with_timeout] will establish an
18//! snaptun with a client, provided the peer behaves as expected and sends the required control
19//! requests.
20//!
21//! The [Server::accept_with_timeout] method produces three different objects: [Receiver], [Sender],
22//! and [Control]. The first is used to receive packets from the peer, the second to send packets to
23//! the peer. The third is used to _drive_ the control state of the connection.
24//!
25//! [Server::accept_with_timeout] expects the client to first send a update token request followed
26//! by an address assignment request. If the client doesn't do so within [ACCEPT_TIMEOUT], a
27//! [AcceptError::Timeout] error is returned and the connection closed. The rationale behind this is
28//! that bogus client connections should be closed as quickly as possible.
29//!
30//! ## Synopsis
31//!
32//! ```no_exec
33//! loop {
34//!   let quic_conn = endpoint.accept().await?;
35//!
36//!   let (sender, receiver, control) = snaptun_server.accept(quic_conn)?;
37//!   let _ = tokio::spawn(control); // drive control state
38//!
39//!   let _ = tokio::spawn(async move {
40//!     while Ok(p) = receiver.receive().await {
41//!       // process incoming packet
42//!     }
43//!   });
44//!
45//!   // send an outgoing packet
46//!   sender.send(p);
47//! }
48//! ```
49
50use std::{
51    net::SocketAddr,
52    pin::Pin,
53    sync::{
54        Arc, RwLock,
55        atomic::{AtomicBool, AtomicUsize, Ordering},
56    },
57    time::SystemTime,
58    vec,
59};
60
61use bytes::Bytes;
62use chrono::{DateTime, Utc};
63use http::StatusCode;
64use prost::Message;
65use quinn::{RecvStream, SendStream, VarInt};
66use scion_proto::address::EndhostAddr;
67use scion_sdk_token_validator::validator::{Token, TokenValidator, TokenValidatorError};
68use serde::Deserialize;
69use tokio::sync::watch;
70
71use crate::{
72    AUTH_HEADER, PATH_SOCK_ADDR_ASSIGNMENT, PATH_UPDATE_TOKEN,
73    metrics::{Metrics, ReceiverMetrics, SenderMetrics},
74    requests::{SocketAddrAssignmentResponse, TokenUpdateResponse, unix_epoch_from_system_time},
75};
76
77/// SNAP tunnel connection errors.
78#[derive(Copy, Clone)]
79pub enum SnaptunConnErrors {
80    /// Invalid control request error.
81    InvalidRequest = 1,
82    /// Timeout error.
83    Timeout = 2,
84    /// Unauthenticated error.
85    Unauthenticated = 3,
86    /// Token expired error.
87    TokenExpired = 4,
88    /// Internal error.
89    InternalError = 5,
90}
91
92impl From<SnaptunConnErrors> for quinn::VarInt {
93    fn from(e: SnaptunConnErrors) -> Self {
94        VarInt::from_u32(e as u32)
95    }
96}
97
98/// Deserializable SNAP token trait.
99pub trait SnapTunToken: for<'de> Deserialize<'de> + Token + Clone {}
100impl<T> SnapTunToken for T where T: for<'de> Deserialize<'de> + Token + Clone {}
101
102/// A client MUST first send a token update request, followed by an address assignment request
103/// within the `ACCEPT_TIMEOUT`.
104pub const ACCEPT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(3);
105
106/// Sending a control response to the client may take no longer than
107/// `SEND_TIMEOUT`.
108pub const SEND_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2);
109/// Maximum size of a control message, both request and response.
110const MAX_CTRL_MESSAGE_SIZE: usize = 4096;
111
112/// The snaptun server accepts connections from clients and provides them with an address
113/// assignment.
114pub struct Server<T> {
115    metrics: Metrics,
116    validator: Arc<dyn TokenValidator<T>>,
117}
118
119/// Accept errors.
120#[derive(Debug, thiserror::Error)]
121pub enum AcceptError {
122    /// Timeout reached.
123    #[error("timeout reached.")]
124    Timeout,
125    /// QUIC connection error.
126    #[error("quinn connection error: {0}")]
127    ConnectionError(#[from] quinn::ConnectionError),
128    /// Parse control request error.
129    #[error("parse control request error: {0}")]
130    ParseControlRequestError(#[from] ParseControlRequestError),
131    /// Send control response error.
132    #[error("send control response error: {0}")]
133    SendControlResponseError(#[from] SendControlResponseError),
134    /// Unexpected control request.
135    #[error("unexpected control request")]
136    UnexpectedControlRequest,
137}
138
139impl<T: SnapTunToken> Server<T> {
140    /// Create a new server that can accept QUIC connections and turn them into
141    /// snap tunnels.
142    pub fn new(validator: Arc<dyn TokenValidator<T>>, metrics: Metrics) -> Self {
143        Self { validator, metrics }
144    }
145
146    /// Accept a connection and establish a tunnel.
147    ///
148    /// ## Tunnel initialization
149    ///
150    /// The client is expected to first send a token update request, followed by an address
151    /// assignment request. The connection is closed with a [SnaptunConnErrors::Timeout] if the
152    /// client does not send the requests within [ACCEPT_TIMEOUT].
153    pub async fn accept_with_timeout(
154        &self,
155        conn: quinn::Connection,
156    ) -> Result<(Sender<T>, Receiver<T>, Control), AcceptError> {
157        match tokio::time::timeout(ACCEPT_TIMEOUT, self.accept(conn.clone())).await {
158            Ok(res) => res,
159            Err(_elapsed) => {
160                conn.close(
161                    SnaptunConnErrors::Timeout.into(),
162                    b"timeout establishing snaptun",
163                );
164                Err(AcceptError::Timeout)
165            }
166        }
167    }
168
169    /// Accept a connection and establish a snaptun.
170    ///
171    /// ## Tunnel initialization
172    ///
173    /// The client is expected to first send a token update request, followed by an address
174    /// assignment request.
175    async fn accept(
176        &self,
177        conn: quinn::Connection,
178    ) -> Result<(Sender<T>, Receiver<T>, Control), AcceptError> {
179        let state_machine = Arc::new(TunnelStateMachine::new(
180            conn.remote_address(),
181            self.validator.clone(),
182        ));
183
184        //
185        // First request MUST be a token update request.
186        let (token_update_req, mut snd, _rcv) = receive_expected_control_request(
187            &conn,
188            |r| matches!(r, ControlRequest::TokenUpdate(_)),
189            b"expected token update request",
190        )
191        .await?;
192
193        let now = SystemTime::now();
194        tracing::debug!(?now, request=?token_update_req, "Got token update request");
195
196        let (code, body) = state_machine.process_control_request(now, token_update_req);
197        let send_res = send_http_response(&mut snd, code, &body).await;
198        if !code.is_success() {
199            conn.close(SnaptunConnErrors::InvalidRequest.into(), &body);
200            return Err(AcceptError::UnexpectedControlRequest);
201        }
202        if let Err(e) = send_res {
203            conn.close(
204                SnaptunConnErrors::InternalError.into(),
205                b"failed to send control response",
206            );
207            return Err(AcceptError::SendControlResponseError(e));
208        }
209
210        // Second request MUST be a socket address assignment request.
211        let (address_assign_request, mut snd, _rcv) = receive_expected_control_request(
212            &conn,
213            |r| matches!(r, ControlRequest::SocketAddrAssignment { .. }),
214            b"expected socket addr assignment request",
215        )
216        .await?;
217
218        let now = SystemTime::now();
219
220        tracing::debug!(?now, request=?address_assign_request, "Got address assignment request");
221
222        let (code, body) = state_machine.process_control_request(now, address_assign_request);
223        let send_res = send_http_response(&mut snd, code, &body).await;
224        if !code.is_success() {
225            conn.close(SnaptunConnErrors::InvalidRequest.into(), &body);
226            return Err(AcceptError::UnexpectedControlRequest);
227        }
228        if let Err(e) = send_res {
229            conn.close(
230                SnaptunConnErrors::InternalError.into(),
231                b"failed to send control response",
232            );
233            return Err(AcceptError::SendControlResponseError(e));
234        }
235
236        let initial_state_version = state_machine.state_version();
237        Ok((
238            Sender::new(
239                state_machine.get_socket_addr(),
240                state_machine.get_addresses().expect("assigned state"),
241                conn.clone(),
242                state_machine.clone(),
243                initial_state_version,
244                self.metrics.sender_metrics.clone(),
245            ),
246            Receiver::new(
247                conn.clone(),
248                state_machine.clone(),
249                initial_state_version,
250                self.metrics.receiver_metrics.clone(),
251            ),
252            Control::new(conn, state_machine.clone()),
253        ))
254    }
255}
256
257async fn receive_expected_control_request(
258    conn: &quinn::Connection,
259    expected: fn(&ControlRequest) -> bool,
260    wrong_request_conn_close_reason: &'static [u8],
261) -> Result<(ControlRequest, SendStream, RecvStream), AcceptError> {
262    let (snd, mut rcv) = conn
263        .accept_bi()
264        .await
265        .map_err(AcceptError::ConnectionError)?;
266    let mut buf = vec![0u8; MAX_CTRL_MESSAGE_SIZE];
267    let req = match recv_request(&mut buf, &mut rcv).await {
268        Ok(req) if expected(&req) => req,
269        Ok(_) => {
270            conn.close(
271                SnaptunConnErrors::InvalidRequest.into(),
272                wrong_request_conn_close_reason,
273            );
274            return Err(AcceptError::UnexpectedControlRequest);
275        }
276        Err(err) => {
277            handle_invalid_request(conn, &err);
278            return Err(err.into());
279        }
280    };
281    Ok((req, snd, rcv))
282}
283
284/// Sender can be used to send packets to the client. It is returned by
285/// [Server::accept_with_timeout].
286///
287/// Sender offers a synchronous and an asychronous API to send packets to the client.
288pub struct Sender<T: SnapTunToken> {
289    assigned_socket_addr: Option<SocketAddr>,
290    metrics: SenderMetrics,
291    addresses: Vec<EndhostAddr>,
292    conn: quinn::Connection,
293    state_machine: Arc<TunnelStateMachine<T>>,
294    last_state_version: AtomicUsize,
295    is_closed: AtomicBool,
296}
297
298impl<T: SnapTunToken> Sender<T> {
299    fn new(
300        assigned_socket_addr: Option<SocketAddr>,
301        addresses: Vec<EndhostAddr>,
302        conn: quinn::Connection,
303        state_machine: Arc<TunnelStateMachine<T>>,
304        initial_state_version: usize,
305        metrics: SenderMetrics,
306    ) -> Self {
307        Self {
308            assigned_socket_addr,
309            addresses,
310            conn,
311            state_machine,
312            last_state_version: AtomicUsize::new(initial_state_version),
313            is_closed: AtomicBool::new(false),
314            metrics,
315        }
316    }
317
318    /// Returns the addresses assigned to this sender.
319    pub fn assigned_addresses(&self) -> Vec<EndhostAddr> {
320        self.addresses.clone()
321    }
322
323    /// Returns the endhost socket address assigned to the endhost.
324    pub fn assigned_socket_addr(&self) -> Option<SocketAddr> {
325        self.assigned_socket_addr
326    }
327
328    /// Returns the remote address of the underlying QUIC connection.
329    pub fn remote_underlay_address(&self) -> SocketAddr {
330        self.conn.remote_address()
331    }
332
333    /// Send a packet to the client. The packet needs to fit entirely into a QUIC datagram.
334    ///
335    /// ## Errors
336    ///
337    /// The function returns an error if either the connection is in an
338    /// erroneous state (non-recoverable), or the address assignment has
339    /// changed. In the latter case, [SendPacketError::NewAssignedAddress] is
340    /// returned with a new [Sender] object that is assigned the new address.
341    /// The old object will return a [SendPacketError::ConnectionClosed] error.
342    pub fn send(&self, pkt: Bytes) -> Result<(), SendPacketError<T>> {
343        let pkt = self.validate_tun(pkt)?;
344        self.conn.send_datagram(pkt)?;
345        self.metrics.datagrams_sent_total.inc();
346        Ok(())
347    }
348
349    /// Send a packet to the client. The packet needs to fit entirely into a QUIC datagram.
350    ///
351    /// Unlike [Self::send], this method will wait for buffer space during congestion
352    /// conditions, which effectively prioritizes old datagrams over new datagrams.
353    pub async fn send_wait(&self, pkt: Bytes) -> Result<(), SendPacketError<T>> {
354        let pkt = self.validate_tun(pkt)?;
355        self.conn.send_datagram_wait(pkt).await?;
356        Ok(())
357    }
358
359    /// Immediately closes the underlying connection with the given code and reason.
360    ///
361    /// All other methods on this Sender will return ConnectionClosed after this is called.
362    pub fn close(&self, error_code: SnaptunConnErrors, reason: &[u8]) {
363        self.conn.close(error_code.into(), reason)
364    }
365
366    fn validate_tun(&self, pkt: Bytes) -> Result<Bytes, SendPacketError<T>> {
367        // if the connection is closed, immediately return an error
368        if self.is_closed.load(Ordering::Acquire) {
369            return Err(SendPacketError::ConnectionClosed);
370        }
371        // check if something changed in the state machine
372        let current_state_version = self.state_machine.state_version();
373        if self
374            .last_state_version
375            .compare_exchange(
376                current_state_version - 1,
377                current_state_version,
378                Ordering::AcqRel,
379                Ordering::Acquire,
380            )
381            .is_ok()
382        {
383            // state has been updated
384            // check if the state machine is closed
385            if self.state_machine.is_closed() {
386                self.is_closed.store(true, Ordering::Release);
387                return Err(SendPacketError::ConnectionClosed);
388            }
389            // if the state machine has changed, we need to re-fetch the addresses from it
390            let addresses = self.state_machine.get_addresses()?;
391
392            // Return the new sender with the updated addresses
393            return Err(SendPacketError::NewAssignedAddress((
394                Box::new(Sender::new(
395                    self.state_machine.get_socket_addr(),
396                    addresses,
397                    self.conn.clone(),
398                    self.state_machine.clone(),
399                    current_state_version,
400                    self.metrics.clone(),
401                )),
402                pkt,
403            )));
404        }
405
406        Ok(pkt)
407    }
408}
409
410impl<T: SnapTunToken> std::fmt::Debug for Sender<T> {
411    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412        f.debug_struct("Sender")
413            .field("addresses", &self.addresses)
414            .field("conn", &self.conn.stable_id())
415            .field("last_state_version", &self.last_state_version)
416            .finish()
417    }
418}
419
420/// Send packet error.
421#[derive(Debug, thiserror::Error)]
422pub enum SendPacketError<T: SnapTunToken> {
423    /// Connection closed.
424    #[error("connection closed")]
425    ConnectionClosed,
426    /// New address assigned.
427    #[error("address was re-assigned")]
428    NewAssignedAddress((Box<Sender<T>>, Bytes)),
429    /// Address assignment error.
430    #[error("address assignment error: {0}")]
431    AddressAssignmentError(#[from] AddressAssignmentError),
432    /// QUIC send data gram error.
433    #[error("underlying send error")]
434    SendDatagramError(#[from] quinn::SendDatagramError),
435}
436
437/// Receiver can be used to receive packets from the client. It is returned by
438/// [Server::accept_with_timeout].
439pub struct Receiver<T: SnapTunToken> {
440    metrics: ReceiverMetrics,
441    conn: quinn::Connection,
442    state_machine: Arc<TunnelStateMachine<T>>,
443    last_state_version: AtomicUsize,
444    is_closed: AtomicBool,
445}
446
447/// Packet receive error.
448#[derive(Debug, thiserror::Error)]
449pub enum ReceivePacketError {
450    /// QUIC connection error.
451    #[error("quinn error: {0}")]
452    ConnectionError(#[from] quinn::ConnectionError),
453    /// Connection closed.
454    #[error("connection closed")]
455    ConnectionClosed,
456}
457
458impl<T: SnapTunToken> Receiver<T> {
459    fn new(
460        conn: quinn::Connection,
461        state_machine: Arc<TunnelStateMachine<T>>,
462        initial_state_version: usize,
463        metrics: ReceiverMetrics,
464    ) -> Self {
465        Self {
466            conn,
467            state_machine,
468            last_state_version: AtomicUsize::new(initial_state_version),
469            is_closed: AtomicBool::new(false),
470            metrics,
471        }
472    }
473
474    /// Receive a packet from the client.
475    pub async fn receive(&self) -> Result<Bytes, ReceivePacketError> {
476        // if the state machine changed, check whether the connection is still valid
477        let current_state_version = self.state_machine.state_version();
478        if self
479            .last_state_version
480            .compare_exchange(
481                current_state_version - 1,
482                current_state_version,
483                Ordering::AcqRel,
484                Ordering::Acquire,
485            )
486            .is_ok()
487        {
488            // state has been updated, check if the state machine is closed
489            if self.state_machine.is_closed() {
490                self.is_closed.store(true, Ordering::Release);
491            }
492        }
493        if self.is_closed.load(Ordering::Acquire) {
494            return Err(ReceivePacketError::ConnectionClosed);
495        }
496        let p = self.conn.read_datagram().await?;
497        self.metrics.datagrams_received_total.inc();
498        Ok(p)
499    }
500}
501
502/// Control errors.
503#[derive(Debug, thiserror::Error)]
504pub enum ControlError {
505    /// Parse control request error.
506    #[error("parse control request error: {0}")]
507    ParseError(#[from] ParseControlRequestError),
508    /// Send control response error.
509    #[error("send control response error: {0}")]
510    SendError(#[from] SendControlResponseError),
511    /// QUIC stopped error.
512    #[error("wait for completion error: {0}")]
513    StoppedError(#[from] quinn::StoppedError),
514    /// Token expired.
515    #[error("token expired")]
516    TokenExpired,
517    /// Connection closed prematurely.
518    #[error("connection closed prematurely")]
519    ClosedPrematurely,
520}
521
522/// Control is used to handle control requests from the client. It is returned by
523/// [Server::accept_with_timeout] and must be polled to process control requests.
524pub struct Control {
525    driver_fut: Pin<Box<dyn Future<Output = Result<(), ControlError>> + Send>>,
526}
527
528impl Control {
529    fn new<T>(conn: quinn::Connection, tunnel_state: Arc<TunnelStateMachine<T>>) -> Self
530    where
531        T: for<'de> Deserialize<'de> + Token + Clone,
532    {
533        let fut = async move {
534            loop {
535                tokio::select! {
536                    _ = tunnel_state.await_token_expiry() => {
537                        // token expired, close the connection
538                        tunnel_state.shutdown();
539                        conn.close(SnaptunConnErrors::TokenExpired.into(), b"token expired");
540                        return Err(ControlError::TokenExpired)
541                    }
542                    res = conn.accept_bi() => {
543                        let (mut snd, mut rcv) = match res {
544                            Ok(v) => v,
545                            Err(quinn::ConnectionError::ApplicationClosed(_)) => {
546                                tunnel_state.shutdown();
547                                return Ok(());
548                            }
549                            Err(_) => {
550                                tunnel_state.shutdown();
551                                return Err(ControlError::ClosedPrematurely);
552                            }
553                        };
554
555                        let mut buf = vec![0u8; MAX_CTRL_MESSAGE_SIZE];
556                        let control_request  = recv_request(&mut buf, &mut rcv).await.inspect_err(|err| {
557                            handle_invalid_request(&conn, err);
558                            tunnel_state.shutdown();
559                        })?;
560
561                        let (code, body) = tunnel_state.process_control_request(SystemTime::now(), control_request);
562                        send_http_response(&mut snd, code, &body).await
563                            .inspect_err(|_| {
564                                tunnel_state.shutdown();
565                                conn.close(SnaptunConnErrors::InternalError.into(), b"send control response error");
566                            })?;
567
568                        snd.stopped().await?;
569                    }
570                }
571            }
572        };
573        let driver_fut = Box::pin(fut);
574        Self { driver_fut }
575    }
576}
577
578impl Future for Control {
579    type Output = Result<(), ControlError>;
580
581    fn poll(
582        mut self: std::pin::Pin<&mut Self>,
583        cx: &mut std::task::Context<'_>,
584    ) -> std::task::Poll<Self::Output> {
585        self.driver_fut.as_mut().poll(cx)
586    }
587}
588
589/// Address assignment error.
590#[derive(Debug, thiserror::Error)]
591pub enum AddressAssignmentError {
592    /// No address assigned.
593    #[error("no address assigned")]
594    NoAddressAssigned,
595}
596
597/// The state transitions of an snap-tun connection.
598///
599/// ```text
600/// Unassigned --> Assigend --> Closed
601/// ```
602///
603/// Once the connection is closed, it remains closed.
604/// The state machine has an internal state version that is incremented whenever the state changes.
605/// This can be used to cheaply detect changes in the state machine from the outside.
606pub struct TunnelStateMachine<T: SnapTunToken> {
607    remote_sock_addr: SocketAddr,
608    validator: Arc<dyn TokenValidator<T>>,
609    inner_state: RwLock<TunnelState>,
610    state_version: AtomicUsize,
611    // channel to notify the token termination about token expiry updates
612    sender: watch::Sender<()>,
613    receiver: watch::Receiver<()>,
614}
615
616impl<T: SnapTunToken> Drop for TunnelStateMachine<T> {
617    fn drop(&mut self) {
618        // Make sure that the state is closed and address is released
619        self.shutdown();
620    }
621}
622
623impl<T: SnapTunToken> TunnelStateMachine<T> {
624    pub(crate) fn new(remote_sock_addr: SocketAddr, validator: Arc<dyn TokenValidator<T>>) -> Self {
625        let (sender, receiver) = watch::channel(());
626
627        Self {
628            remote_sock_addr,
629            validator,
630            inner_state: Default::default(),
631            state_version: AtomicUsize::new(0),
632            sender,
633            receiver,
634        }
635    }
636
637    /// Processes an address assignment request, updates the internal protocol
638    /// state and returns the response that should be sent back to the client.
639    fn process_control_request(
640        &self,
641        now: SystemTime,
642        control_request: ControlRequest,
643    ) -> (http::StatusCode, Vec<u8>) {
644        let mut inner_state = self.inner_state.write().expect("no fail");
645
646        if let TunnelState::Closed = *inner_state {
647            return (http::StatusCode::BAD_REQUEST, "tunnel is closed".into());
648        }
649        match control_request {
650            ControlRequest::SocketAddrAssignment(token) => {
651                self.locked_process_socket_addr_assignment_request(&mut inner_state, now, token)
652            }
653            ControlRequest::TokenUpdate(token) => {
654                self.locked_process_token_update(&mut inner_state, now, token)
655            }
656        }
657    }
658
659    fn locked_process_token_update(
660        &self,
661        inner_state: &mut TunnelState,
662        now: SystemTime,
663        token: String,
664    ) -> (http::StatusCode, Vec<u8>) {
665        match self.validator.validate(now, &token) {
666            Ok(claims) => {
667                let token_expiry = claims.exp_time();
668
669                // update internal state
670                self.locked_update_tunnel_expiry(inner_state, token_expiry);
671
672                let resp = TokenUpdateResponse {
673                    valid_until: unix_epoch_from_system_time(token_expiry),
674                };
675
676                let mut resp_body = vec![];
677                resp.encode(&mut resp_body).expect("no fail");
678                (StatusCode::OK, resp_body)
679            }
680            Err(e) => map_token_validation_err_to_response(e),
681        }
682    }
683
684    fn locked_process_socket_addr_assignment_request(
685        &self,
686        inner_state: &mut TunnelState,
687        now: SystemTime,
688        token: String,
689    ) -> (http::StatusCode, Vec<u8>) {
690        // XXX: assuming well-behaved clients, we should never encounter
691        // a situation where a client did not authenticate before requesting a
692        // socket addr.
693        let token_expiry = match inner_state.token_validity() {
694            Ok(v) => v,
695            Err(err) => {
696                tracing::error!(
697                    ?err,
698                    "Failed to get token validity when processing address assignment request"
699                );
700                // this should, in principle, never happen assuming well-behaved
701                // clients.
702                return (
703                    StatusCode::INTERNAL_SERVER_ERROR,
704                    "invalid state transition".into(),
705                );
706            }
707        };
708        match self.validator.validate(now, &token) {
709            Ok(_claims) => {
710                self.locked_update_state(
711                    inner_state,
712                    TunnelState::SockAddrAssigned { token_expiry },
713                );
714                let resp = SocketAddrAssignmentResponse::from(self.remote_sock_addr);
715
716                let mut resp_body = vec![];
717                resp.encode(&mut resp_body).expect("no fail");
718                (StatusCode::OK, resp_body)
719            }
720            Err(e) => map_token_validation_err_to_response(e),
721        }
722    }
723
724    fn locked_update_tunnel_expiry(&self, inner_state: &mut TunnelState, token_expiry: SystemTime) {
725        match inner_state {
726            TunnelState::Unassigned => {
727                *inner_state = TunnelState::SessionEstablished { token_expiry };
728            }
729            TunnelState::SessionEstablished { .. } => {
730                *inner_state = TunnelState::SessionEstablished { token_expiry };
731            }
732            TunnelState::SockAddrAssigned { .. } => {
733                *inner_state = TunnelState::SockAddrAssigned { token_expiry }
734            }
735            TunnelState::Closed => {
736                tracing::error!("Updating tunnel token expiry but in closed state")
737            }
738        };
739    }
740
741    fn locked_update_state(&self, inner_state: &mut TunnelState, new_state: TunnelState) {
742        tracing::debug!(%new_state, "Updating tunnel state");
743        *inner_state = new_state;
744
745        self.state_version.fetch_add(1, Ordering::AcqRel);
746
747        if self.sender.send(()).is_err() {
748            // This happens only if the channel is closed, which means that the token has
749            // expired and the receiver is no longer interested in updates.
750            tracing::debug!("Failed to notify token expiry update");
751        }
752    }
753
754    fn get_addresses(&self) -> Result<Vec<EndhostAddr>, AddressAssignmentError> {
755        let guard = self.inner_state.read().expect("no fail");
756
757        match &*guard {
758            TunnelState::SockAddrAssigned { .. } => Ok(vec![]),
759            _ => Err(AddressAssignmentError::NoAddressAssigned),
760        }
761    }
762
763    fn get_socket_addr(&self) -> Option<SocketAddr> {
764        let guard = self.inner_state.read().expect("no fail");
765        if let TunnelState::SockAddrAssigned { .. } = &*guard {
766            return Some(self.remote_sock_addr);
767        }
768        None
769    }
770
771    async fn await_token_expiry(&self) {
772        let mut expiry_notifier = self.receiver.clone();
773        loop {
774            let valid_duration = {
775                let res = {
776                    let guard = self.inner_state.read().expect("no fail");
777                    guard.token_validity()
778                };
779                match res {
780                    Ok(token_validity) => {
781                        match token_validity.duration_since(SystemTime::now()) {
782                            Ok(dur) => dur,
783                            Err(_) => return, // token already expired
784                        }
785                    }
786                    Err(err) => {
787                        // Tunnel in an invalid state, should only happen if the tunnel is closed
788                        // (e.g. token already expired).
789                        tracing::warn!(%err, "Tunnel in an invalid state");
790                        return;
791                    }
792                }
793            };
794
795            tokio::select! {
796                _ = expiry_notifier.changed() => {
797                    // token expiry updated
798                    continue;
799                }
800                _ = tokio::time::sleep(valid_duration) => {
801                    // Sleep until the token expires
802                    return;
803                }
804            }
805        }
806    }
807
808    fn state_version(&self) -> usize {
809        self.state_version.load(Ordering::Acquire)
810    }
811
812    fn is_closed(&self) -> bool {
813        if let TunnelState::Closed = *self.inner_state.read().expect("no fail") {
814            return true;
815        }
816        false
817    }
818
819    fn shutdown(&self) {
820        let mut inner_state = self.inner_state.write().expect("no fail");
821        self.locked_update_state(&mut inner_state, TunnelState::Closed);
822    }
823}
824
825fn map_token_validation_err_to_response(value: TokenValidatorError) -> (StatusCode, Vec<u8>) {
826    match value {
827        TokenValidatorError::JwtSignatureInvalid() => {
828            tracing::info!("Invalid JWT Signature");
829            (StatusCode::UNAUTHORIZED, "unauthorized".into())
830        }
831        TokenValidatorError::JwtError(err) => {
832            tracing::info!(?err, "Token validation failed");
833            (StatusCode::UNAUTHORIZED, "unauthorized".into())
834        }
835        TokenValidatorError::TokenExpired(err) => {
836            tracing::info!(?err, "Token validation failed: token expired");
837            (StatusCode::UNAUTHORIZED, "unauthorized".into())
838        }
839    }
840}
841
842#[derive(Debug, thiserror::Error)]
843enum TunnelStateError {
844    #[error("invalid state: {0}")]
845    InvalidState(TunnelState),
846}
847
848#[derive(Debug, Clone, Default)]
849enum TunnelState {
850    #[default]
851    Unassigned,
852    SessionEstablished {
853        token_expiry: SystemTime,
854    },
855    SockAddrAssigned {
856        token_expiry: SystemTime,
857    },
858    Closed,
859}
860
861impl TunnelState {
862    fn token_validity(&self) -> Result<SystemTime, TunnelStateError> {
863        match self {
864            TunnelState::SessionEstablished { token_expiry } => Ok(*token_expiry),
865            TunnelState::SockAddrAssigned { token_expiry, .. } => Ok(*token_expiry),
866            _ => Err(TunnelStateError::InvalidState(self.clone())),
867        }
868    }
869}
870
871impl std::fmt::Display for TunnelState {
872    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
873        match self {
874            TunnelState::Unassigned => write!(f, "Unassigned"),
875            TunnelState::SessionEstablished { token_expiry } => {
876                write!(
877                    f,
878                    "SessionEstablished ({})",
879                    DateTime::<Utc>::from(*token_expiry)
880                )
881            }
882            TunnelState::Closed => write!(f, "Closed"),
883            TunnelState::SockAddrAssigned { token_expiry } => {
884                write!(
885                    f,
886                    "Remote socket address assigned (valid until: {}).",
887                    DateTime::<Utc>::from(*token_expiry),
888                )
889            }
890        }
891    }
892}
893
894#[derive(Debug)]
895enum ControlRequest {
896    SocketAddrAssignment(String),
897    TokenUpdate(String),
898}
899
900fn handle_invalid_request(conn: &quinn::Connection, err: &ParseControlRequestError) {
901    match err {
902        ParseControlRequestError::ClosedPrematurely => {
903            conn.close(
904                SnaptunConnErrors::InternalError.into(),
905                b"closed prematurely",
906            );
907        }
908        ParseControlRequestError::ReadError(_) => {
909            conn.close(SnaptunConnErrors::InternalError.into(), b"read error");
910        }
911        ParseControlRequestError::InvalidRequest(reason) => {
912            conn.close(SnaptunConnErrors::InvalidRequest.into(), reason.as_bytes());
913        }
914        ParseControlRequestError::Unauthenticated(reason) => {
915            conn.close(SnaptunConnErrors::Unauthenticated.into(), reason.as_bytes());
916        }
917    }
918}
919
920/// Error parsing control request.
921#[derive(Debug, thiserror::Error)]
922pub enum ParseControlRequestError {
923    /// Invalid request.
924    #[error("invalid request: {0}")]
925    InvalidRequest(String),
926    /// Failed to read from QUIC stream.
927    #[error("read error: {0}")]
928    ReadError(#[from] quinn::ReadError),
929    /// Unauthenticated request.
930    #[error("unauthenticated: {0}")]
931    Unauthenticated(String),
932    /// Connection closed prematurely.
933    #[error("closed prematurely")]
934    ClosedPrematurely,
935}
936
937// We serialize the request/responses as actual http/1.1 requests. This is an
938// arbitrary choice, as what matters is the semantics. However, we require so
939// little flexibility in this matter that this is actually simpler than
940// specifying a (protobuf) encoding for http-headers.
941//
942// We are liberal in what we accept:
943// * The request MUST be a POST request.
944// * The request MUST specify an Authorization-header of Bearer-type.
945// * The request MUST have a correct path.
946//
947// All other headers are ignored.
948async fn recv_request(
949    buf: &mut [u8],
950    rcv: &mut RecvStream,
951) -> Result<ControlRequest, ParseControlRequestError> {
952    use ParseControlRequestError::*;
953    let mut cursor = 0;
954
955    // Keep reading into the buffer
956    while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
957        cursor += n;
958        let mut headers = [httparse::EMPTY_HEADER; 16];
959        let mut req = httparse::Request::new(&mut headers);
960
961        // Try to parse the request
962        let Ok(httparse::Status::Complete(_body_offset)) = req.parse(&buf[..cursor]) else {
963            // Check if we can keep reading
964            if cursor >= buf.len() {
965                return Err(InvalidRequest("request too big".into()));
966            }
967            continue;
968        };
969
970        // Parsed full request
971        if !matches!(req.method, Some("POST")) {
972            return Err(InvalidRequest("invalid method".into()));
973        }
974
975        // A first defensive check that the path is correct before we
976        // actually act on it. (1)
977        match req.path {
978            Some(PATH_SOCK_ADDR_ASSIGNMENT) => {}
979            Some(PATH_UPDATE_TOKEN) => {}
980            Some(_) | None => return Err(InvalidRequest("invalid path".into())),
981        }
982
983        // Expect auth header
984        let Some(auth_header) = req.headers.iter().find(|h| h.name == AUTH_HEADER) else {
985            return Err(Unauthenticated("no auth header".into()));
986        };
987        let bearer_token = auth_header
988            .value
989            .strip_prefix(b"Bearer ")
990            .ok_or(Unauthenticated(
991                "bearer not found in authorization header".into(),
992            ))
993            .map(|x| String::from_utf8_lossy(x).to_string())?;
994
995        // assert: req.path.is_some() and is valid, see (1)
996        let path = req.path.unwrap();
997        match path {
998            PATH_SOCK_ADDR_ASSIGNMENT => {
999                return Ok(ControlRequest::SocketAddrAssignment(bearer_token));
1000            }
1001            PATH_UPDATE_TOKEN => return Ok(ControlRequest::TokenUpdate(bearer_token)),
1002            path => unreachable!("invalid path: {path}"),
1003        }
1004    }
1005
1006    Err(ClosedPrematurely)
1007}
1008
1009/// Error when sending a control response.
1010#[derive(Debug, thiserror::Error)]
1011pub enum SendControlResponseError {
1012    /// I/O error.
1013    #[error("i/o error: {0}")]
1014    IoError(#[from] std::io::Error),
1015    /// Stream was closed.
1016    #[error("stream closed: {0}")]
1017    ClosedStream(#[from] quinn::ClosedStream),
1018}
1019
1020// todo: refine these response headers to be in line with the spec.
1021async fn send_http_response(
1022    stream: &mut SendStream,
1023    code: http::StatusCode,
1024    body: &[u8],
1025) -> Result<(), SendControlResponseError> {
1026    // write_all is not cancel-safe, so we use loops instead.
1027    async fn write_all(stream: &mut SendStream, data: &[u8]) -> std::io::Result<()> {
1028        let mut cursor = 0;
1029        while cursor < data.len() {
1030            cursor += stream.write(&data[cursor..]).await?;
1031        }
1032        Ok(())
1033    }
1034
1035    write_all(
1036        stream,
1037        format!(
1038            "HTTP/1.1 {} {}\r\nContent-Length: {}\r\n\r\n",
1039            code.as_str(),
1040            code.canonical_reason().unwrap_or(""),
1041            body.len(),
1042        )
1043        .as_bytes(),
1044    )
1045    .await?;
1046    write_all(stream, body).await?;
1047
1048    // Gracefully terminate the stream.
1049    stream.finish()?;
1050    Ok(())
1051}