1mod tunnel;
17
18use std::{
19 collections::{
20 BTreeMap,
21 btree_map::Entry::{Occupied, Vacant},
22 },
23 net::SocketAddr,
24 sync::{Arc, Mutex},
25};
26
27use ana_gotatun::{
28 packet::PacketBufPool,
29 x25519::{self, PublicKey},
30};
31
32pub const PACKET_BUF_POOL_SIZE: usize = 65535;
35
36pub const PERSISTENT_KEEPALIVE_SECONDS: u16 = 10;
39
40use scion_sdk_reqwest_connect_rpc::{client::CrpcClientError, token_source::TokenSource};
41use scion_sdk_utils::backoff::{BackoffConfig, ExponentialBackoff};
42use tokio::task::{AbortHandle, JoinSet};
43pub use tunnel::{SnapTunnel, SnapTunnelDriverError, SnapTunnelReceiveError};
44
45#[async_trait::async_trait]
47pub trait SnapTunControlPlaneClient: Send + Sync {
48 async fn register_identity(
50 &self,
51 identity: PublicKey,
52 psk_share: Option<[u8; 32]>,
53 ) -> Result<Option<[u8; 32]>, CrpcClientError>;
54
55 async fn register_identity_with_retries(
57 &self,
58 identity: PublicKey,
59 psk_share: Option<[u8; 32]>,
60 backoff: ExponentialBackoff,
61 max_attempts: u32,
62 ) -> Result<Option<[u8; 32]>, CrpcClientError> {
63 let mut attempt = 0u32;
64 loop {
65 match self.register_identity(identity, psk_share).await {
66 Ok(psk_share) => return Ok(psk_share),
67 Err(e) => {
68 if attempt == max_attempts - 1 {
69 return Err(e);
70 }
71 attempt += 1;
72 tokio::time::sleep(backoff.duration(attempt)).await;
73 }
74 }
75 }
76 }
77}
78
79struct SnapTunControlPlane {
82 address: url::Url,
83 client: Arc<dyn SnapTunControlPlaneClient>,
84 tunnel_count: u64,
85}
86
87impl Clone for SnapTunControlPlane {
88 fn clone(&self) -> Self {
89 Self {
90 address: self.address.clone(),
91 client: self.client.clone(),
92 tunnel_count: self.tunnel_count,
93 }
94 }
95}
96
97struct SnapTunEndpointState {
98 control_planes: Arc<Mutex<BTreeMap<url::Url, SnapTunControlPlane>>>,
99 pub static_private: x25519::StaticSecret,
100 pub static_public: x25519::PublicKey,
101 pub backoff: ExponentialBackoff,
102 pub max_attempts: u32,
103}
104
105pub(super) struct TunnelGuard {
109 endpoint_state: Arc<SnapTunEndpointState>,
110 control_plane: url::Url,
111}
112
113impl Drop for TunnelGuard {
114 fn drop(&mut self) {
115 self.endpoint_state
116 .remove_tunnel(self.control_plane.clone());
117 }
118}
119
120impl SnapTunEndpointState {
121 async fn identity_registration_loop(&self, token_source: Arc<dyn TokenSource>) {
123 let mut watch = token_source.watch();
124 let _ = watch.borrow_and_update();
126 loop {
127 let control_planes = self
129 .control_planes
130 .lock()
131 .expect("lock poisoned")
132 .values()
133 .cloned()
134 .collect::<Vec<_>>();
135 let mut set = JoinSet::new();
136 for control_plane in control_planes {
137 let static_public = self.static_public;
138 let backoff = self.backoff;
139 let max_attempts = self.max_attempts;
140 set.spawn(async move {
141 if let Err(e) = control_plane.client.register_identity_with_retries(static_public, None, backoff, max_attempts).await {
142 tracing::error!(cp_address=%control_plane.address, err=?e, "error registering identity with control plane");
143 }
144 });
145 }
146 set.join_all().await;
147 if watch.changed().await.is_err() {
148 tracing::info!(
149 "token source watch channel closed, stopping identity registration loop"
150 );
151 return;
152 }
153 let r = watch.borrow();
154 if let Some(Ok(r)) = &*r {
155 let token_sig = r.rsplit('.').next().unwrap_or("");
158 tracing::debug!(token_sig, "token renewal in registration loop");
159 }
160 }
161 }
162
163 async fn add_tunnel(
164 self: Arc<Self>,
165 address: url::Url,
166 client: Arc<dyn SnapTunControlPlaneClient>,
167 ) -> Result<TunnelGuard, CrpcClientError> {
168 let new = {
169 let mut control_planes = self.control_planes.lock().expect("lock poisoned");
170 match control_planes.entry(address.clone()) {
171 Occupied(mut entry) => {
172 entry.get_mut().tunnel_count += 1;
173 false
174 }
175 Vacant(entry) => {
176 entry.insert(SnapTunControlPlane {
177 address: address.clone(),
178 client: client.clone(),
179 tunnel_count: 1,
180 });
181 true
182 }
183 }
184 };
185 if new {
186 let static_public = self.static_public;
187 let backoff = self.backoff;
188 let max_attempts = self.max_attempts / 2;
189 client
190 .register_identity_with_retries(static_public, None, backoff, max_attempts)
191 .await?;
192 }
193 Ok(TunnelGuard {
194 endpoint_state: self.clone(),
195 control_plane: address,
196 })
197 }
198
199 fn remove_tunnel(&self, address: url::Url) {
200 let mut control_planes = self.control_planes.lock().unwrap();
201 if let Occupied(mut entry) = control_planes.entry(address) {
202 entry.get_mut().tunnel_count -= 1;
203 if entry.get().tunnel_count == 0 {
204 entry.remove();
205 }
206 }
207 }
208}
209
210pub struct SnapTunEndpoint {
214 state: Arc<SnapTunEndpointState>,
215 identity_registration_abort_handle: AbortHandle,
216}
217
218impl Drop for SnapTunEndpoint {
219 fn drop(&mut self) {
220 self.identity_registration_abort_handle.abort();
221 }
222}
223
224#[derive(Debug, thiserror::Error)]
226pub enum ConnectSnapTunSocketError {
227 #[error("error registering identity with control plane: {0}")]
229 SnapTunControlPlaneClientError(#[from] CrpcClientError),
230 #[error("error connecting snap tunnel: {0}")]
232 SnapTunConnectionError(#[from] SnapTunnelDriverError),
233}
234
235impl SnapTunEndpoint {
236 pub fn new(token_source: Arc<dyn TokenSource>, static_private: x25519::StaticSecret) -> Self {
238 let static_public = x25519::PublicKey::from(&static_private);
239 let state = Arc::new(SnapTunEndpointState {
240 control_planes: Arc::new(Mutex::new(BTreeMap::new())),
241 static_private,
242 static_public,
243 backoff: ExponentialBackoff::new_from_config(BackoffConfig {
244 minimum_delay_secs: 1.0,
245 maximum_delay_secs: 20.0,
246 factor: 1.2,
247 jitter_secs: 0.1,
248 }),
249 max_attempts: 10,
250 });
251 let state_clone = state.clone();
252 let abort_handle =
253 tokio::spawn(async move { state_clone.identity_registration_loop(token_source).await })
254 .abort_handle();
255 Self {
256 state,
257 identity_registration_abort_handle: abort_handle,
258 }
259 }
260
261 pub async fn connect_tunnel(
264 &self,
265 peer_public: x25519::PublicKey,
266 dataplane_address: SocketAddr,
267 control_plane: url::Url,
268 control_plane_client: Arc<dyn SnapTunControlPlaneClient>,
269 underlay_socket: Arc<tokio::net::UdpSocket>,
270 receive_queue_capacity: usize,
271 pool: PacketBufPool<PACKET_BUF_POOL_SIZE>,
272 ) -> Result<SnapTunnel, ConnectSnapTunSocketError> {
273 let guard = self
274 .state
275 .clone()
276 .add_tunnel(control_plane, control_plane_client)
277 .await?;
278 Ok(SnapTunnel::new(
279 guard,
280 self.state.static_private.clone(),
281 peer_public,
282 underlay_socket,
283 dataplane_address,
284 receive_queue_capacity,
285 Some(PERSISTENT_KEEPALIVE_SECONDS),
286 pool,
287 )
288 .await?)
289 }
290}