Skip to main content

trojan_server/
cli.rs

1//! CLI module for trojan-server.
2//!
3//! This module provides the command-line interface that can be used either
4//! as a standalone binary or as a subcommand of the main trojan-rs CLI.
5
6use std::io;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10use clap::Parser;
11use tracing::{info, warn};
12use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
13use trojan_auth::{MemoryAuth, ReloadableAuth};
14use trojan_config::{CliOverrides, LoggingConfig, apply_overrides, load_config, validate_config};
15
16use crate::{CancellationToken, run_with_shutdown};
17
18/// Trojan server CLI arguments.
19#[derive(Parser, Debug, Clone)]
20#[command(name = "trojan-server", version, about = "Trojan server in Rust")]
21pub struct ServerArgs {
22    /// Config file path (json/jsonc/yaml/toml)
23    #[arg(short, long, default_value = "config.toml")]
24    pub config: PathBuf,
25
26    #[command(flatten)]
27    pub overrides: CliOverrides,
28}
29
30/// Run the trojan server with the given arguments.
31///
32/// This is the main entry point for the server CLI, used by both the
33/// standalone binary and the unified trojan-rs CLI.
34pub async fn run(args: ServerArgs) -> Result<(), Box<dyn std::error::Error>> {
35    let mut config = load_config(&args.config)?;
36    apply_overrides(&mut config, &args.overrides);
37    validate_config(&config)?;
38
39    init_tracing(&config.logging);
40
41    // Set up graceful shutdown on SIGTERM/SIGINT
42    let shutdown = CancellationToken::new();
43    let shutdown_signal = shutdown.clone();
44
45    tokio::spawn(async move {
46        shutdown_signal_handler().await;
47        info!("shutdown signal received");
48        shutdown_signal.cancel();
49    });
50
51    // Create reloadable auth backend
52    let auth = Arc::new(ReloadableAuth::new(build_memory_auth(&config.auth)));
53
54    // Set up SIGHUP handler for config reload
55    #[cfg(unix)]
56    {
57        let config_path = args.config.clone();
58        let overrides = args.overrides.clone();
59        let auth_reload = auth.clone();
60        tokio::spawn(async move {
61            reload_signal_handler(config_path, overrides, auth_reload).await;
62        });
63    }
64
65    run_with_shutdown(config, auth, shutdown).await?;
66    Ok(())
67}
68
69/// Wait for shutdown signals (SIGTERM, SIGINT).
70async fn shutdown_signal_handler() {
71    let ctrl_c = async {
72        if let Err(e) = tokio::signal::ctrl_c().await {
73            warn!("failed to listen for Ctrl+C: {}", e);
74            // Fall back to waiting forever
75            std::future::pending::<()>().await;
76        }
77    };
78
79    #[cfg(unix)]
80    let terminate = async {
81        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
82            Ok(mut sig) => {
83                sig.recv().await;
84            }
85            Err(e) => {
86                warn!("failed to listen for SIGTERM: {}", e);
87                // Fall back to waiting forever
88                std::future::pending::<()>().await;
89            }
90        }
91    };
92
93    #[cfg(not(unix))]
94    let terminate = std::future::pending::<()>();
95
96    tokio::select! {
97        _ = ctrl_c => {}
98        _ = terminate => {}
99    }
100}
101
102/// Handle SIGHUP for config reload (Unix only).
103#[cfg(unix)]
104async fn reload_signal_handler(
105    config_path: PathBuf,
106    overrides: CliOverrides,
107    auth: Arc<ReloadableAuth>,
108) {
109    use tokio::signal::unix::{SignalKind, signal};
110
111    let mut sighup = match signal(SignalKind::hangup()) {
112        Ok(sig) => sig,
113        Err(e) => {
114            warn!(
115                "failed to install SIGHUP handler: {}, config reload disabled",
116                e
117            );
118            return;
119        }
120    };
121
122    loop {
123        sighup.recv().await;
124        info!("SIGHUP received, reloading configuration");
125
126        match reload_config(&config_path, &overrides, &auth) {
127            Ok(()) => info!("configuration reloaded successfully"),
128            Err(e) => warn!("failed to reload configuration: {}", e),
129        }
130    }
131}
132
133/// Reload configuration from file.
134#[cfg(unix)]
135fn reload_config(
136    config_path: &PathBuf,
137    overrides: &CliOverrides,
138    auth: &Arc<ReloadableAuth>,
139) -> Result<(), Box<dyn std::error::Error>> {
140    let mut config = load_config(config_path)?;
141    apply_overrides(&mut config, overrides);
142    validate_config(&config)?;
143
144    // Reload auth passwords + users
145    let new_auth = build_memory_auth(&config.auth);
146    auth.reload(new_auth);
147    info!(
148        password_count = config.auth.passwords.len(),
149        user_count = config.auth.users.len(),
150        "auth reloaded"
151    );
152
153    // Note: TLS certificates and other settings require server restart
154    // Future enhancement: implement TLS cert hot-reload via rustls ResolvesServerCert
155
156    Ok(())
157}
158
159/// Build a `MemoryAuth` from both `passwords` and `users` in the config.
160fn build_memory_auth(auth: &trojan_config::AuthConfig) -> MemoryAuth {
161    let mut mem = MemoryAuth::new();
162    for pw in &auth.passwords {
163        mem.add_password(pw, None);
164    }
165    for u in &auth.users {
166        mem.add_password(&u.password, Some(u.id.clone()));
167    }
168    mem
169}
170
171/// Initialize tracing subscriber with the given logging configuration.
172///
173/// Supports:
174/// - `level`: Base log level (trace, debug, info, warn, error)
175/// - `format`: Output format (json, pretty, compact). Default: pretty
176/// - `output`: Output target (stdout, stderr). Default: stderr
177/// - `filters`: Per-module log level overrides
178fn init_tracing(config: &LoggingConfig) {
179    // Build the env filter from base level and per-module filters
180    let base_level = config.level.as_deref().unwrap_or("info");
181    let mut filter_str = base_level.to_string();
182
183    for (module, level) in &config.filters {
184        filter_str.push(',');
185        filter_str.push_str(module);
186        filter_str.push('=');
187        filter_str.push_str(level);
188    }
189
190    let filter = EnvFilter::try_new(&filter_str).unwrap_or_else(|_| EnvFilter::new("info"));
191
192    let format = config.format.as_deref().unwrap_or("pretty");
193    let output = config.output.as_deref().unwrap_or("stderr");
194
195    // Create the subscriber based on format and output
196    match (format, output) {
197        ("json", "stdout") => {
198            tracing_subscriber::registry()
199                .with(filter)
200                .with(fmt::layer().json().with_writer(io::stdout))
201                .init();
202        }
203        ("json", _) => {
204            tracing_subscriber::registry()
205                .with(filter)
206                .with(fmt::layer().json().with_writer(io::stderr))
207                .init();
208        }
209        ("compact", "stdout") => {
210            tracing_subscriber::registry()
211                .with(filter)
212                .with(fmt::layer().compact().with_writer(io::stdout))
213                .init();
214        }
215        ("compact", _) => {
216            tracing_subscriber::registry()
217                .with(filter)
218                .with(fmt::layer().compact().with_writer(io::stderr))
219                .init();
220        }
221        (_, "stdout") => {
222            // pretty is default
223            tracing_subscriber::registry()
224                .with(filter)
225                .with(fmt::layer().with_writer(io::stdout))
226                .init();
227        }
228        _ => {
229            // pretty to stderr is default
230            tracing_subscriber::registry()
231                .with(filter)
232                .with(fmt::layer().with_writer(io::stderr))
233                .init();
234        }
235    }
236}