snap_tun/
client.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! SNAP tunnel client.
15
16use std::{
17    net::IpAddr,
18    ops::Deref,
19    pin::Pin,
20    sync::{Arc, RwLock},
21    time::{Duration, SystemTime},
22};
23
24use bytes::Bytes;
25use prost::Message;
26use quinn::{RecvStream, SendStream};
27use scion_proto::address::EndhostAddr;
28use tokio::{sync::watch, task::JoinHandle};
29use tracing::debug;
30
31use crate::requests::{
32    AddrError, AddressAssignRequest, AddressAssignResponse, AddressRange, SessionRenewalResponse,
33    system_time_from_unix_epoch_secs,
34};
35
36/// All control requests issued by the client MUST NOT exceed
37/// `CTRL_REQUEST_BUF_SIZE` bytes.
38pub const CTRL_RESPONSE_BUF_SIZE: usize = 4096;
39
40/// Lead time for session renewal. Renewal is triggered when the current time is later than the
41/// token expiry minus the lead time.
42pub const DEFAULT_RENEWAL_WAIT_THRESHOLD: Duration = Duration::from_secs(300); // 5min
43
44/// Token renewal error.
45pub type TokenRenewError = Box<dyn std::error::Error + Sync + Send>;
46
47/// Function type for renewing tokens.
48pub type TokenRenewFn = Box<
49    dyn Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenRenewError>> + Send>> + Send + Sync,
50>;
51
52/// Automatic session renewal configuration.
53pub struct AutoSessionRenewal {
54    token_renewer: TokenRenewFn,
55    renew_wait_threshold: Duration,
56}
57
58impl AutoSessionRenewal {
59    /// Create a new automatic session renewal configuration.
60    ///
61    /// # Arguments
62    /// * `renew_wait_threshold` - Duration before session expiry to wait before attempting renewal.
63    /// * `token_renewer` - Function to renew the session token.
64    pub fn new(renew_wait_threshold: Duration, token_renewer: TokenRenewFn) -> Self {
65        AutoSessionRenewal {
66            token_renewer,
67            renew_wait_threshold,
68        }
69    }
70}
71
72/// SNAP tunnel client builder.
73pub struct ClientBuilder {
74    desired_addresses: Vec<EndhostAddr>,
75    initial_session_token: String,
76    auto_session_renewal: Option<AutoSessionRenewal>,
77}
78
79impl ClientBuilder {
80    /// Client builder with an initial session token to be used to authenticate requests.
81    pub fn new<S: AsRef<str>>(initial_session_token: S) -> Self {
82        ClientBuilder {
83            desired_addresses: Vec::new(),
84            initial_session_token: initial_session_token.as_ref().into(),
85            auto_session_renewal: None,
86        }
87    }
88
89    /// Set the desired addresses to be requested from the SNAP. If empty, the SNAP server will
90    /// assign an address.
91    pub fn with_desired_addresses(mut self, desired_addresses: Vec<EndhostAddr>) -> Self {
92        self.desired_addresses = desired_addresses;
93        self
94    }
95
96    /// Enable automatic session renewal.
97    pub fn with_auto_session_renewal(mut self, session_renewal: AutoSessionRenewal) -> Self {
98        self.auto_session_renewal = Some(session_renewal);
99        self
100    }
101
102    /// Establish a SNAP tunnel using the provided QUIC connection using the builder's settings.
103    pub async fn connect(
104        self,
105        conn: quinn::Connection,
106    ) -> Result<(Sender, Receiver, Control), SnapTunError> {
107        let (expiry_sender, expiry_receiver) = watch::channel(());
108        let conn_state = SharedConnState::new(ConnState::new(expiry_sender.clone()));
109        let mut ctrl = Control {
110            conn: conn.clone(),
111            state: conn_state.clone(),
112            session_renewal_task: None,
113        };
114
115        ctrl.state.write().expect("no fail").session_token = self.initial_session_token;
116        ctrl.renew_session().await?;
117        ctrl.request_address(self.desired_addresses).await?;
118
119        if let Some(auto_session_renewal) = self.auto_session_renewal {
120            ctrl.start_auto_session_renewal(auto_session_renewal, expiry_receiver);
121        }
122
123        Ok((Sender::new(conn.clone()), Receiver { conn }, ctrl))
124    }
125}
126
127/// Control can be used to send control messages to the server
128pub struct Control {
129    conn: quinn::Connection,
130    state: SharedConnState,
131    session_renewal_task: Option<JoinHandle<Result<(), RenewTaskError>>>,
132}
133
134impl Control {
135    /// Returns the currently assigned addresses.
136    pub fn assigned_addresses(&self) -> Vec<EndhostAddr> {
137        self.state
138            .read()
139            .expect("no fail")
140            .assigned_addresses
141            .clone()
142    }
143
144    /// Returns the session expiry time.
145    pub fn session_expiry(&self) -> SystemTime {
146        self.state.read().expect("no fail").session_expiry
147    }
148
149    /// Sends an address assign request to the snaptun server.
150    ///
151    /// In addition, this also extends the session validity based on the token validity.
152    ///
153    /// # Arguments
154    /// * `desired_addresses` - Client can request specific [EndhostAddr] from the server.
155    async fn request_address(
156        &mut self,
157        desired_addresses: Vec<EndhostAddr>,
158    ) -> Result<(), ControlError> {
159        debug!(?desired_addresses, "Requesting address assignment");
160        let (mut snd, mut rcv) = self.conn.open_bi().await?;
161        let request = AddressAssignRequest {
162            requested_addresses: desired_addresses
163                .into_iter()
164                .map(|addr| {
165                    let (version, prefix_length, octets) = match addr.local_address() {
166                        IpAddr::V4(a) => (4, 32, a.octets().to_vec()),
167                        IpAddr::V6(a) => (6, 128, a.octets().to_vec()),
168                    };
169                    AddressRange {
170                        isd_as: addr.isd_asn().into(),
171                        ip_version: version as u32,
172                        prefix_length: prefix_length as u32,
173                        address: octets,
174                    }
175                })
176                .collect::<Vec<_>>(),
177        };
178        let body = request.encode_to_vec();
179        let token = self.state.read().expect("no fail").session_token.clone();
180        send_control_request(&mut snd, crate::PATH_ADDR_ASSIGNMENT, body.as_ref(), &token).await?;
181        let mut resp_buf = [0u8; CTRL_RESPONSE_BUF_SIZE];
182        let response: AddressAssignResponse =
183            parse_http_response(&mut resp_buf[..], &mut rcv).await?;
184
185        if response.assigned_addresses.is_empty() {
186            return Err(ControlError::AddressAssignmentFailed(
187                AddrAssignError::NoAddressAssigned,
188            ));
189        }
190        let assigned_addresses = response
191            .assigned_addresses
192            .iter()
193            .map(|address_range| {
194                TryInto::<EndhostAddr>::try_into(address_range).map_err(|e| {
195                    ControlError::AddressAssignmentFailed(AddrAssignError::InvalidAddr(e))
196                })
197            })
198            .collect::<Result<Vec<_>, _>>()?;
199        debug!(?assigned_addresses, "Got address assignment");
200
201        self.state.write().expect("no fail").assigned_addresses = assigned_addresses;
202        Ok(())
203    }
204
205    /// Sends a session renewal request to the snaptun server.
206    pub async fn renew_session(&mut self) -> Result<(), ControlError> {
207        let token = self.state.read().expect("no fail").session_token.clone();
208        self.set_session_expiry(renew_session(&self.conn.clone(), &token).await?);
209        Ok(())
210    }
211
212    fn start_auto_session_renewal(
213        &mut self,
214        config: AutoSessionRenewal,
215        mut expiry_notifier: watch::Receiver<()>,
216    ) {
217        let conn = self.conn.clone();
218        let conn_state = self.state.clone();
219
220        self.session_renewal_task = Some(tokio::spawn(async move {
221            // Maximum number of retries for session renewal.
222            const MAX_RETRIES: u32 = 5;
223            // Base retry delay used for exponential backoff.
224            const BASE_RETRY_DELAY_SECS: u64 = 3;
225            // Fraction of the remaining time to sleep before retrying.
226            const SLEEP_FRACTION: f32 = 0.75; // Sleep for 3/4 of the remaining time
227
228            let mut retries: u32 = 0;
229            loop {
230                let secs_until_expiry = {
231                    let expiry = conn_state.read().expect("no fail").session_expiry;
232                    // Calculate how long until the session expires
233                    match expiry.duration_since(SystemTime::now()) {
234                        Ok(duration) => duration.as_secs(),
235                        Err(_) => {
236                            // As long as the auto session renewal works correctly, this should
237                            // never happen.
238                            tracing::error!("Session expiry already passed, stopping auto-renewal");
239                            return Err(RenewTaskError::SessionExpired);
240                        }
241                    }
242                };
243
244                // Renew immediately if the remaining seconds are less than the wait threshold.
245                let sleep_secs = if secs_until_expiry < config.renew_wait_threshold.as_secs() {
246                    0
247                } else {
248                    (secs_until_expiry as f32 * SLEEP_FRACTION) as u64
249                };
250                debug!("Next session renewal in {sleep_secs} seconds");
251
252                tokio::select! {
253                    _ = expiry_notifier.changed() => continue,
254                    _ = tokio::time::sleep(Duration::from_secs(sleep_secs)) => {
255                        debug!("Renewing token and snaptun session");
256
257                        // renew token
258                        let token = match (config.token_renewer)().await {
259                            Ok(token) => token,
260                            Err(err) => {
261                                debug!(%err, "Failed to renew token, retry");
262                                retries += 1;
263                                if retries >= MAX_RETRIES {
264                                    return Err(RenewTaskError::MaxRetriesReached);
265                                }
266                                tokio::time::sleep(Duration::from_secs(BASE_RETRY_DELAY_SECS.pow(retries))).await;
267                                continue;
268                            },
269                        };
270
271                        // renew session
272                        let new_expiry = match renew_session(&conn, &token).await {
273                            Ok(exp) => exp,
274                            Err(err) => {
275                                debug!(%err, "Failed to renew session, retry");
276                                retries += 1;
277                                if retries >= MAX_RETRIES {
278                                    return Err(RenewTaskError::MaxRetriesReached);
279                                }
280                                tokio::time::sleep(Duration::from_secs(BASE_RETRY_DELAY_SECS.pow(retries))).await;
281                                continue;
282                            }
283                        };
284
285                        debug!(new_expiry=%chrono::DateTime::<chrono::Utc>::from(new_expiry).to_rfc3339(), "auto session renewal successful");
286                        conn_state.write().expect("no fail").session_expiry = new_expiry;
287                        retries = 0;
288                    }
289                }
290            }
291        }));
292    }
293
294    fn set_session_expiry(&mut self, expiry: SystemTime) {
295        self.state.write().expect("no fail").session_expiry = expiry;
296        if self
297            .state
298            .read()
299            .expect("no fail")
300            .expiry_notifier
301            .send(())
302            .is_err()
303        {
304            // This happens only if the channel is closed, which means that the session has
305            // expired and the receiver is no longer interested in updates.
306            debug!("Failed to notify session expiry update");
307        }
308    }
309}
310
311/// Token renew task error.
312#[derive(Debug, thiserror::Error)]
313pub enum RenewTaskError {
314    /// Session expired.
315    #[error("session expired")]
316    SessionExpired,
317    /// Maximum number of retries reached.
318    #[error("maximum number of retries reached")]
319    MaxRetriesReached,
320}
321
322/// Renew SNAP tunnel session.
323///
324/// This opens a new bi-directional stream to the server, sends a session renewal request, and waits
325/// for the response. On success, it returns the new session expiry time.
326pub async fn renew_session(
327    conn: &quinn::Connection,
328    token: &str,
329) -> Result<SystemTime, ControlError> {
330    let (mut snd, mut rcv) = conn.open_bi().await?;
331
332    let body = vec![];
333    send_control_request(&mut snd, crate::PATH_SESSION_RENEWAL, &body, token).await?;
334    let mut resp_buf = [0u8; CTRL_RESPONSE_BUF_SIZE];
335    let response: SessionRenewalResponse = parse_http_response(&mut resp_buf[..], &mut rcv).await?;
336
337    Ok(system_time_from_unix_epoch_secs(response.valid_until))
338}
339
340impl Drop for Control {
341    fn drop(&mut self) {
342        if let Some(task) = self.session_renewal_task.take() {
343            // Cancel the session renewal task
344            task.abort();
345        }
346    }
347}
348
349/// Connection state.
350#[derive(Debug, Clone)]
351struct ConnState {
352    session_token: String,
353    session_expiry: SystemTime,
354    assigned_addresses: Vec<EndhostAddr>,
355    expiry_notifier: watch::Sender<()>,
356}
357
358impl ConnState {
359    fn new(expiry_notifier: watch::Sender<()>) -> Self {
360        Self {
361            session_token: String::new(),
362            session_expiry: SystemTime::UNIX_EPOCH,
363            assigned_addresses: Vec::new(),
364            expiry_notifier,
365        }
366    }
367}
368
369#[derive(Debug, Clone)]
370struct SharedConnState(Arc<RwLock<ConnState>>);
371
372impl SharedConnState {
373    fn new(conn_state: ConnState) -> Self {
374        Self(Arc::new(RwLock::new(conn_state)))
375    }
376}
377
378impl Deref for SharedConnState {
379    type Target = Arc<RwLock<ConnState>>;
380
381    fn deref(&self) -> &Self::Target {
382        &self.0
383    }
384}
385
386/// SNAP tunnel sender.
387pub struct Sender {
388    conn: quinn::Connection,
389}
390
391impl Sender {
392    /// Creates a new sender.
393    pub fn new(conn: quinn::Connection) -> Self {
394        Self { conn }
395    }
396
397    /// Sends a datagram to the connection.
398    pub fn send_datagram(&self, data: Bytes) -> Result<(), quinn::SendDatagramError> {
399        self.conn.send_datagram(data)?;
400        Ok(())
401    }
402
403    /// Sends a datagram to the connection and waits for the datagram to be sent.
404    pub async fn send_datagram_wait(&self, data: Bytes) -> Result<(), quinn::SendDatagramError> {
405        self.conn.send_datagram_wait(data).await?;
406        Ok(())
407    }
408}
409
410/// SNAP tunnel receiver.
411pub struct Receiver {
412    conn: quinn::Connection,
413}
414
415impl Receiver {
416    /// Reads a datagram from the connection.
417    pub async fn read_datagram(&self) -> Result<Bytes, quinn::ConnectionError> {
418        let packet = self.conn.read_datagram().await?;
419        Ok(packet)
420    }
421}
422
423/// Parse response error.
424#[derive(Debug, thiserror::Error)]
425pub enum ParseResponseError {
426    /// Parsing HTTP envelope failed.
427    #[error("parsing HTTP envelope failed: {0}")]
428    HTTParseError(#[from] httparse::Error),
429    /// QUIC read error.
430    #[error("read error: {0}")]
431    ReadError(#[from] quinn::ReadError),
432    /// Protobuf decode error.
433    #[error("parsing control message failed: {0}")]
434    ParseError(#[from] prost::DecodeError),
435}
436
437async fn parse_http_response<M: prost::Message + Default>(
438    buf: &mut [u8],
439    rcv: &mut RecvStream,
440) -> Result<M, ParseResponseError> {
441    let mut cursor = 0usize;
442    let mut body_offset = 0usize;
443    while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
444        cursor += n;
445        let mut headers = [httparse::EMPTY_HEADER; 16];
446        let mut resp = httparse::Response::new(&mut headers);
447        body_offset = match resp.parse(&buf[..cursor]) {
448            Ok(httparse::Status::Partial) => continue,
449            Ok(httparse::Status::Complete(n)) => n,
450            Err(e) => return Err(ParseResponseError::HTTParseError(e)),
451        };
452    }
453    // we want to keep this method cancel-safe, so we use repeated reads.
454    while let Some(n) = rcv.read(&mut buf[cursor..]).await? {
455        cursor += n;
456    }
457    let m = M::decode(&buf[body_offset..cursor])?;
458    Ok(m)
459}
460
461/// Send control request error.
462#[derive(Debug, thiserror::Error)]
463pub enum SendControlRequestError {
464    /// I/O error.
465    #[error("i/o error: {0}")]
466    IoError(#[from] std::io::Error),
467    /// QUIC closed stream error.
468    #[error("stream closed: {0}")]
469    ClosedStream(#[from] quinn::ClosedStream),
470}
471
472/// Send a control request to the server using `snd` as the request-stream.
473async fn send_control_request(
474    snd: &mut SendStream,
475    method: &str,
476    body: &[u8],
477    token: &str,
478) -> Result<(), SendControlRequestError> {
479    write_all(
480        snd,
481        format!(
482            "POST {method} HTTP/1.1\r\n\
483content-type: application/proto\r\n\
484connect-protocol-version: 1\r\n\
485content-encoding: identity\r\n\
486accept-encoding: identity\r\n\
487content-length: {}\r\n\
488Authorization: Bearer {token}\r\n\r\n",
489            body.len()
490        )
491        .as_bytes(),
492    )
493    .await?;
494    write_all(snd, body).await?;
495    snd.finish()?;
496    Ok(())
497}
498
499// SendStream::write_all is not cancel-safe, so we use loops instead.
500async fn write_all(stream: &mut SendStream, data: &[u8]) -> std::io::Result<()> {
501    let mut cursor = 0;
502    while cursor < data.len() {
503        cursor += stream.write(&data[cursor..]).await?;
504    }
505    Ok(())
506}
507
508/// SNAP tunnel errors.
509#[derive(Debug, thiserror::Error)]
510pub enum SnapTunError {
511    /// Initial token error.
512    #[error("initial token error: {0}")]
513    InitialTokenError(#[from] TokenRenewError),
514    /// Control error.
515    #[error("control error: {0}")]
516    ControlError(#[from] ControlError),
517}
518
519/// SNAP tunnel control errors.
520#[derive(Debug, thiserror::Error)]
521pub enum ControlError {
522    /// QUIC connection error.
523    #[error("quinn connection error: {0}")]
524    ConnectionError(#[from] quinn::ConnectionError),
525    /// Address assignment failed.
526    #[error("address assignment failed: {0}")]
527    AddressAssignmentFailed(#[from] AddrAssignError),
528    /// Parse control request response error.
529    #[error("parse control request response: {0}")]
530    ParseResponse(#[from] ParseResponseError),
531    /// Send control request error.
532    #[error("send control request error: {0}")]
533    SendRequestError(#[from] SendControlRequestError),
534}
535
536/// Address assignment error.
537#[derive(Debug, thiserror::Error)]
538pub enum AddrAssignError {
539    /// Invalid address.
540    #[error("invalid addr: {0}")]
541    InvalidAddr(#[from] AddrError),
542    /// No address assigned.
543    #[error("no address assigned")]
544    NoAddressAssigned,
545}