spacegate_kernel/
listener.rs

1use std::{net::SocketAddr, sync::Arc};
2
3use futures_util::TryFutureExt;
4use tokio_util::sync::CancellationToken;
5use tracing::{instrument, Instrument};
6
7use crate::{service::TcpService, BoxError, BoxResult};
8
9/// Listener embodies the concept of a logical endpoint where a Gateway accepts network connections.
10#[derive(Clone)]
11pub struct SgListen {
12    pub socket_addr: SocketAddr,
13    pub services: Vec<Arc<dyn TcpService>>,
14    pub listener_id: String,
15    cancel_token: CancellationToken,
16}
17
18impl std::fmt::Debug for SgListen {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        f.debug_struct("SgListen")
21            .field("socket_addr", &self.socket_addr)
22            .field("listener_id", &self.listener_id)
23            .field("services", &self.services.iter().map(|s| s.protocol_name()).collect::<Vec<_>>())
24            .finish_non_exhaustive()
25    }
26}
27
28impl SgListen {
29    /// we only have 65535 ports for a console, so it's a safe size
30    pub fn new(socket_addr: SocketAddr, cancel_token: CancellationToken) -> Self {
31        Self {
32            socket_addr,
33            services: Vec::new(),
34            cancel_token,
35            listener_id: Default::default(),
36        }
37    }
38
39    pub fn with_service<S: TcpService>(mut self, service: S) -> Self {
40        self.services.push(Arc::new(service));
41        self
42    }
43
44    pub fn add_service<S: TcpService>(&mut self, service: S) {
45        self.services.push(Arc::new(service));
46    }
47
48    pub fn with_services(mut self, services: Vec<Arc<dyn TcpService>>) -> Self {
49        self.services.extend(services);
50        self
51    }
52
53    pub fn extend_services(&mut self, services: Vec<Arc<dyn TcpService>>) {
54        self.services.extend(services);
55    }
56
57    pub fn with_listener_id(mut self, listener_id: impl Into<String>) -> Self {
58        self.listener_id = listener_id.into();
59        self
60    }
61}
62
63impl SgListen {
64    /// Spawn the listener on the tokio runtime.
65    ///
66    /// It's a shortcut for `tokio::spawn(listener.listen())`.
67    pub fn spawn(self) -> tokio::task::JoinHandle<Result<(), BoxError>> {
68        tokio::spawn(self.listen())
69    }
70
71    /// Listen on the socket address.
72    #[instrument(skip(self), fields(bind=%self.socket_addr))]
73    pub async fn listen(self) -> Result<(), BoxError> {
74        tracing::debug!("start binding...");
75        let listener = tokio::net::TcpListener::bind(self.socket_addr).await?;
76        let cancel_token = self.cancel_token;
77        tracing::debug!("start listening...");
78        let peek_size = self.services.iter().fold(0, |acc, s| acc.max(s.sniff_peek_size()));
79        let services: Arc<[Arc<dyn TcpService>]> = self.services.clone().into();
80        loop {
81            let accepted = tokio::select! {
82                () = cancel_token.cancelled() => {
83                    tracing::warn!("cancelled");
84                    return Ok(());
85                },
86                accepted = listener.accept() => accepted
87            };
88            match accepted {
89                Ok((stream, peer_addr)) => {
90                    let services = services.clone();
91                    let _task = tokio::spawn(
92                        async move {
93                            let mut peek_buf = vec![0u8; peek_size];
94                            stream.peek(&mut peek_buf).await?;
95                            for s in services.iter() {
96                                if s.sniff(&peek_buf) {
97                                    tracing::debug!(tcp_service=%s.protocol_name(), "accepted");
98                                    s.handle(stream, peer_addr).await?;
99                                    break;
100                                }
101                            }
102                            BoxResult::Ok(())
103                        }
104                        .inspect_err(|e| {
105                            tracing::warn!("TcpService error: {:?}", e);
106                        })
107                        .instrument(tracing::info_span!("connection", peer = %peer_addr)),
108                    );
109                }
110                Err(e) => {
111                    tracing::warn!("Accept tcp connection error: {:?}", e);
112                }
113            }
114        }
115    }
116}