1use std::io;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10use clap::Parser;
11use tracing::{info, warn};
12use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
13use trojan_auth::{MemoryAuth, ReloadableAuth};
14use trojan_config::{CliOverrides, LoggingConfig, apply_overrides, load_config, validate_config};
15
16use crate::{CancellationToken, run_with_shutdown};
17
18#[derive(Parser, Debug, Clone)]
20#[command(name = "trojan-server", version, about = "Trojan server in Rust")]
21pub struct ServerArgs {
22 #[arg(short, long, default_value = "config.toml")]
24 pub config: PathBuf,
25
26 #[command(flatten)]
27 pub overrides: CliOverrides,
28}
29
30pub async fn run(args: ServerArgs) -> Result<(), Box<dyn std::error::Error>> {
35 let mut config = load_config(&args.config)?;
36 apply_overrides(&mut config, &args.overrides);
37 validate_config(&config)?;
38
39 init_tracing(&config.logging);
40
41 let shutdown = CancellationToken::new();
43 let shutdown_signal = shutdown.clone();
44
45 tokio::spawn(async move {
46 shutdown_signal_handler().await;
47 info!("shutdown signal received");
48 shutdown_signal.cancel();
49 });
50
51 let auth = Arc::new(ReloadableAuth::new(build_memory_auth(&config.auth)));
53
54 #[cfg(unix)]
56 {
57 let config_path = args.config.clone();
58 let overrides = args.overrides.clone();
59 let auth_reload = auth.clone();
60 tokio::spawn(async move {
61 reload_signal_handler(config_path, overrides, auth_reload).await;
62 });
63 }
64
65 run_with_shutdown(config, auth, shutdown).await?;
66 Ok(())
67}
68
69async fn shutdown_signal_handler() {
71 let ctrl_c = async {
72 if let Err(e) = tokio::signal::ctrl_c().await {
73 warn!("failed to listen for Ctrl+C: {}", e);
74 std::future::pending::<()>().await;
76 }
77 };
78
79 #[cfg(unix)]
80 let terminate = async {
81 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
82 Ok(mut sig) => {
83 sig.recv().await;
84 }
85 Err(e) => {
86 warn!("failed to listen for SIGTERM: {}", e);
87 std::future::pending::<()>().await;
89 }
90 }
91 };
92
93 #[cfg(not(unix))]
94 let terminate = std::future::pending::<()>();
95
96 tokio::select! {
97 _ = ctrl_c => {}
98 _ = terminate => {}
99 }
100}
101
102#[cfg(unix)]
104async fn reload_signal_handler(
105 config_path: PathBuf,
106 overrides: CliOverrides,
107 auth: Arc<ReloadableAuth>,
108) {
109 use tokio::signal::unix::{SignalKind, signal};
110
111 let mut sighup = match signal(SignalKind::hangup()) {
112 Ok(sig) => sig,
113 Err(e) => {
114 warn!(
115 "failed to install SIGHUP handler: {}, config reload disabled",
116 e
117 );
118 return;
119 }
120 };
121
122 loop {
123 sighup.recv().await;
124 info!("SIGHUP received, reloading configuration");
125
126 match reload_config(&config_path, &overrides, &auth) {
127 Ok(()) => info!("configuration reloaded successfully"),
128 Err(e) => warn!("failed to reload configuration: {}", e),
129 }
130 }
131}
132
133#[cfg(unix)]
135fn reload_config(
136 config_path: &PathBuf,
137 overrides: &CliOverrides,
138 auth: &Arc<ReloadableAuth>,
139) -> Result<(), Box<dyn std::error::Error>> {
140 let mut config = load_config(config_path)?;
141 apply_overrides(&mut config, overrides);
142 validate_config(&config)?;
143
144 let new_auth = build_memory_auth(&config.auth);
146 auth.reload(new_auth);
147 info!(
148 password_count = config.auth.passwords.len(),
149 user_count = config.auth.users.len(),
150 "auth reloaded"
151 );
152
153 Ok(())
157}
158
159fn build_memory_auth(auth: &trojan_config::AuthConfig) -> MemoryAuth {
161 let mut mem = MemoryAuth::new();
162 for pw in &auth.passwords {
163 mem.add_password(pw, None);
164 }
165 for u in &auth.users {
166 mem.add_password(&u.password, Some(u.id.clone()));
167 }
168 mem
169}
170
171fn init_tracing(config: &LoggingConfig) {
179 let base_level = config.level.as_deref().unwrap_or("info");
181 let mut filter_str = base_level.to_string();
182
183 for (module, level) in &config.filters {
184 filter_str.push(',');
185 filter_str.push_str(module);
186 filter_str.push('=');
187 filter_str.push_str(level);
188 }
189
190 let filter = EnvFilter::try_new(&filter_str).unwrap_or_else(|_| EnvFilter::new("info"));
191
192 let format = config.format.as_deref().unwrap_or("pretty");
193 let output = config.output.as_deref().unwrap_or("stderr");
194
195 match (format, output) {
197 ("json", "stdout") => {
198 tracing_subscriber::registry()
199 .with(filter)
200 .with(fmt::layer().json().with_writer(io::stdout))
201 .init();
202 }
203 ("json", _) => {
204 tracing_subscriber::registry()
205 .with(filter)
206 .with(fmt::layer().json().with_writer(io::stderr))
207 .init();
208 }
209 ("compact", "stdout") => {
210 tracing_subscriber::registry()
211 .with(filter)
212 .with(fmt::layer().compact().with_writer(io::stdout))
213 .init();
214 }
215 ("compact", _) => {
216 tracing_subscriber::registry()
217 .with(filter)
218 .with(fmt::layer().compact().with_writer(io::stderr))
219 .init();
220 }
221 (_, "stdout") => {
222 tracing_subscriber::registry()
224 .with(filter)
225 .with(fmt::layer().with_writer(io::stdout))
226 .init();
227 }
228 _ => {
229 tracing_subscriber::registry()
231 .with(filter)
232 .with(fmt::layer().with_writer(io::stderr))
233 .init();
234 }
235 }
236}