1use std::net::{Ipv4Addr, SocketAddr};
8use std::sync::Arc;
9use std::time::Duration;
10
11use futures_util::{SinkExt, StreamExt};
12use sha2::{Digest, Sha256};
13use tokio::net::TcpStream;
14use tokio::sync::mpsc;
15use tokio::time::{interval, timeout};
16use tokio_tungstenite::{accept_async, tungstenite::Message as WsMessage, WebSocketStream};
17use uuid::Uuid;
18
19use crate::overlay::DynTunnelDnsRegistrar;
20use crate::{
21 ControlMessage, Message, Result, ServiceProtocol, TunnelError, TunnelRegistry,
22 TunnelServerConfig,
23};
24
25pub type TokenValidator = Arc<dyn Fn(&str) -> Result<()> + Send + Sync>;
27
28#[inline]
34#[allow(clippy::cast_possible_truncation)]
35fn current_timestamp_ms() -> u64 {
36 std::time::SystemTime::now()
37 .duration_since(std::time::UNIX_EPOCH)
38 .map(|d| d.as_millis())
39 .unwrap_or(0)
40 .min(u128::from(u64::MAX)) as u64
41}
42
43pub struct ControlHandler {
69 registry: Arc<TunnelRegistry>,
70 config: TunnelServerConfig,
71
72 token_validator: TokenValidator,
74
75 dns_registrar: Option<DynTunnelDnsRegistrar>,
77
78 local_overlay_ip: Option<Ipv4Addr>,
80}
81
82impl ControlHandler {
83 #[must_use]
91 pub fn new(
92 registry: Arc<TunnelRegistry>,
93 config: TunnelServerConfig,
94 token_validator: TokenValidator,
95 ) -> Self {
96 Self {
97 registry,
98 config,
99 token_validator,
100 dns_registrar: None,
101 local_overlay_ip: None,
102 }
103 }
104
105 #[must_use]
107 pub fn with_dns_registrar(mut self, registrar: DynTunnelDnsRegistrar) -> Self {
108 self.dns_registrar = Some(registrar);
109 self
110 }
111
112 #[must_use]
114 pub fn with_local_overlay_ip(mut self, ip: Ipv4Addr) -> Self {
115 self.local_overlay_ip = Some(ip);
116 self
117 }
118
119 pub async fn handle_connection(
142 &self,
143 stream: TcpStream,
144 client_addr: SocketAddr,
145 ) -> Result<()> {
146 let ws_stream = accept_async(stream)
148 .await
149 .map_err(TunnelError::connection)?;
150
151 let (mut ws_sink, mut ws_stream) = ws_stream.split();
152
153 let auth_timeout = Duration::from_secs(10);
155 let auth_msg = timeout(auth_timeout, async {
156 while let Some(msg) = ws_stream.next().await {
157 match msg {
158 Ok(WsMessage::Binary(data)) => {
159 return Message::decode(&data).map(|(m, _)| m);
160 }
161 Ok(WsMessage::Close(_)) => {
162 return Err(TunnelError::connection_msg("Client closed connection"));
163 }
164 Ok(_) => {} Err(e) => return Err(TunnelError::connection(e)),
166 }
167 }
168 Err(TunnelError::connection_msg("Connection closed before auth"))
169 })
170 .await
171 .map_err(|_| TunnelError::timeout())??;
172
173 let Message::Auth {
175 token,
176 client_id: _,
177 } = auth_msg
178 else {
179 let fail = Message::AuthFail {
180 reason: "Expected AUTH message".to_string(),
181 };
182 let _ = ws_sink.send(WsMessage::Binary(fail.encode().into())).await;
183 return Err(TunnelError::auth("Expected AUTH message"));
184 };
185
186 if let Err(e) = (self.token_validator)(&token) {
188 let fail = Message::AuthFail {
189 reason: e.to_string(),
190 };
191 let _ = ws_sink.send(WsMessage::Binary(fail.encode().into())).await;
192 return Err(e);
193 }
194
195 let token_hash = hash_token(&token);
197
198 if self.registry.token_exists(&token_hash) {
200 let fail = Message::AuthFail {
201 reason: "Token already in use".to_string(),
202 };
203 let _ = ws_sink.send(WsMessage::Binary(fail.encode().into())).await;
204 return Err(TunnelError::auth("Token already in use"));
205 }
206
207 let (control_tx, mut control_rx) = mpsc::channel::<ControlMessage>(256);
209
210 let tunnel = self.registry.register_tunnel(
212 token_hash.clone(),
213 None, control_tx,
215 Some(client_addr),
216 )?;
217
218 let tunnel_id = tunnel.id;
219
220 let auth_ok = Message::AuthOk { tunnel_id };
222 ws_sink
223 .send(WsMessage::Binary(auth_ok.encode().into()))
224 .await
225 .map_err(TunnelError::connection)?;
226
227 tracing::info!(
228 tunnel_id = %tunnel_id,
229 client_addr = %client_addr,
230 "Tunnel authenticated"
231 );
232
233 let result = self
235 .run_message_loop(tunnel_id, &mut ws_sink, &mut ws_stream, &mut control_rx)
236 .await;
237
238 self.registry.unregister_tunnel(tunnel_id);
240
241 tracing::info!(tunnel_id = %tunnel_id, "Tunnel disconnected");
242
243 result
244 }
245
246 async fn run_message_loop(
253 &self,
254 tunnel_id: Uuid,
255 ws_sink: &mut futures_util::stream::SplitSink<WebSocketStream<TcpStream>, WsMessage>,
256 ws_stream: &mut futures_util::stream::SplitStream<WebSocketStream<TcpStream>>,
257 control_rx: &mut mpsc::Receiver<ControlMessage>,
258 ) -> Result<()> {
259 let mut heartbeat_interval = interval(self.config.heartbeat_interval);
260 let heartbeat_timeout = self.config.heartbeat_timeout;
261 let mut last_heartbeat_ack = std::time::Instant::now();
262
263 loop {
264 tokio::select! {
265 _ = heartbeat_interval.tick() => {
267 if last_heartbeat_ack.elapsed() > heartbeat_timeout {
269 tracing::warn!(tunnel_id = %tunnel_id, "Heartbeat timeout");
270 return Err(TunnelError::timeout());
271 }
272
273 let timestamp = current_timestamp_ms();
275 let hb = Message::Heartbeat { timestamp };
276 ws_sink
277 .send(WsMessage::Binary(hb.encode().into()))
278 .await
279 .map_err(TunnelError::connection)?;
280 }
281
282 Some(ctrl_msg) = control_rx.recv() => {
284 let msg = match ctrl_msg {
285 ControlMessage::Connect {
286 service_id,
287 connection_id,
288 client_addr,
289 } => Message::Connect {
290 service_id,
291 connection_id,
292 client_addr: client_addr.to_string(),
293 },
294 ControlMessage::Heartbeat { timestamp } => {
295 Message::Heartbeat { timestamp }
296 }
297 ControlMessage::Disconnect { reason } => {
298 let _ = ws_sink
299 .send(WsMessage::Binary(
300 Message::Disconnect { reason }.encode().into(),
301 ))
302 .await;
303 return Ok(());
304 }
305 };
306 ws_sink
307 .send(WsMessage::Binary(msg.encode().into()))
308 .await
309 .map_err(TunnelError::connection)?;
310 }
311
312 Some(msg_result) = ws_stream.next() => {
314 match msg_result {
315 Ok(WsMessage::Binary(data)) => {
316 let (msg, _) = Message::decode(&data)?;
317
318 if matches!(msg, Message::HeartbeatAck { .. }) {
320 last_heartbeat_ack = std::time::Instant::now();
321 }
322
323 self.handle_client_message(tunnel_id, msg, ws_sink).await?;
324
325 self.registry.touch_tunnel(tunnel_id);
327 }
328 Ok(WsMessage::Close(_)) => {
329 return Ok(());
330 }
331 Ok(WsMessage::Ping(data)) => {
332 ws_sink
333 .send(WsMessage::Pong(data))
334 .await
335 .map_err(TunnelError::connection)?;
336 }
337 Ok(_) => {} Err(e) => {
339 return Err(TunnelError::connection(e));
340 }
341 }
342 }
343
344 else => break,
345 }
346 }
347
348 Ok(())
349 }
350
351 #[allow(clippy::too_many_lines)]
353 async fn handle_client_message(
354 &self,
355 tunnel_id: Uuid,
356 msg: Message,
357 ws_sink: &mut futures_util::stream::SplitSink<WebSocketStream<TcpStream>, WsMessage>,
358 ) -> Result<()> {
359 match msg {
360 Message::Register {
361 name,
362 protocol,
363 local_port,
364 remote_port,
365 } => {
366 self.handle_register(tunnel_id, &name, protocol, local_port, remote_port, ws_sink)
367 .await?;
368 }
369
370 Message::Unregister { service_id } => {
371 if let Err(e) = self.registry.remove_service(tunnel_id, service_id) {
372 tracing::warn!(
373 tunnel_id = %tunnel_id,
374 service_id = %service_id,
375 error = %e,
376 "Service unregistration failed"
377 );
378 } else {
379 tracing::info!(
380 tunnel_id = %tunnel_id,
381 service_id = %service_id,
382 "Service unregistered"
383 );
384 }
385 }
386
387 Message::ConnectAck { connection_id } => {
388 tracing::debug!(
389 tunnel_id = %tunnel_id,
390 connection_id = %connection_id,
391 "Connection acknowledged"
392 );
393 }
395
396 Message::ConnectFail {
397 connection_id,
398 reason,
399 } => {
400 tracing::warn!(
401 tunnel_id = %tunnel_id,
402 connection_id = %connection_id,
403 reason = %reason,
404 "Connection failed"
405 );
406 }
407
408 Message::HeartbeatAck { timestamp } => {
409 let now = current_timestamp_ms();
410 let latency_ms = now.saturating_sub(timestamp);
411 tracing::trace!(
412 tunnel_id = %tunnel_id,
413 latency_ms = latency_ms,
414 "Heartbeat ack received"
415 );
416 }
417
418 Message::Auth { .. }
420 | Message::AuthOk { .. }
421 | Message::AuthFail { .. }
422 | Message::RegisterOk { .. }
423 | Message::RegisterFail { .. }
424 | Message::Connect { .. }
425 | Message::Heartbeat { .. }
426 | Message::Disconnect { .. } => {
427 tracing::warn!(
428 tunnel_id = %tunnel_id,
429 msg_type = ?msg.message_type(),
430 "Unexpected message from client"
431 );
432 }
433 }
434
435 Ok(())
436 }
437
438 async fn handle_register(
440 &self,
441 tunnel_id: Uuid,
442 name: &str,
443 protocol: ServiceProtocol,
444 local_port: u16,
445 remote_port: u16,
446 ws_sink: &mut futures_util::stream::SplitSink<WebSocketStream<TcpStream>, WsMessage>,
447 ) -> Result<()> {
448 let result = self
449 .registry
450 .add_service(tunnel_id, name, protocol, local_port, remote_port);
451
452 let response = match result {
453 Ok(service) => {
454 let assigned_port = service.assigned_port.unwrap_or(remote_port);
455 tracing::info!(
456 tunnel_id = %tunnel_id,
457 service_name = %name,
458 local_port = local_port,
459 remote_port = assigned_port,
460 "Service registered"
461 );
462
463 if let (Some(ref registrar), Some(overlay_ip)) =
465 (&self.dns_registrar, self.local_overlay_ip)
466 {
467 let dns_name = format!("tun-{name}");
468 if let Err(e) = registrar
469 .register_service(&dns_name, overlay_ip, assigned_port)
470 .await
471 {
472 tracing::warn!(
473 service_name = %name,
474 dns_name = %dns_name,
475 error = %e,
476 "Failed to register service in overlay DNS"
477 );
478 } else {
479 tracing::debug!(
480 dns_name = %dns_name,
481 overlay_ip = %overlay_ip,
482 port = assigned_port,
483 "Registered service in overlay DNS"
484 );
485 }
486 }
487
488 Message::RegisterOk {
489 service_id: service.id,
490 }
491 }
492 Err(e) => {
493 tracing::warn!(
494 tunnel_id = %tunnel_id,
495 service_name = %name,
496 error = %e,
497 "Service registration failed"
498 );
499 Message::RegisterFail {
500 reason: e.to_string(),
501 }
502 }
503 };
504
505 ws_sink
506 .send(WsMessage::Binary(response.encode().into()))
507 .await
508 .map_err(TunnelError::connection)?;
509
510 Ok(())
511 }
512}
513
514#[must_use]
528pub fn hash_token(token: &str) -> String {
529 let mut hasher = Sha256::new();
530 hasher.update(token.as_bytes());
531 hex::encode(hasher.finalize())
532}
533
534pub fn accept_all_tokens(token: &str) -> Result<()> {
553 if token.is_empty() {
554 return Err(TunnelError::auth("Token cannot be empty"));
555 }
556 Ok(())
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562
563 #[test]
564 fn test_hash_token_consistent() {
565 let token = "my-secret-token";
566 let hash1 = hash_token(token);
567 let hash2 = hash_token(token);
568
569 assert_eq!(hash1, hash2);
570 assert_eq!(hash1.len(), 64); }
572
573 #[test]
574 fn test_hash_token_different_tokens() {
575 let hash1 = hash_token("token1");
576 let hash2 = hash_token("token2");
577
578 assert_ne!(hash1, hash2);
579 }
580
581 #[test]
582 fn test_hash_token_empty() {
583 let hash = hash_token("");
584 assert_eq!(hash.len(), 64);
586 }
587
588 #[test]
589 fn test_hash_token_known_value() {
590 let hash = hash_token("test");
592 assert_eq!(
594 hash,
595 "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08"
596 );
597 }
598
599 #[test]
600 fn test_accept_all_tokens_valid() {
601 assert!(accept_all_tokens("valid-token").is_ok());
602 assert!(accept_all_tokens("a").is_ok());
603 assert!(accept_all_tokens("very-long-token-with-many-characters").is_ok());
604 }
605
606 #[test]
607 fn test_accept_all_tokens_empty() {
608 let result = accept_all_tokens("");
609 assert!(result.is_err());
610 assert!(result.unwrap_err().to_string().contains("cannot be empty"));
611 }
612
613 #[test]
614 fn test_control_handler_creation() {
615 let registry = Arc::new(TunnelRegistry::default());
616 let config = TunnelServerConfig::default();
617 let validator = Arc::new(accept_all_tokens);
618
619 let handler = ControlHandler::new(registry.clone(), config, validator);
620
621 assert!(Arc::strong_count(&handler.registry) >= 1);
623 }
624
625 #[test]
626 fn test_hash_token_unicode() {
627 let hash = hash_token("token-with-unicode-\u{1F600}");
629 assert_eq!(hash.len(), 64);
630 }
631
632 #[test]
633 fn test_hash_token_special_chars() {
634 let hash = hash_token("token!@#$%^&*()");
635 assert_eq!(hash.len(), 64);
636 }
637}