Skip to main content

trojan_server/
server.rs

1//! Main server loop and connection handling.
2
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use tokio::sync::{OwnedSemaphorePermit, Semaphore};
9use tokio::time::Instant;
10use tokio_rustls::TlsAcceptor;
11use tokio_util::sync::CancellationToken;
12use tracing::{Instrument, debug, info, info_span, warn};
13
14use crate::error::ServerError;
15use crate::handler::handle_conn;
16use crate::pool::ConnectionPool;
17use crate::rate_limit::RateLimiter;
18use crate::resolve::resolve_sockaddr;
19use crate::state::ServerState;
20use crate::tls::load_tls_config;
21use crate::util::{ConnectionGuard, ConnectionTracker, apply_tcp_options, create_listener};
22use trojan_auth::AuthBackend;
23use trojan_config::Config;
24use trojan_core::defaults;
25use trojan_dns::DnsResolver;
26use trojan_metrics::{
27    ERROR_TLS_HANDSHAKE, record_connection_accepted, record_connection_closed,
28    record_connection_rejected, record_error, record_tls_handshake_duration,
29    set_connection_queue_depth,
30};
31
32/// Default graceful shutdown timeout.
33pub const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
34
35/// Global connection ID counter.
36static CONN_ID: AtomicU64 = AtomicU64::new(1);
37
38/// Generate a unique connection ID.
39#[inline]
40fn next_conn_id() -> u64 {
41    CONN_ID.fetch_add(1, Ordering::Relaxed)
42}
43
44/// Run the server with a cancellation token for graceful shutdown.
45pub async fn run_with_shutdown(
46    config: Config,
47    auth: impl AuthBackend + 'static,
48    shutdown: CancellationToken,
49) -> Result<(), ServerError> {
50    let tls_config = load_tls_config(&config.tls)?;
51    let acceptor = TlsAcceptor::from(Arc::new(tls_config));
52
53    let listen: SocketAddr = config
54        .server
55        .listen
56        .parse()
57        .map_err(|_| ServerError::Config("invalid listen address".into()))?;
58
59    // Build DNS resolver from config.
60    // Backward compatibility: preserve legacy `server.tcp.prefer_ipv4` behavior.
61    let mut dns_config = config.dns.clone();
62    if config.server.tcp.prefer_ipv4 && !dns_config.prefer_ipv4 {
63        dns_config.prefer_ipv4 = true;
64        info!(
65            "server.tcp.prefer_ipv4 is deprecated; mapped to dns.prefer_ipv4 for backward compatibility"
66        );
67    }
68    let dns_resolver = DnsResolver::new(&dns_config)
69        .map_err(|e| ServerError::Config(format!("dns resolver: {e}")))?;
70    info!(
71        dns = ?dns_config.strategy,
72        prefer_ipv4 = dns_config.prefer_ipv4,
73        "dns resolver initialized"
74    );
75
76    let fallback_addr = resolve_sockaddr(&config.server.fallback, &dns_resolver).await?;
77
78    // Initialize fallback connection pool if configured
79    let fallback_pool: Option<Arc<ConnectionPool>> =
80        config.server.fallback_pool.as_ref().map(|pool_cfg| {
81            info!(
82                max_idle = pool_cfg.max_idle,
83                max_age_secs = pool_cfg.max_age_secs,
84                fill_batch = pool_cfg.fill_batch,
85                fill_delay_ms = pool_cfg.fill_delay_ms,
86                "fallback connection pool enabled"
87            );
88            let pool = Arc::new(ConnectionPool::new(
89                fallback_addr,
90                pool_cfg.max_idle,
91                pool_cfg.max_age_secs,
92                pool_cfg.fill_batch,
93                pool_cfg.fill_delay_ms,
94            ));
95            // Use max_age_secs as cleanup interval
96            pool.start_cleanup_task(Duration::from_secs(pool_cfg.max_age_secs));
97            pool
98        });
99
100    // Extract resource limits with defaults
101    let (relay_buffer_size, tcp_send_buffer, tcp_recv_buffer, connection_backlog) =
102        match &config.server.resource_limits {
103            Some(rl) => {
104                info!(
105                    relay_buffer = rl.relay_buffer_size,
106                    tcp_send_buffer = rl.tcp_send_buffer,
107                    tcp_recv_buffer = rl.tcp_recv_buffer,
108                    connection_backlog = rl.connection_backlog,
109                    "resource limits configured"
110                );
111                (
112                    rl.relay_buffer_size,
113                    rl.tcp_send_buffer,
114                    rl.tcp_recv_buffer,
115                    rl.connection_backlog,
116                )
117            }
118            None => (
119                defaults::DEFAULT_RELAY_BUFFER_SIZE,
120                defaults::DEFAULT_TCP_SEND_BUFFER,
121                defaults::DEFAULT_TCP_RECV_BUFFER,
122                defaults::DEFAULT_CONNECTION_BACKLOG,
123            ),
124        };
125
126    // Initialize analytics if feature enabled and configured
127    #[cfg(feature = "analytics")]
128    let analytics = if config.analytics.enabled {
129        match trojan_analytics::init(config.analytics.clone()).await {
130            Ok(collector) => {
131                info!("analytics enabled, sending to ClickHouse");
132                Some(collector)
133            }
134            Err(e) => {
135                warn!("failed to init analytics: {}, disabled", e);
136                None
137            }
138        }
139    } else {
140        debug!("analytics disabled in config");
141        None
142    };
143
144    // Initialize rule engine if feature enabled and rules configured
145    #[cfg(feature = "rules")]
146    let rule_engine = if !config.server.rules.is_empty() {
147        match crate::rules::build_rule_engine(&config.server) {
148            Ok(engine) => {
149                info!(
150                    rule_sets = engine.rule_set_count(),
151                    rules = engine.rule_count(),
152                    "rule engine initialized"
153                );
154                Some(Arc::new(trojan_rules::HotRuleEngine::new(engine)))
155            }
156            Err(e) => {
157                return Err(ServerError::Rules(format!("failed to init rules: {e}")));
158            }
159        }
160    } else {
161        debug!("no routing rules configured");
162        None
163    };
164
165    // Spawn background rule update task for HTTP providers
166    #[cfg(feature = "rules")]
167    if let Some(ref hot_engine) = rule_engine
168        && crate::rules::has_http_providers(&config.server)
169    {
170        let interval_secs = crate::rules::http_update_interval(&config.server).unwrap_or(3600); // default: 1 hour
171        let engine_ref = hot_engine.clone();
172        let server_cfg = config.server.clone();
173        let update_shutdown = shutdown.clone();
174        info!(interval_secs, "starting background rule update task");
175        tokio::spawn(async move {
176            rule_update_loop(engine_ref, server_cfg, interval_secs, update_shutdown).await;
177        });
178    }
179
180    // Build outbound connectors from config
181    #[cfg(feature = "rules")]
182    let outbounds = {
183        let mut map = std::collections::HashMap::new();
184        for (name, outbound_cfg) in &config.server.outbounds {
185            match crate::outbound::Outbound::from_config(name, outbound_cfg) {
186                Ok(outbound) => {
187                    info!(name = %name, "outbound connector configured");
188                    map.insert(name.clone(), Arc::new(outbound));
189                }
190                Err(e) => {
191                    return Err(ServerError::Config(format!("outbound '{name}': {e}")));
192                }
193            }
194        }
195        map
196    };
197
198    // Load GeoIP databases with deduplication.
199    // geoip_server is used indirectly (metrics fallback shares it).
200    #[cfg(feature = "geoip")]
201    #[allow(unused_variables)]
202    let (geoip_server, geoip_metrics, geoip_analytics) =
203        load_geoip_databases(&config, &shutdown).await;
204
205    // Start metrics server (with debug routes if rules feature is enabled)
206    if let Some(ref listen) = config.metrics.listen {
207        #[cfg(feature = "rules")]
208        let extra_routes = rule_engine
209            .as_ref()
210            .map(|engine| crate::debug_api::debug_routes(engine.clone()));
211        #[cfg(not(feature = "rules"))]
212        let extra_routes: Option<axum::Router> = None;
213
214        match trojan_metrics::init_metrics_server(listen, extra_routes) {
215            Ok(_handle) => {
216                #[cfg(feature = "rules")]
217                let endpoints = if rule_engine.is_some() {
218                    "/metrics, /health, /ready, /debug/rules/match"
219                } else {
220                    "/metrics, /health, /ready"
221                };
222                #[cfg(not(feature = "rules"))]
223                let endpoints = "/metrics, /health, /ready";
224                info!("metrics server listening on {} ({})", listen, endpoints);
225            }
226            Err(e) => warn!("failed to start metrics server: {}", e),
227        }
228    }
229
230    // Spawn DDNS update task if enabled
231    #[cfg(feature = "ddns")]
232    if config.ddns.enabled {
233        let ddns_config = config.ddns.clone();
234        let ddns_shutdown = shutdown.clone();
235        info!("starting DDNS update task");
236        tokio::spawn(async move {
237            trojan_ddns::ddns_loop(ddns_config, ddns_shutdown).await;
238        });
239    }
240
241    // Log TCP options
242    let tcp_cfg = &config.server.tcp;
243    info!(
244        no_delay = tcp_cfg.no_delay,
245        keepalive_secs = tcp_cfg.keepalive_secs,
246        reuse_port = tcp_cfg.reuse_port,
247        fast_open = tcp_cfg.fast_open,
248        "TCP options configured"
249    );
250
251    let state = Arc::new(ServerState {
252        fallback_addr,
253        max_udp_payload: config.server.max_udp_payload,
254        max_udp_buffer_bytes: config.server.max_udp_buffer_bytes,
255        max_header_bytes: config.server.max_header_bytes,
256        tcp_idle_timeout: Duration::from_secs(config.server.tcp_idle_timeout_secs),
257        udp_idle_timeout: Duration::from_secs(config.server.udp_timeout_secs),
258        fallback_pool,
259        relay_buffer_size,
260        tcp_send_buffer,
261        tcp_recv_buffer,
262        tcp_config: config.server.tcp.clone(),
263        websocket: config.websocket.clone(),
264        dns_resolver,
265        #[cfg(feature = "analytics")]
266        analytics,
267        #[cfg(feature = "rules")]
268        rule_engine,
269        #[cfg(feature = "rules")]
270        outbounds,
271        #[cfg(feature = "geoip")]
272        geoip_metrics,
273        #[cfg(all(feature = "geoip", feature = "analytics"))]
274        geoip_analytics,
275    });
276    let auth = Arc::new(auth);
277    let tracker = ConnectionTracker::new();
278
279    // Connection limiter (None = unlimited)
280    let conn_limit: Option<Arc<Semaphore>> = config.server.max_connections.map(|n| {
281        info!("max_connections set to {}", n);
282        Arc::new(Semaphore::new(n))
283    });
284
285    // Rate limiter (None = disabled)
286    let rate_limiter: Option<Arc<RateLimiter>> = config.server.rate_limit.as_ref().map(|rl| {
287        info!(
288            max_per_ip = rl.max_connections_per_ip,
289            window_secs = rl.window_secs,
290            "rate limiting enabled"
291        );
292        let limiter = Arc::new(RateLimiter::new(rl.max_connections_per_ip, rl.window_secs));
293        limiter.start_cleanup_task(Duration::from_secs(rl.cleanup_interval_secs));
294        limiter
295    });
296
297    // Create listener with custom backlog and TCP options using socket2
298    let listener = create_listener(listen, connection_backlog, &config.server.tcp)?;
299    info!(address = %listen, backlog = connection_backlog, "listening");
300
301    #[cfg(feature = "ws")]
302    if config.websocket.enabled && config.websocket.mode == "split" {
303        let ws_listen = config.websocket.listen.clone().unwrap_or_default();
304        let ws_addr: SocketAddr = ws_listen
305            .parse()
306            .map_err(|_| ServerError::Config("invalid websocket.listen address".into()))?;
307        let ws_listener = create_listener(ws_addr, connection_backlog, &config.server.tcp)?;
308        let ws_acceptor = acceptor.clone();
309        let ws_state = state.clone();
310        let ws_auth = auth.clone();
311        let ws_tracker = tracker.clone();
312        let ws_conn_limit = conn_limit.clone();
313        let ws_rate_limiter = rate_limiter.clone();
314        let ws_shutdown = shutdown.clone();
315
316        info!(address = %ws_addr, "websocket split listener started");
317        tokio::spawn(async move {
318            loop {
319                tokio::select! {
320                    biased;
321                    _ = ws_shutdown.cancelled() => break,
322                    result = ws_listener.accept() => {
323                        let (tcp, peer) = match result {
324                            Ok(v) => v,
325                            Err(_) => continue,
326                        };
327
328                        // Apply TCP socket options
329                        if let Err(e) = apply_tcp_options(&tcp, &ws_state.tcp_config) {
330                            tracing::debug!(error = %e, "failed to apply TCP options");
331                        }
332
333                        if let Some(ref limiter) = ws_rate_limiter {
334                            let ip = peer.ip();
335                            if !limiter.check_and_increment(ip) {
336                                record_connection_rejected("rate_limit");
337                                drop(tcp);
338                                continue;
339                            }
340                        }
341
342                        let permit: Option<OwnedSemaphorePermit> = match &ws_conn_limit {
343                            Some(sem) => match sem.clone().try_acquire_owned() {
344                                Ok(p) => Some(p),
345                                Err(_) => {
346                                    record_connection_rejected("max_connections");
347                                    drop(tcp);
348                                    continue;
349                                }
350                            },
351                            None => None,
352                        };
353
354                        let conn_id = next_conn_id();
355                        let acceptor = ws_acceptor.clone();
356                        let state = ws_state.clone();
357                        let auth = ws_auth.clone();
358                        ws_tracker.increment();
359                        let guard = ConnectionGuard::new(ws_tracker.clone());
360
361                        let span = info_span!("conn", id = conn_id, peer = %peer, transport = "ws");
362                        tokio::spawn(
363                            async move {
364                                let _guard = guard;
365                                let _permit = permit;
366                                record_connection_accepted();
367                                let start = Instant::now();
368
369                                let result = async {
370                                    let tls_start = Instant::now();
371                                    let tls_timeout =
372                                        Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
373                                    match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await
374                                    {
375                                        Ok(Ok(tls)) => {
376                                            let tls_duration = tls_start.elapsed().as_secs_f64();
377                                            record_tls_handshake_duration(tls_duration);
378                                            crate::handler::handle_ws_only(tls, state, auth, peer).await
379                                        }
380                                        Ok(Err(err)) => {
381                                            record_error(ERROR_TLS_HANDSHAKE);
382                                            warn!(error = %err, "TLS handshake failed");
383                                            Ok(())
384                                        }
385                                        Err(_) => {
386                                            record_error(ERROR_TLS_HANDSHAKE);
387                                            warn!(
388                                                timeout_secs = tls_timeout.as_secs(),
389                                                "TLS handshake timed out"
390                                            );
391                                            Ok(())
392                                        }
393                                    }
394                                }
395                                .await;
396
397                                let duration_secs = start.elapsed().as_secs_f64();
398                                record_connection_closed(duration_secs);
399
400                                if let Err(ref err) = result {
401                                    warn!(error = %err, "connection error");
402                                }
403                            }
404                            .instrument(span),
405                        );
406                    }
407                }
408            }
409        });
410    }
411
412    #[cfg(not(feature = "ws"))]
413    if config.websocket.enabled {
414        warn!("websocket.enabled=true but ws feature is disabled; ignoring websocket");
415    }
416
417    loop {
418        tokio::select! {
419            biased;
420
421            _ = shutdown.cancelled() => {
422                info!("shutdown signal received, stopping accept loop");
423                break;
424            }
425
426            result = listener.accept() => {
427                let (tcp, peer) = result?;
428
429                // Apply TCP socket options (no_delay, keepalive)
430                if let Err(e) = apply_tcp_options(&tcp, &state.tcp_config) {
431                    debug!(error = %e, "failed to apply TCP options");
432                }
433
434                // Update connection queue depth metric (based on semaphore usage)
435                if let Some(ref sem) = conn_limit {
436                    let available = sem.available_permits();
437                    set_connection_queue_depth(available as f64);
438                }
439
440                // Check rate limit first
441                if let Some(ref limiter) = rate_limiter {
442                    let ip = peer.ip();
443                    if !limiter.check_and_increment(ip) {
444                        debug!(peer = %peer, reason = "rate_limit", "connection rejected");
445                        record_connection_rejected("rate_limit");
446                        drop(tcp);
447                        continue;
448                    }
449                }
450
451                // Try to acquire connection permit
452                let permit: Option<OwnedSemaphorePermit> = match &conn_limit {
453                    Some(sem) => match sem.clone().try_acquire_owned() {
454                        Ok(p) => Some(p),
455                        Err(_) => {
456                            debug!(peer = %peer, reason = "max_connections", "connection rejected");
457                            record_connection_rejected("max_connections");
458                            drop(tcp); // close immediately
459                            continue;
460                        }
461                    },
462                    None => None,
463                };
464
465                let conn_id = next_conn_id();
466                debug!(conn_id, peer = %peer, "new connection");
467
468                let acceptor = acceptor.clone();
469                let state = state.clone();
470                let auth = auth.clone();
471                tracker.increment();
472                let guard = ConnectionGuard::new(tracker.clone());
473
474                let span = info_span!("conn", id = conn_id, peer = %peer);
475                tokio::spawn(
476                    async move {
477                        let _guard = guard; // ensure decrement on drop
478                        let _permit = permit; // hold permit until connection closes
479                        record_connection_accepted();
480                        let start = Instant::now();
481
482                        let result = async {
483                            // Measure TLS handshake duration with timeout
484                            let tls_start = Instant::now();
485                            let tls_timeout =
486                                Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
487                            match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await {
488                                Ok(Ok(tls)) => {
489                                    let tls_duration = tls_start.elapsed().as_secs_f64();
490                                    record_tls_handshake_duration(tls_duration);
491                                    debug!(duration_ms = tls_duration * 1000.0, "TLS handshake completed");
492                                    handle_conn(tls, state, auth, peer).await
493                                }
494                                Ok(Err(err)) => {
495                                    record_error(ERROR_TLS_HANDSHAKE);
496                                    warn!(error = %err, "TLS handshake failed");
497                                    Ok(())
498                                }
499                                Err(_) => {
500                                    record_error(ERROR_TLS_HANDSHAKE);
501                                    warn!(timeout_secs = tls_timeout.as_secs(), "TLS handshake timed out");
502                                    Ok(())
503                                }
504                            }
505                        }
506                        .await;
507
508                        let duration_secs = start.elapsed().as_secs_f64();
509                        record_connection_closed(duration_secs);
510
511                        if let Err(ref err) = result {
512                            record_error(err.error_type());
513                            warn!(duration_secs, error = %err, "connection closed with error");
514                        } else {
515                            debug!(duration_secs, "connection closed");
516                        }
517                    }
518                    .instrument(span),
519                );
520            }
521        }
522    }
523
524    // Shutdown rate limiter cleanup task
525    if let Some(ref limiter) = rate_limiter {
526        limiter.shutdown();
527    }
528
529    // Graceful drain: wait for active connections
530    let active = tracker.count();
531    if active > 0 {
532        info!("waiting for {} active connections to drain", active);
533        if tracker.wait_for_zero(DEFAULT_SHUTDOWN_TIMEOUT).await {
534            info!("all connections drained");
535        } else {
536            warn!(
537                "shutdown timeout, {} connections still active",
538                tracker.count()
539            );
540        }
541    }
542
543    info!("server stopped");
544    Ok(())
545}
546
547/// Run the server (blocking until error, no graceful shutdown).
548/// For backward compatibility with existing code.
549pub async fn run(config: Config, auth: impl AuthBackend + 'static) -> Result<(), ServerError> {
550    run_with_shutdown(config, auth, CancellationToken::new()).await
551}
552
553/// Load GeoIP databases from config with deduplication.
554///
555/// Returns `(server_geoip, metrics_geoip, analytics_geoip)`.
556/// If multiple configs point to the same source, the same `Arc` is shared.
557///
558/// Databases can be downloaded from CDN or custom URLs. Auto-update tasks
559/// are spawned for configs with `auto_update = true` and no local `path` set.
560#[cfg(feature = "geoip")]
561#[allow(unused_variables)]
562async fn load_geoip_databases(
563    config: &Config,
564    shutdown: &CancellationToken,
565) -> (
566    Option<Arc<trojan_rules::geoip_db::GeoipDb>>,
567    Option<Arc<trojan_rules::geoip_db::GeoipDb>>,
568    Option<Arc<trojan_rules::geoip_db::GeoipDb>>,
569) {
570    use std::collections::HashMap;
571    use trojan_rules::geoip_db::GeoipDb;
572
573    // Deduplication key: (path, url, source) tuple identifies a unique database
574    type Key = (Option<String>, Option<String>, String);
575    let mut loaded: HashMap<Key, Arc<GeoipDb>> = HashMap::new();
576
577    // Track configs that need auto-update tasks
578    let mut auto_update_configs: Vec<(trojan_config::GeoipConfig, Arc<GeoipDb>)> = Vec::new();
579
580    // Load a single GeoIP config, deduplicating by key
581    async fn load_or_share(
582        cfg: &trojan_config::GeoipConfig,
583        loaded: &mut HashMap<Key, Arc<GeoipDb>>,
584    ) -> Option<Arc<GeoipDb>> {
585        let key: Key = (cfg.path.clone(), cfg.url.clone(), cfg.source.clone());
586        if let Some(existing) = loaded.get(&key) {
587            return Some(existing.clone());
588        }
589        match trojan_rules::geoip_db::load_geoip(cfg).await {
590            Ok(db) => {
591                let arc = Arc::new(db);
592                loaded.insert(key, arc.clone());
593                Some(arc)
594            }
595            Err(e) => {
596                warn!(source = %cfg.source, error = %e, "failed to load GeoIP database");
597                None
598            }
599        }
600    }
601
602    // Server GeoIP (for rule matching — also shared by metrics/analytics)
603    let server_geoip = if let Some(cfg) = config.server.geoip.as_ref() {
604        load_or_share(cfg, &mut loaded).await
605    } else {
606        None
607    };
608
609    // Metrics GeoIP
610    let metrics_geoip = if let Some(cfg) = config.metrics.geoip.as_ref() {
611        let result = load_or_share(cfg, &mut loaded).await;
612        if let Some(ref db) = result
613            && cfg.auto_update
614            && cfg.path.is_none()
615        {
616            auto_update_configs.push((cfg.clone(), db.clone()));
617        }
618        result
619    } else {
620        server_geoip.clone() // fallback to server's GeoIP
621    };
622
623    // Analytics GeoIP
624    #[cfg(feature = "analytics")]
625    let analytics_geoip = if let Some(cfg) = config.analytics.geoip.as_ref() {
626        let result = load_or_share(cfg, &mut loaded).await;
627        if let Some(ref db) = result
628            && cfg.auto_update
629            && cfg.path.is_none()
630        {
631            auto_update_configs.push((cfg.clone(), db.clone()));
632        }
633        result
634    } else {
635        None
636    };
637    #[cfg(not(feature = "analytics"))]
638    let analytics_geoip: Option<Arc<GeoipDb>> = None;
639
640    if !loaded.is_empty() {
641        info!(
642            databases = loaded.len(),
643            "GeoIP databases loaded (deduplicated)"
644        );
645    }
646
647    // Spawn auto-update tasks for configs that need them
648    {
649        // Deduplicate auto-update tasks by Arc pointer identity
650        let mut seen_ptrs = std::collections::HashSet::new();
651        for (cfg, db) in auto_update_configs {
652            let ptr = Arc::as_ptr(&db) as usize;
653            if !seen_ptrs.insert(ptr) {
654                continue; // already spawned for this database
655            }
656            let cancel = shutdown.clone();
657            let source = cfg.source.clone();
658            info!(source = %source, "spawning GeoIP auto-update task");
659            let swappable = Arc::new(arc_swap::ArcSwap::from(db));
660            tokio::spawn(trojan_rules::geoip_db::geoip_auto_update_task(
661                cfg,
662                swappable,
663                cancel,
664                move |success| {
665                    if success {
666                        trojan_metrics::record_rule_update();
667                    } else {
668                        trojan_metrics::record_rule_update_error();
669                    }
670                },
671            ));
672        }
673    }
674
675    (server_geoip, metrics_geoip, analytics_geoip)
676}
677
678/// Background task that periodically re-fetches HTTP rule-sets and hot-swaps the engine.
679#[cfg(feature = "rules")]
680async fn rule_update_loop(
681    engine: Arc<trojan_rules::HotRuleEngine>,
682    server_config: trojan_config::ServerConfig,
683    interval_secs: u64,
684    shutdown: CancellationToken,
685) {
686    use std::time::Duration;
687    use trojan_metrics::{record_rule_update, record_rule_update_error};
688
689    // Initial fetch (immediate) to replace any cache-only startup data
690    match crate::rules::build_rule_engine_async(&server_config).await {
691        Ok(new_engine) => {
692            info!(
693                rule_sets = new_engine.rule_set_count(),
694                rules = new_engine.rule_count(),
695                "initial rule fetch completed, engine updated"
696            );
697            engine.update(new_engine);
698            record_rule_update();
699        }
700        Err(e) => {
701            warn!(error = %e, "initial rule fetch failed, keeping startup rules");
702            record_rule_update_error();
703        }
704    }
705
706    let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
707    interval.tick().await; // consume the immediate tick
708
709    loop {
710        tokio::select! {
711            biased;
712            _ = shutdown.cancelled() => {
713                debug!("rule update task shutting down");
714                return;
715            }
716            _ = interval.tick() => {
717                debug!("starting scheduled rule update");
718                match crate::rules::build_rule_engine_async(&server_config).await {
719                    Ok(new_engine) => {
720                        info!(
721                            rule_sets = new_engine.rule_set_count(),
722                            rules = new_engine.rule_count(),
723                            "rule update completed, engine swapped"
724                        );
725                        engine.update(new_engine);
726                        record_rule_update();
727                    }
728                    Err(e) => {
729                        warn!(error = %e, "rule update failed, keeping current rules");
730                        record_rule_update_error();
731                    }
732                }
733            }
734        }
735    }
736}