Skip to main content

trojan_client/
cli.rs

1//! CLI module for trojan-client.
2
3use std::io;
4use std::path::PathBuf;
5
6use clap::Parser;
7use tokio_util::sync::CancellationToken;
8use tracing::{info, warn};
9use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
10use trojan_config::LoggingConfig;
11
12use crate::config::load_client_config;
13
14/// Trojan client CLI arguments.
15#[derive(Parser, Debug, Clone)]
16#[command(name = "trojan-client", version, about = "Trojan SOCKS5 proxy client")]
17pub struct ClientArgs {
18    /// Config file path (toml/json/jsonc).
19    #[arg(short, long, default_value = "client.toml")]
20    pub config: PathBuf,
21
22    /// Override SOCKS5 listen address.
23    #[arg(short, long)]
24    pub listen: Option<String>,
25
26    /// Override remote trojan server address.
27    #[arg(short, long)]
28    pub remote: Option<String>,
29
30    /// Override password.
31    #[arg(short, long)]
32    pub password: Option<String>,
33
34    /// Skip TLS certificate verification.
35    #[arg(long)]
36    pub skip_verify: bool,
37
38    /// Log level override.
39    #[arg(long)]
40    pub log_level: Option<String>,
41}
42
43/// Run the trojan client with the given CLI arguments.
44pub async fn run(args: ClientArgs) -> Result<(), Box<dyn std::error::Error>> {
45    let mut config = load_client_config(&args.config)?;
46
47    // Apply CLI overrides
48    if let Some(listen) = &args.listen {
49        config.client.listen = listen.clone();
50    }
51    if let Some(remote) = &args.remote {
52        config.client.remote = remote.clone();
53    }
54    if let Some(password) = &args.password {
55        config.client.password = password.clone();
56    }
57    if args.skip_verify {
58        config.client.tls.skip_verify = true;
59    }
60    if let Some(level) = &args.log_level {
61        config.logging.level = Some(level.clone());
62    }
63
64    init_tracing(&config.logging);
65
66    // Graceful shutdown
67    let shutdown = CancellationToken::new();
68    let shutdown_signal = shutdown.clone();
69
70    tokio::spawn(async move {
71        shutdown_signal_handler().await;
72        info!("shutdown signal received");
73        shutdown_signal.cancel();
74    });
75
76    crate::run(config, shutdown).await?;
77    Ok(())
78}
79
80async fn shutdown_signal_handler() {
81    let ctrl_c = async {
82        if let Err(e) = tokio::signal::ctrl_c().await {
83            warn!("failed to listen for Ctrl+C: {e}");
84            std::future::pending::<()>().await;
85        }
86    };
87
88    #[cfg(unix)]
89    let terminate = async {
90        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
91            Ok(mut sig) => {
92                sig.recv().await;
93            }
94            Err(e) => {
95                warn!("failed to listen for SIGTERM: {e}");
96                std::future::pending::<()>().await;
97            }
98        }
99    };
100
101    #[cfg(not(unix))]
102    let terminate = std::future::pending::<()>();
103
104    tokio::select! {
105        _ = ctrl_c => {}
106        _ = terminate => {}
107    }
108}
109
110fn init_tracing(config: &LoggingConfig) {
111    let base_level = config.level.as_deref().unwrap_or("info");
112    let mut filter_str = base_level.to_string();
113
114    for (module, level) in &config.filters {
115        filter_str.push(',');
116        filter_str.push_str(module);
117        filter_str.push('=');
118        filter_str.push_str(level);
119    }
120
121    let filter = EnvFilter::try_new(&filter_str).unwrap_or_else(|_| EnvFilter::new("info"));
122
123    let format = config.format.as_deref().unwrap_or("pretty");
124    let output = config.output.as_deref().unwrap_or("stderr");
125
126    match (format, output) {
127        ("json", "stdout") => {
128            tracing_subscriber::registry()
129                .with(filter)
130                .with(fmt::layer().json().with_writer(io::stdout))
131                .init();
132        }
133        ("json", _) => {
134            tracing_subscriber::registry()
135                .with(filter)
136                .with(fmt::layer().json().with_writer(io::stderr))
137                .init();
138        }
139        ("compact", "stdout") => {
140            tracing_subscriber::registry()
141                .with(filter)
142                .with(fmt::layer().compact().with_writer(io::stdout))
143                .init();
144        }
145        ("compact", _) => {
146            tracing_subscriber::registry()
147                .with(filter)
148                .with(fmt::layer().compact().with_writer(io::stderr))
149                .init();
150        }
151        (_, "stdout") => {
152            tracing_subscriber::registry()
153                .with(filter)
154                .with(fmt::layer().with_writer(io::stdout))
155                .init();
156        }
157        _ => {
158            tracing_subscriber::registry()
159                .with(filter)
160                .with(fmt::layer().with_writer(io::stderr))
161                .init();
162        }
163    }
164}