Skip to main content

snap_tun/
client.rs

1// Copyright 2026 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! SNAP tunnel client.
15
16mod tunnel;
17
18use std::{
19    collections::{
20        BTreeMap,
21        btree_map::Entry::{Occupied, Vacant},
22    },
23    net::SocketAddr,
24    sync::{Arc, Mutex},
25};
26
27use ana_gotatun::{
28    packet::PacketBufPool,
29    x25519::{self, PublicKey},
30};
31
32/// Size of the packet buffer pool used for SNAP tunnels.
33/// This represents the maximum packet size that can be handled.
34pub const PACKET_BUF_POOL_SIZE: usize = 65535;
35
36/// Use a default value of 10 seconds for Persistent keepalive seconds for SNAP tunnels.
37/// This should be short enough to keep NAT mappings alive.
38pub const PERSISTENT_KEEPALIVE_SECONDS: u16 = 10;
39
40use scion_sdk_reqwest_connect_rpc::{client::CrpcClientError, token_source::TokenSource};
41use scion_sdk_utils::backoff::{BackoffConfig, ExponentialBackoff};
42use tokio::task::{AbortHandle, JoinSet};
43pub use tunnel::{SnapTunnel, SnapTunnelDriverError, SnapTunnelReceiveError};
44
45/// Trait for a control plane client.
46#[async_trait::async_trait]
47pub trait SnapTunControlPlaneClient: Send + Sync {
48    /// Register an identity with the control plane.
49    async fn register_identity(
50        &self,
51        identity: PublicKey,
52        psk_share: Option<[u8; 32]>,
53    ) -> Result<Option<[u8; 32]>, CrpcClientError>;
54
55    /// Register an identity with the control plane with retries.
56    async fn register_identity_with_retries(
57        &self,
58        identity: PublicKey,
59        psk_share: Option<[u8; 32]>,
60        backoff: ExponentialBackoff,
61        max_attempts: u32,
62    ) -> Result<Option<[u8; 32]>, CrpcClientError> {
63        let mut attempt = 0u32;
64        loop {
65            match self.register_identity(identity, psk_share).await {
66                Ok(psk_share) => return Ok(psk_share),
67                Err(e) => {
68                    if attempt == max_attempts - 1 {
69                        return Err(e);
70                    }
71                    attempt += 1;
72                    tokio::time::sleep(backoff.duration(attempt)).await;
73                }
74            }
75        }
76    }
77}
78
79/// Struct to hold information about a snap-tun control plane
80/// and the number of tunnels connected to it.
81struct SnapTunControlPlane {
82    address: url::Url,
83    client: Arc<dyn SnapTunControlPlaneClient>,
84    tunnel_count: u64,
85}
86
87impl Clone for SnapTunControlPlane {
88    fn clone(&self) -> Self {
89        Self {
90            address: self.address.clone(),
91            client: self.client.clone(),
92            tunnel_count: self.tunnel_count,
93        }
94    }
95}
96
97struct SnapTunEndpointState {
98    control_planes: Arc<Mutex<BTreeMap<url::Url, SnapTunControlPlane>>>,
99    pub static_private: x25519::StaticSecret,
100    pub static_public: x25519::PublicKey,
101    pub backoff: ExponentialBackoff,
102    pub max_attempts: u32,
103}
104
105/// Guard that decrements the tunnel count when dropped.
106/// When the tunnel count reaches 0, the control plane is removed from the map
107/// of managed control planes.
108pub(super) struct TunnelGuard {
109    endpoint_state: Arc<SnapTunEndpointState>,
110    control_plane: url::Url,
111}
112
113impl Drop for TunnelGuard {
114    fn drop(&mut self) {
115        self.endpoint_state
116            .remove_tunnel(self.control_plane.clone());
117    }
118}
119
120impl SnapTunEndpointState {
121    /// Task to register the endpoints identity with all control planes
122    async fn identity_registration_loop(&self, token_source: Arc<dyn TokenSource>) {
123        let mut watch = token_source.watch();
124        // drop first bogus update, as initial registration is done in add_tunnel()
125        let _ = watch.borrow_and_update();
126        loop {
127            // register the identity with all managed control planes.
128            let control_planes = self
129                .control_planes
130                .lock()
131                .expect("lock poisoned")
132                .values()
133                .cloned()
134                .collect::<Vec<_>>();
135            let mut set = JoinSet::new();
136            for control_plane in control_planes {
137                let static_public = self.static_public;
138                let backoff = self.backoff;
139                let max_attempts = self.max_attempts;
140                set.spawn(async move {
141                    if let Err(e) = control_plane.client.register_identity_with_retries(static_public, None, backoff, max_attempts).await {
142                        tracing::error!(cp_address=%control_plane.address, err=?e, "error registering identity with control plane");
143                    }
144                });
145            }
146            set.join_all().await;
147            if watch.changed().await.is_err() {
148                tracing::info!(
149                    "token source watch channel closed, stopping identity registration loop"
150                );
151                return;
152            }
153            let r = watch.borrow();
154            if let Some(Ok(r)) = &*r {
155                // assume token is a JWT-token, the signature is a unique
156                // identifier for this token.
157                let token_sig = r.rsplit('.').next().unwrap_or("");
158                tracing::debug!(token_sig, "token renewal in registration loop");
159            }
160        }
161    }
162
163    async fn add_tunnel(
164        self: Arc<Self>,
165        address: url::Url,
166        client: Arc<dyn SnapTunControlPlaneClient>,
167    ) -> Result<TunnelGuard, CrpcClientError> {
168        let new = {
169            let mut control_planes = self.control_planes.lock().expect("lock poisoned");
170            match control_planes.entry(address.clone()) {
171                Occupied(mut entry) => {
172                    entry.get_mut().tunnel_count += 1;
173                    false
174                }
175                Vacant(entry) => {
176                    entry.insert(SnapTunControlPlane {
177                        address: address.clone(),
178                        client: client.clone(),
179                        tunnel_count: 1,
180                    });
181                    true
182                }
183            }
184        };
185        if new {
186            let static_public = self.static_public;
187            let backoff = self.backoff;
188            let max_attempts = self.max_attempts / 2;
189            client
190                .register_identity_with_retries(static_public, None, backoff, max_attempts)
191                .await?;
192        }
193        Ok(TunnelGuard {
194            endpoint_state: self.clone(),
195            control_plane: address,
196        })
197    }
198
199    fn remove_tunnel(&self, address: url::Url) {
200        let mut control_planes = self.control_planes.lock().unwrap();
201        if let Occupied(mut entry) = control_planes.entry(address) {
202            entry.get_mut().tunnel_count -= 1;
203            if entry.get().tunnel_count == 0 {
204                entry.remove();
205            }
206        }
207    }
208}
209
210/// Snap tunnel endpoint that allows creating new snap tun connections.
211/// It holds one static identity and manages the registration of this identity with all connected
212/// control planes.
213pub struct SnapTunEndpoint {
214    state: Arc<SnapTunEndpointState>,
215    identity_registration_abort_handle: AbortHandle,
216}
217
218impl Drop for SnapTunEndpoint {
219    fn drop(&mut self) {
220        self.identity_registration_abort_handle.abort();
221    }
222}
223
224/// Error when connecting to a SNAP tunnel.
225#[derive(Debug, thiserror::Error)]
226pub enum ConnectSnapTunSocketError {
227    /// Error when connecting to the snaptun control plane to register the identity.
228    #[error("error registering identity with control plane: {0}")]
229    SnapTunControlPlaneClientError(#[from] CrpcClientError),
230    /// Error when creating the SNAP tunnel connection.
231    #[error("error connecting snap tunnel: {0}")]
232    SnapTunConnectionError(#[from] SnapTunnelDriverError),
233}
234
235impl SnapTunEndpoint {
236    /// Creates a new SNAP tunnel socket manager.
237    pub fn new(token_source: Arc<dyn TokenSource>, static_private: x25519::StaticSecret) -> Self {
238        let static_public = x25519::PublicKey::from(&static_private);
239        let state = Arc::new(SnapTunEndpointState {
240            control_planes: Arc::new(Mutex::new(BTreeMap::new())),
241            static_private,
242            static_public,
243            backoff: ExponentialBackoff::new_from_config(BackoffConfig {
244                minimum_delay_secs: 1.0,
245                maximum_delay_secs: 20.0,
246                factor: 1.2,
247                jitter_secs: 0.1,
248            }),
249            max_attempts: 10,
250        });
251        let state_clone = state.clone();
252        let abort_handle =
253            tokio::spawn(async move { state_clone.identity_registration_loop(token_source).await })
254                .abort_handle();
255        Self {
256            state,
257            identity_registration_abort_handle: abort_handle,
258        }
259    }
260
261    /// Connects a new SNAP tunnel. If the endpoints static identity is not already registered with
262    /// the selected snap-tun control plane, it is registered before this method returns.
263    pub async fn connect_tunnel(
264        &self,
265        peer_public: x25519::PublicKey,
266        dataplane_address: SocketAddr,
267        control_plane: url::Url,
268        control_plane_client: Arc<dyn SnapTunControlPlaneClient>,
269        underlay_socket: Arc<tokio::net::UdpSocket>,
270        receive_queue_capacity: usize,
271        pool: PacketBufPool<PACKET_BUF_POOL_SIZE>,
272    ) -> Result<SnapTunnel, ConnectSnapTunSocketError> {
273        let guard = self
274            .state
275            .clone()
276            .add_tunnel(control_plane, control_plane_client)
277            .await?;
278        Ok(SnapTunnel::new(
279            guard,
280            self.state.static_private.clone(),
281            peer_public,
282            underlay_socket,
283            dataplane_address,
284            receive_queue_capacity,
285            Some(PERSISTENT_KEEPALIVE_SECONDS),
286            pool,
287        )
288        .await?)
289    }
290}