1use std::net::SocketAddr;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use tokio::sync::{OwnedSemaphorePermit, Semaphore};
9use tokio::time::Instant;
10use tokio_rustls::TlsAcceptor;
11use tokio_util::sync::CancellationToken;
12use tracing::{Instrument, debug, info, info_span, warn};
13
14use crate::error::ServerError;
15use crate::handler::handle_conn;
16use crate::pool::ConnectionPool;
17use crate::rate_limit::RateLimiter;
18use crate::resolve::resolve_sockaddr;
19use crate::state::ServerState;
20use crate::tls::load_tls_config;
21use crate::util::{ConnectionGuard, ConnectionTracker, apply_tcp_options, create_listener};
22use trojan_auth::AuthBackend;
23use trojan_config::Config;
24use trojan_core::defaults;
25use trojan_dns::DnsResolver;
26use trojan_metrics::{
27 ERROR_TLS_HANDSHAKE, record_connection_accepted, record_connection_closed,
28 record_connection_rejected, record_error, record_tls_handshake_duration,
29 set_connection_queue_depth,
30};
31
32pub const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
34
35static CONN_ID: AtomicU64 = AtomicU64::new(1);
37
38#[inline]
40fn next_conn_id() -> u64 {
41 CONN_ID.fetch_add(1, Ordering::Relaxed)
42}
43
44pub async fn run_with_shutdown(
46 config: Config,
47 auth: impl AuthBackend + 'static,
48 shutdown: CancellationToken,
49) -> Result<(), ServerError> {
50 let tls_config = load_tls_config(&config.tls)?;
51 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
52
53 let listen: SocketAddr = config
54 .server
55 .listen
56 .parse()
57 .map_err(|_| ServerError::Config("invalid listen address".into()))?;
58
59 let mut dns_config = config.dns.clone();
62 if config.server.tcp.prefer_ipv4 && !dns_config.prefer_ipv4 {
63 dns_config.prefer_ipv4 = true;
64 info!(
65 "server.tcp.prefer_ipv4 is deprecated; mapped to dns.prefer_ipv4 for backward compatibility"
66 );
67 }
68 let dns_resolver = DnsResolver::new(&dns_config)
69 .map_err(|e| ServerError::Config(format!("dns resolver: {e}")))?;
70 info!(
71 dns = ?dns_config.strategy,
72 prefer_ipv4 = dns_config.prefer_ipv4,
73 "dns resolver initialized"
74 );
75
76 let fallback_addr = resolve_sockaddr(&config.server.fallback, &dns_resolver).await?;
77
78 let fallback_pool: Option<Arc<ConnectionPool>> =
80 config.server.fallback_pool.as_ref().map(|pool_cfg| {
81 info!(
82 max_idle = pool_cfg.max_idle,
83 max_age_secs = pool_cfg.max_age_secs,
84 fill_batch = pool_cfg.fill_batch,
85 fill_delay_ms = pool_cfg.fill_delay_ms,
86 "fallback connection pool enabled"
87 );
88 let pool = Arc::new(ConnectionPool::new(
89 fallback_addr,
90 pool_cfg.max_idle,
91 pool_cfg.max_age_secs,
92 pool_cfg.fill_batch,
93 pool_cfg.fill_delay_ms,
94 ));
95 pool.start_cleanup_task(Duration::from_secs(pool_cfg.max_age_secs));
97 pool
98 });
99
100 let (relay_buffer_size, tcp_send_buffer, tcp_recv_buffer, connection_backlog) =
102 match &config.server.resource_limits {
103 Some(rl) => {
104 info!(
105 relay_buffer = rl.relay_buffer_size,
106 tcp_send_buffer = rl.tcp_send_buffer,
107 tcp_recv_buffer = rl.tcp_recv_buffer,
108 connection_backlog = rl.connection_backlog,
109 "resource limits configured"
110 );
111 (
112 rl.relay_buffer_size,
113 rl.tcp_send_buffer,
114 rl.tcp_recv_buffer,
115 rl.connection_backlog,
116 )
117 }
118 None => (
119 defaults::DEFAULT_RELAY_BUFFER_SIZE,
120 defaults::DEFAULT_TCP_SEND_BUFFER,
121 defaults::DEFAULT_TCP_RECV_BUFFER,
122 defaults::DEFAULT_CONNECTION_BACKLOG,
123 ),
124 };
125
126 #[cfg(feature = "analytics")]
128 let analytics = if config.analytics.enabled {
129 match trojan_analytics::init(config.analytics.clone()).await {
130 Ok(collector) => {
131 info!("analytics enabled, sending to ClickHouse");
132 Some(collector)
133 }
134 Err(e) => {
135 warn!("failed to init analytics: {}, disabled", e);
136 None
137 }
138 }
139 } else {
140 debug!("analytics disabled in config");
141 None
142 };
143
144 #[cfg(feature = "rules")]
146 let rule_engine = if !config.server.rules.is_empty() {
147 match crate::rules::build_rule_engine(&config.server) {
148 Ok(engine) => {
149 info!(
150 rule_sets = engine.rule_set_count(),
151 rules = engine.rule_count(),
152 "rule engine initialized"
153 );
154 Some(Arc::new(trojan_rules::HotRuleEngine::new(engine)))
155 }
156 Err(e) => {
157 return Err(ServerError::Rules(format!("failed to init rules: {e}")));
158 }
159 }
160 } else {
161 debug!("no routing rules configured");
162 None
163 };
164
165 #[cfg(feature = "rules")]
167 if let Some(ref hot_engine) = rule_engine
168 && crate::rules::has_http_providers(&config.server)
169 {
170 let interval_secs = crate::rules::http_update_interval(&config.server).unwrap_or(3600); let engine_ref = hot_engine.clone();
172 let server_cfg = config.server.clone();
173 let update_shutdown = shutdown.clone();
174 info!(interval_secs, "starting background rule update task");
175 tokio::spawn(async move {
176 rule_update_loop(engine_ref, server_cfg, interval_secs, update_shutdown).await;
177 });
178 }
179
180 #[cfg(feature = "rules")]
182 let outbounds = {
183 let mut map = std::collections::HashMap::new();
184 for (name, outbound_cfg) in &config.server.outbounds {
185 match crate::outbound::Outbound::from_config(name, outbound_cfg) {
186 Ok(outbound) => {
187 info!(name = %name, "outbound connector configured");
188 map.insert(name.clone(), Arc::new(outbound));
189 }
190 Err(e) => {
191 return Err(ServerError::Config(format!("outbound '{name}': {e}")));
192 }
193 }
194 }
195 map
196 };
197
198 #[cfg(feature = "geoip")]
201 #[allow(unused_variables)]
202 let (geoip_server, geoip_metrics, geoip_analytics) =
203 load_geoip_databases(&config, &shutdown).await;
204
205 if let Some(ref listen) = config.metrics.listen {
207 #[cfg(feature = "rules")]
208 let extra_routes = rule_engine
209 .as_ref()
210 .map(|engine| crate::debug_api::debug_routes(engine.clone()));
211 #[cfg(not(feature = "rules"))]
212 let extra_routes: Option<axum::Router> = None;
213
214 match trojan_metrics::init_metrics_server(listen, extra_routes) {
215 Ok(_handle) => {
216 #[cfg(feature = "rules")]
217 let endpoints = if rule_engine.is_some() {
218 "/metrics, /health, /ready, /debug/rules/match"
219 } else {
220 "/metrics, /health, /ready"
221 };
222 #[cfg(not(feature = "rules"))]
223 let endpoints = "/metrics, /health, /ready";
224 info!("metrics server listening on {} ({})", listen, endpoints);
225 }
226 Err(e) => warn!("failed to start metrics server: {}", e),
227 }
228 }
229
230 let tcp_cfg = &config.server.tcp;
232 info!(
233 no_delay = tcp_cfg.no_delay,
234 keepalive_secs = tcp_cfg.keepalive_secs,
235 reuse_port = tcp_cfg.reuse_port,
236 fast_open = tcp_cfg.fast_open,
237 "TCP options configured"
238 );
239
240 let state = Arc::new(ServerState {
241 fallback_addr,
242 max_udp_payload: config.server.max_udp_payload,
243 max_udp_buffer_bytes: config.server.max_udp_buffer_bytes,
244 max_header_bytes: config.server.max_header_bytes,
245 tcp_idle_timeout: Duration::from_secs(config.server.tcp_idle_timeout_secs),
246 udp_idle_timeout: Duration::from_secs(config.server.udp_timeout_secs),
247 fallback_pool,
248 relay_buffer_size,
249 tcp_send_buffer,
250 tcp_recv_buffer,
251 tcp_config: config.server.tcp.clone(),
252 websocket: config.websocket.clone(),
253 dns_resolver,
254 #[cfg(feature = "analytics")]
255 analytics,
256 #[cfg(feature = "rules")]
257 rule_engine,
258 #[cfg(feature = "rules")]
259 outbounds,
260 #[cfg(feature = "geoip")]
261 geoip_metrics,
262 #[cfg(all(feature = "geoip", feature = "analytics"))]
263 geoip_analytics,
264 });
265 let auth = Arc::new(auth);
266 let tracker = ConnectionTracker::new();
267
268 let conn_limit: Option<Arc<Semaphore>> = config.server.max_connections.map(|n| {
270 info!("max_connections set to {}", n);
271 Arc::new(Semaphore::new(n))
272 });
273
274 let rate_limiter: Option<Arc<RateLimiter>> = config.server.rate_limit.as_ref().map(|rl| {
276 info!(
277 max_per_ip = rl.max_connections_per_ip,
278 window_secs = rl.window_secs,
279 "rate limiting enabled"
280 );
281 let limiter = Arc::new(RateLimiter::new(rl.max_connections_per_ip, rl.window_secs));
282 limiter.start_cleanup_task(Duration::from_secs(rl.cleanup_interval_secs));
283 limiter
284 });
285
286 let listener = create_listener(listen, connection_backlog, &config.server.tcp)?;
288 info!(address = %listen, backlog = connection_backlog, "listening");
289
290 #[cfg(feature = "ws")]
291 if config.websocket.enabled && config.websocket.mode == "split" {
292 let ws_listen = config.websocket.listen.clone().unwrap_or_default();
293 let ws_addr: SocketAddr = ws_listen
294 .parse()
295 .map_err(|_| ServerError::Config("invalid websocket.listen address".into()))?;
296 let ws_listener = create_listener(ws_addr, connection_backlog, &config.server.tcp)?;
297 let ws_acceptor = acceptor.clone();
298 let ws_state = state.clone();
299 let ws_auth = auth.clone();
300 let ws_tracker = tracker.clone();
301 let ws_conn_limit = conn_limit.clone();
302 let ws_rate_limiter = rate_limiter.clone();
303 let ws_shutdown = shutdown.clone();
304
305 info!(address = %ws_addr, "websocket split listener started");
306 tokio::spawn(async move {
307 loop {
308 tokio::select! {
309 biased;
310 _ = ws_shutdown.cancelled() => break,
311 result = ws_listener.accept() => {
312 let (tcp, peer) = match result {
313 Ok(v) => v,
314 Err(_) => continue,
315 };
316
317 if let Err(e) = apply_tcp_options(&tcp, &ws_state.tcp_config) {
319 tracing::debug!(error = %e, "failed to apply TCP options");
320 }
321
322 if let Some(ref limiter) = ws_rate_limiter {
323 let ip = peer.ip();
324 if !limiter.check_and_increment(ip) {
325 record_connection_rejected("rate_limit");
326 drop(tcp);
327 continue;
328 }
329 }
330
331 let permit: Option<OwnedSemaphorePermit> = match &ws_conn_limit {
332 Some(sem) => match sem.clone().try_acquire_owned() {
333 Ok(p) => Some(p),
334 Err(_) => {
335 record_connection_rejected("max_connections");
336 drop(tcp);
337 continue;
338 }
339 },
340 None => None,
341 };
342
343 let conn_id = next_conn_id();
344 let acceptor = ws_acceptor.clone();
345 let state = ws_state.clone();
346 let auth = ws_auth.clone();
347 ws_tracker.increment();
348 let guard = ConnectionGuard::new(ws_tracker.clone());
349
350 let span = info_span!("conn", id = conn_id, peer = %peer, transport = "ws");
351 tokio::spawn(
352 async move {
353 let _guard = guard;
354 let _permit = permit;
355 record_connection_accepted();
356 let start = Instant::now();
357
358 let result = async {
359 let tls_start = Instant::now();
360 let tls_timeout =
361 Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
362 match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await
363 {
364 Ok(Ok(tls)) => {
365 let tls_duration = tls_start.elapsed().as_secs_f64();
366 record_tls_handshake_duration(tls_duration);
367 crate::handler::handle_ws_only(tls, state, auth, peer).await
368 }
369 Ok(Err(err)) => {
370 record_error(ERROR_TLS_HANDSHAKE);
371 warn!(error = %err, "TLS handshake failed");
372 Ok(())
373 }
374 Err(_) => {
375 record_error(ERROR_TLS_HANDSHAKE);
376 warn!(
377 timeout_secs = tls_timeout.as_secs(),
378 "TLS handshake timed out"
379 );
380 Ok(())
381 }
382 }
383 }
384 .await;
385
386 let duration_secs = start.elapsed().as_secs_f64();
387 record_connection_closed(duration_secs);
388
389 if let Err(ref err) = result {
390 warn!(error = %err, "connection error");
391 }
392 }
393 .instrument(span),
394 );
395 }
396 }
397 }
398 });
399 }
400
401 #[cfg(not(feature = "ws"))]
402 if config.websocket.enabled {
403 warn!("websocket.enabled=true but ws feature is disabled; ignoring websocket");
404 }
405
406 loop {
407 tokio::select! {
408 biased;
409
410 _ = shutdown.cancelled() => {
411 info!("shutdown signal received, stopping accept loop");
412 break;
413 }
414
415 result = listener.accept() => {
416 let (tcp, peer) = result?;
417
418 if let Err(e) = apply_tcp_options(&tcp, &state.tcp_config) {
420 debug!(error = %e, "failed to apply TCP options");
421 }
422
423 if let Some(ref sem) = conn_limit {
425 let available = sem.available_permits();
426 set_connection_queue_depth(available as f64);
427 }
428
429 if let Some(ref limiter) = rate_limiter {
431 let ip = peer.ip();
432 if !limiter.check_and_increment(ip) {
433 debug!(peer = %peer, reason = "rate_limit", "connection rejected");
434 record_connection_rejected("rate_limit");
435 drop(tcp);
436 continue;
437 }
438 }
439
440 let permit: Option<OwnedSemaphorePermit> = match &conn_limit {
442 Some(sem) => match sem.clone().try_acquire_owned() {
443 Ok(p) => Some(p),
444 Err(_) => {
445 debug!(peer = %peer, reason = "max_connections", "connection rejected");
446 record_connection_rejected("max_connections");
447 drop(tcp); continue;
449 }
450 },
451 None => None,
452 };
453
454 let conn_id = next_conn_id();
455 debug!(conn_id, peer = %peer, "new connection");
456
457 let acceptor = acceptor.clone();
458 let state = state.clone();
459 let auth = auth.clone();
460 tracker.increment();
461 let guard = ConnectionGuard::new(tracker.clone());
462
463 let span = info_span!("conn", id = conn_id, peer = %peer);
464 tokio::spawn(
465 async move {
466 let _guard = guard; let _permit = permit; record_connection_accepted();
469 let start = Instant::now();
470
471 let result = async {
472 let tls_start = Instant::now();
474 let tls_timeout =
475 Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
476 match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await {
477 Ok(Ok(tls)) => {
478 let tls_duration = tls_start.elapsed().as_secs_f64();
479 record_tls_handshake_duration(tls_duration);
480 debug!(duration_ms = tls_duration * 1000.0, "TLS handshake completed");
481 handle_conn(tls, state, auth, peer).await
482 }
483 Ok(Err(err)) => {
484 record_error(ERROR_TLS_HANDSHAKE);
485 warn!(error = %err, "TLS handshake failed");
486 Ok(())
487 }
488 Err(_) => {
489 record_error(ERROR_TLS_HANDSHAKE);
490 warn!(timeout_secs = tls_timeout.as_secs(), "TLS handshake timed out");
491 Ok(())
492 }
493 }
494 }
495 .await;
496
497 let duration_secs = start.elapsed().as_secs_f64();
498 record_connection_closed(duration_secs);
499
500 if let Err(ref err) = result {
501 record_error(err.error_type());
502 warn!(duration_secs, error = %err, "connection closed with error");
503 } else {
504 debug!(duration_secs, "connection closed");
505 }
506 }
507 .instrument(span),
508 );
509 }
510 }
511 }
512
513 if let Some(ref limiter) = rate_limiter {
515 limiter.shutdown();
516 }
517
518 let active = tracker.count();
520 if active > 0 {
521 info!("waiting for {} active connections to drain", active);
522 if tracker.wait_for_zero(DEFAULT_SHUTDOWN_TIMEOUT).await {
523 info!("all connections drained");
524 } else {
525 warn!(
526 "shutdown timeout, {} connections still active",
527 tracker.count()
528 );
529 }
530 }
531
532 info!("server stopped");
533 Ok(())
534}
535
536pub async fn run(config: Config, auth: impl AuthBackend + 'static) -> Result<(), ServerError> {
539 run_with_shutdown(config, auth, CancellationToken::new()).await
540}
541
542#[cfg(feature = "geoip")]
550#[allow(unused_variables)]
551async fn load_geoip_databases(
552 config: &Config,
553 shutdown: &CancellationToken,
554) -> (
555 Option<Arc<trojan_rules::geoip_db::GeoipDb>>,
556 Option<Arc<trojan_rules::geoip_db::GeoipDb>>,
557 Option<Arc<trojan_rules::geoip_db::GeoipDb>>,
558) {
559 use std::collections::HashMap;
560 use trojan_rules::geoip_db::GeoipDb;
561
562 type Key = (Option<String>, Option<String>, String);
564 let mut loaded: HashMap<Key, Arc<GeoipDb>> = HashMap::new();
565
566 let mut auto_update_configs: Vec<(trojan_config::GeoipConfig, Arc<GeoipDb>)> = Vec::new();
568
569 async fn load_or_share(
571 cfg: &trojan_config::GeoipConfig,
572 loaded: &mut HashMap<Key, Arc<GeoipDb>>,
573 ) -> Option<Arc<GeoipDb>> {
574 let key: Key = (cfg.path.clone(), cfg.url.clone(), cfg.source.clone());
575 if let Some(existing) = loaded.get(&key) {
576 return Some(existing.clone());
577 }
578 match trojan_rules::geoip_db::load_geoip(cfg).await {
579 Ok(db) => {
580 let arc = Arc::new(db);
581 loaded.insert(key, arc.clone());
582 Some(arc)
583 }
584 Err(e) => {
585 warn!(source = %cfg.source, error = %e, "failed to load GeoIP database");
586 None
587 }
588 }
589 }
590
591 let server_geoip = if let Some(cfg) = config.server.geoip.as_ref() {
593 load_or_share(cfg, &mut loaded).await
594 } else {
595 None
596 };
597
598 let metrics_geoip = if let Some(cfg) = config.metrics.geoip.as_ref() {
600 let result = load_or_share(cfg, &mut loaded).await;
601 if let Some(ref db) = result
602 && cfg.auto_update
603 && cfg.path.is_none()
604 {
605 auto_update_configs.push((cfg.clone(), db.clone()));
606 }
607 result
608 } else {
609 server_geoip.clone() };
611
612 #[cfg(feature = "analytics")]
614 let analytics_geoip = if let Some(cfg) = config.analytics.geoip.as_ref() {
615 let result = load_or_share(cfg, &mut loaded).await;
616 if let Some(ref db) = result
617 && cfg.auto_update
618 && cfg.path.is_none()
619 {
620 auto_update_configs.push((cfg.clone(), db.clone()));
621 }
622 result
623 } else {
624 None
625 };
626 #[cfg(not(feature = "analytics"))]
627 let analytics_geoip: Option<Arc<GeoipDb>> = None;
628
629 if !loaded.is_empty() {
630 info!(
631 databases = loaded.len(),
632 "GeoIP databases loaded (deduplicated)"
633 );
634 }
635
636 {
638 let mut seen_ptrs = std::collections::HashSet::new();
640 for (cfg, db) in auto_update_configs {
641 let ptr = Arc::as_ptr(&db) as usize;
642 if !seen_ptrs.insert(ptr) {
643 continue; }
645 let cancel = shutdown.clone();
646 let source = cfg.source.clone();
647 info!(source = %source, "spawning GeoIP auto-update task");
648 let swappable = Arc::new(arc_swap::ArcSwap::from(db));
649 tokio::spawn(trojan_rules::geoip_db::geoip_auto_update_task(
650 cfg,
651 swappable,
652 cancel,
653 move |success| {
654 if success {
655 trojan_metrics::record_rule_update();
656 } else {
657 trojan_metrics::record_rule_update_error();
658 }
659 },
660 ));
661 }
662 }
663
664 (server_geoip, metrics_geoip, analytics_geoip)
665}
666
667#[cfg(feature = "rules")]
669async fn rule_update_loop(
670 engine: Arc<trojan_rules::HotRuleEngine>,
671 server_config: trojan_config::ServerConfig,
672 interval_secs: u64,
673 shutdown: CancellationToken,
674) {
675 use std::time::Duration;
676 use trojan_metrics::{record_rule_update, record_rule_update_error};
677
678 match crate::rules::build_rule_engine_async(&server_config).await {
680 Ok(new_engine) => {
681 info!(
682 rule_sets = new_engine.rule_set_count(),
683 rules = new_engine.rule_count(),
684 "initial rule fetch completed, engine updated"
685 );
686 engine.update(new_engine);
687 record_rule_update();
688 }
689 Err(e) => {
690 warn!(error = %e, "initial rule fetch failed, keeping startup rules");
691 record_rule_update_error();
692 }
693 }
694
695 let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
696 interval.tick().await; loop {
699 tokio::select! {
700 biased;
701 _ = shutdown.cancelled() => {
702 debug!("rule update task shutting down");
703 return;
704 }
705 _ = interval.tick() => {
706 debug!("starting scheduled rule update");
707 match crate::rules::build_rule_engine_async(&server_config).await {
708 Ok(new_engine) => {
709 info!(
710 rule_sets = new_engine.rule_set_count(),
711 rules = new_engine.rule_count(),
712 "rule update completed, engine swapped"
713 );
714 engine.update(new_engine);
715 record_rule_update();
716 }
717 Err(e) => {
718 warn!(error = %e, "rule update failed, keeping current rules");
719 record_rule_update_error();
720 }
721 }
722 }
723 }
724 }
725}