Skip to main content

trojan_server/
server.rs

1//! Main server loop and connection handling.
2
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use tokio::sync::{OwnedSemaphorePermit, Semaphore};
9use tokio::time::Instant;
10use tokio_rustls::TlsAcceptor;
11use tokio_util::sync::CancellationToken;
12use tracing::{Instrument, debug, info, info_span, warn};
13
14use crate::error::ServerError;
15use crate::handler::handle_conn;
16use crate::pool::ConnectionPool;
17use crate::rate_limit::RateLimiter;
18use crate::resolve::resolve_sockaddr;
19use crate::state::ServerState;
20use crate::tls::load_tls_config;
21use crate::util::{ConnectionGuard, ConnectionTracker, apply_tcp_options, create_listener};
22use trojan_auth::AuthBackend;
23use trojan_config::Config;
24use trojan_core::defaults;
25use trojan_metrics::{
26    ERROR_TLS_HANDSHAKE, record_connection_accepted, record_connection_closed,
27    record_connection_rejected, record_error, record_tls_handshake_duration,
28    set_connection_queue_depth,
29};
30
31/// Default graceful shutdown timeout.
32pub const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
33
34/// Global connection ID counter.
35static CONN_ID: AtomicU64 = AtomicU64::new(1);
36
37/// Generate a unique connection ID.
38#[inline]
39fn next_conn_id() -> u64 {
40    CONN_ID.fetch_add(1, Ordering::Relaxed)
41}
42
43/// Run the server with a cancellation token for graceful shutdown.
44pub async fn run_with_shutdown(
45    config: Config,
46    auth: impl AuthBackend + 'static,
47    shutdown: CancellationToken,
48) -> Result<(), ServerError> {
49    let tls_config = load_tls_config(&config.tls)?;
50    let acceptor = TlsAcceptor::from(Arc::new(tls_config));
51
52    let listen: SocketAddr = config
53        .server
54        .listen
55        .parse()
56        .map_err(|_| ServerError::Config("invalid listen address".into()))?;
57
58    let fallback_addr = resolve_sockaddr(&config.server.fallback, config.server.tcp.prefer_ipv4).await?;
59
60    // Initialize fallback connection pool if configured
61    let fallback_pool: Option<Arc<ConnectionPool>> =
62        config.server.fallback_pool.as_ref().map(|pool_cfg| {
63            info!(
64                max_idle = pool_cfg.max_idle,
65                max_age_secs = pool_cfg.max_age_secs,
66                fill_batch = pool_cfg.fill_batch,
67                fill_delay_ms = pool_cfg.fill_delay_ms,
68                "fallback connection pool enabled"
69            );
70            let pool = Arc::new(ConnectionPool::new(
71                fallback_addr,
72                pool_cfg.max_idle,
73                pool_cfg.max_age_secs,
74                pool_cfg.fill_batch,
75                pool_cfg.fill_delay_ms,
76            ));
77            // Use max_age_secs as cleanup interval
78            pool.start_cleanup_task(Duration::from_secs(pool_cfg.max_age_secs));
79            pool
80        });
81
82    // Extract resource limits with defaults
83    let (relay_buffer_size, tcp_send_buffer, tcp_recv_buffer, connection_backlog) =
84        match &config.server.resource_limits {
85            Some(rl) => {
86                info!(
87                    relay_buffer = rl.relay_buffer_size,
88                    tcp_send_buffer = rl.tcp_send_buffer,
89                    tcp_recv_buffer = rl.tcp_recv_buffer,
90                    connection_backlog = rl.connection_backlog,
91                    "resource limits configured"
92                );
93                (
94                    rl.relay_buffer_size,
95                    rl.tcp_send_buffer,
96                    rl.tcp_recv_buffer,
97                    rl.connection_backlog,
98                )
99            }
100            None => (
101                defaults::DEFAULT_RELAY_BUFFER_SIZE,
102                defaults::DEFAULT_TCP_SEND_BUFFER,
103                defaults::DEFAULT_TCP_RECV_BUFFER,
104                defaults::DEFAULT_CONNECTION_BACKLOG,
105            ),
106        };
107
108    // Initialize analytics if feature enabled and configured
109    #[cfg(feature = "analytics")]
110    let analytics = if config.analytics.enabled {
111        match trojan_analytics::init(config.analytics.clone()).await {
112            Ok(collector) => {
113                info!("analytics enabled, sending to ClickHouse");
114                Some(collector)
115            }
116            Err(e) => {
117                warn!("failed to init analytics: {}, disabled", e);
118                None
119            }
120        }
121    } else {
122        debug!("analytics disabled in config");
123        None
124    };
125
126    // Log TCP options
127    let tcp_cfg = &config.server.tcp;
128    info!(
129        no_delay = tcp_cfg.no_delay,
130        keepalive_secs = tcp_cfg.keepalive_secs,
131        reuse_port = tcp_cfg.reuse_port,
132        fast_open = tcp_cfg.fast_open,
133        "TCP options configured"
134    );
135
136    let state = Arc::new(ServerState {
137        fallback_addr,
138        max_udp_payload: config.server.max_udp_payload,
139        max_udp_buffer_bytes: config.server.max_udp_buffer_bytes,
140        max_header_bytes: config.server.max_header_bytes,
141        tcp_idle_timeout: Duration::from_secs(config.server.tcp_idle_timeout_secs),
142        udp_idle_timeout: Duration::from_secs(config.server.udp_timeout_secs),
143        fallback_pool,
144        relay_buffer_size,
145        tcp_send_buffer,
146        tcp_recv_buffer,
147        tcp_config: config.server.tcp.clone(),
148        websocket: config.websocket.clone(),
149        #[cfg(feature = "analytics")]
150        analytics,
151    });
152    let auth = Arc::new(auth);
153    let tracker = ConnectionTracker::new();
154
155    // Connection limiter (None = unlimited)
156    let conn_limit: Option<Arc<Semaphore>> = config.server.max_connections.map(|n| {
157        info!("max_connections set to {}", n);
158        Arc::new(Semaphore::new(n))
159    });
160
161    // Rate limiter (None = disabled)
162    let rate_limiter: Option<Arc<RateLimiter>> = config.server.rate_limit.as_ref().map(|rl| {
163        info!(
164            max_per_ip = rl.max_connections_per_ip,
165            window_secs = rl.window_secs,
166            "rate limiting enabled"
167        );
168        let limiter = Arc::new(RateLimiter::new(rl.max_connections_per_ip, rl.window_secs));
169        limiter.start_cleanup_task(Duration::from_secs(rl.cleanup_interval_secs));
170        limiter
171    });
172
173    // Create listener with custom backlog and TCP options using socket2
174    let listener = create_listener(listen, connection_backlog, &config.server.tcp)?;
175    info!(address = %listen, backlog = connection_backlog, "listening");
176
177    #[cfg(feature = "ws")]
178    if config.websocket.enabled && config.websocket.mode == "split" {
179        let ws_listen = config.websocket.listen.clone().unwrap_or_default();
180        let ws_addr: SocketAddr = ws_listen
181            .parse()
182            .map_err(|_| ServerError::Config("invalid websocket.listen address".into()))?;
183        let ws_listener = create_listener(ws_addr, connection_backlog, &config.server.tcp)?;
184        let ws_acceptor = acceptor.clone();
185        let ws_state = state.clone();
186        let ws_auth = auth.clone();
187        let ws_tracker = tracker.clone();
188        let ws_conn_limit = conn_limit.clone();
189        let ws_rate_limiter = rate_limiter.clone();
190        let ws_shutdown = shutdown.clone();
191
192        info!(address = %ws_addr, "websocket split listener started");
193        tokio::spawn(async move {
194            loop {
195                tokio::select! {
196                    biased;
197                    _ = ws_shutdown.cancelled() => break,
198                    result = ws_listener.accept() => {
199                        let (tcp, peer) = match result {
200                            Ok(v) => v,
201                            Err(_) => continue,
202                        };
203
204                        // Apply TCP socket options
205                        if let Err(e) = apply_tcp_options(&tcp, &ws_state.tcp_config) {
206                            tracing::debug!(error = %e, "failed to apply TCP options");
207                        }
208
209                        if let Some(ref limiter) = ws_rate_limiter {
210                            let ip = peer.ip();
211                            if !limiter.check_and_increment(ip) {
212                                record_connection_rejected("rate_limit");
213                                drop(tcp);
214                                continue;
215                            }
216                        }
217
218                        let permit: Option<OwnedSemaphorePermit> = match &ws_conn_limit {
219                            Some(sem) => match sem.clone().try_acquire_owned() {
220                                Ok(p) => Some(p),
221                                Err(_) => {
222                                    record_connection_rejected("max_connections");
223                                    drop(tcp);
224                                    continue;
225                                }
226                            },
227                            None => None,
228                        };
229
230                        let conn_id = next_conn_id();
231                        let acceptor = ws_acceptor.clone();
232                        let state = ws_state.clone();
233                        let auth = ws_auth.clone();
234                        ws_tracker.increment();
235                        let guard = ConnectionGuard::new(ws_tracker.clone());
236
237                        let span = info_span!("conn", id = conn_id, peer = %peer, transport = "ws");
238                        tokio::spawn(
239                            async move {
240                                let _guard = guard;
241                                let _permit = permit;
242                                record_connection_accepted();
243                                let start = Instant::now();
244
245                                let result = async {
246                                    let tls_start = Instant::now();
247                                    let tls_timeout =
248                                        Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
249                                    match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await
250                                    {
251                                        Ok(Ok(tls)) => {
252                                            let tls_duration = tls_start.elapsed().as_secs_f64();
253                                            record_tls_handshake_duration(tls_duration);
254                                            crate::handler::handle_ws_only(tls, state, auth, peer).await
255                                        }
256                                        Ok(Err(err)) => {
257                                            record_error(ERROR_TLS_HANDSHAKE);
258                                            warn!(error = %err, "TLS handshake failed");
259                                            Ok(())
260                                        }
261                                        Err(_) => {
262                                            record_error(ERROR_TLS_HANDSHAKE);
263                                            warn!(
264                                                timeout_secs = tls_timeout.as_secs(),
265                                                "TLS handshake timed out"
266                                            );
267                                            Ok(())
268                                        }
269                                    }
270                                }
271                                .await;
272
273                                let duration_secs = start.elapsed().as_secs_f64();
274                                record_connection_closed(duration_secs);
275
276                                if let Err(ref err) = result {
277                                    warn!(error = %err, "connection error");
278                                }
279                            }
280                            .instrument(span),
281                        );
282                    }
283                }
284            }
285        });
286    }
287
288    #[cfg(not(feature = "ws"))]
289    if config.websocket.enabled {
290        warn!("websocket.enabled=true but ws feature is disabled; ignoring websocket");
291    }
292
293    loop {
294        tokio::select! {
295            biased;
296
297            _ = shutdown.cancelled() => {
298                info!("shutdown signal received, stopping accept loop");
299                break;
300            }
301
302            result = listener.accept() => {
303                let (tcp, peer) = result?;
304
305                // Apply TCP socket options (no_delay, keepalive)
306                if let Err(e) = apply_tcp_options(&tcp, &state.tcp_config) {
307                    debug!(error = %e, "failed to apply TCP options");
308                }
309
310                // Update connection queue depth metric (based on semaphore usage)
311                if let Some(ref sem) = conn_limit {
312                    let available = sem.available_permits();
313                    set_connection_queue_depth(available as f64);
314                }
315
316                // Check rate limit first
317                if let Some(ref limiter) = rate_limiter {
318                    let ip = peer.ip();
319                    if !limiter.check_and_increment(ip) {
320                        debug!(peer = %peer, reason = "rate_limit", "connection rejected");
321                        record_connection_rejected("rate_limit");
322                        drop(tcp);
323                        continue;
324                    }
325                }
326
327                // Try to acquire connection permit
328                let permit: Option<OwnedSemaphorePermit> = match &conn_limit {
329                    Some(sem) => match sem.clone().try_acquire_owned() {
330                        Ok(p) => Some(p),
331                        Err(_) => {
332                            debug!(peer = %peer, reason = "max_connections", "connection rejected");
333                            record_connection_rejected("max_connections");
334                            drop(tcp); // close immediately
335                            continue;
336                        }
337                    },
338                    None => None,
339                };
340
341                let conn_id = next_conn_id();
342                debug!(conn_id, peer = %peer, "new connection");
343
344                let acceptor = acceptor.clone();
345                let state = state.clone();
346                let auth = auth.clone();
347                tracker.increment();
348                let guard = ConnectionGuard::new(tracker.clone());
349
350                let span = info_span!("conn", id = conn_id, peer = %peer);
351                tokio::spawn(
352                    async move {
353                        let _guard = guard; // ensure decrement on drop
354                        let _permit = permit; // hold permit until connection closes
355                        record_connection_accepted();
356                        let start = Instant::now();
357
358                        let result = async {
359                            // Measure TLS handshake duration with timeout
360                            let tls_start = Instant::now();
361                            let tls_timeout =
362                                Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
363                            match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await {
364                                Ok(Ok(tls)) => {
365                                    let tls_duration = tls_start.elapsed().as_secs_f64();
366                                    record_tls_handshake_duration(tls_duration);
367                                    debug!(duration_ms = tls_duration * 1000.0, "TLS handshake completed");
368                                    handle_conn(tls, state, auth, peer).await
369                                }
370                                Ok(Err(err)) => {
371                                    record_error(ERROR_TLS_HANDSHAKE);
372                                    warn!(error = %err, "TLS handshake failed");
373                                    Ok(())
374                                }
375                                Err(_) => {
376                                    record_error(ERROR_TLS_HANDSHAKE);
377                                    warn!(timeout_secs = tls_timeout.as_secs(), "TLS handshake timed out");
378                                    Ok(())
379                                }
380                            }
381                        }
382                        .await;
383
384                        let duration_secs = start.elapsed().as_secs_f64();
385                        record_connection_closed(duration_secs);
386
387                        if let Err(ref err) = result {
388                            record_error(err.error_type());
389                            warn!(duration_secs, error = %err, "connection closed with error");
390                        } else {
391                            debug!(duration_secs, "connection closed");
392                        }
393                    }
394                    .instrument(span),
395                );
396            }
397        }
398    }
399
400    // Shutdown rate limiter cleanup task
401    if let Some(ref limiter) = rate_limiter {
402        limiter.shutdown();
403    }
404
405    // Graceful drain: wait for active connections
406    let active = tracker.count();
407    if active > 0 {
408        info!("waiting for {} active connections to drain", active);
409        if tracker.wait_for_zero(DEFAULT_SHUTDOWN_TIMEOUT).await {
410            info!("all connections drained");
411        } else {
412            warn!(
413                "shutdown timeout, {} connections still active",
414                tracker.count()
415            );
416        }
417    }
418
419    info!("server stopped");
420    Ok(())
421}
422
423/// Run the server (blocking until error, no graceful shutdown).
424/// For backward compatibility with existing code.
425pub async fn run(config: Config, auth: impl AuthBackend + 'static) -> Result<(), ServerError> {
426    run_with_shutdown(config, auth, CancellationToken::new()).await
427}