1use std::io;
7use std::path::PathBuf;
8use std::sync::Arc;
9use std::time::Duration;
10
11use clap::Parser;
12use tracing::{info, warn};
13use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
14
15use trojan_auth::{
16 AuthBackend, MemoryAuth, ReloadableAuth,
17 http::{Codec, HttpAuth, HttpAuthConfig},
18};
19use trojan_config::{CliOverrides, LoggingConfig, apply_overrides, load_config, validate_config};
20
21use crate::{CancellationToken, run_with_shutdown};
22
23#[derive(Parser, Debug, Clone)]
25#[command(name = "trojan-server", version, about = "Trojan server in Rust")]
26pub struct ServerArgs {
27 #[arg(short, long, default_value = "config.toml")]
29 pub config: PathBuf,
30
31 #[command(flatten)]
32 pub overrides: CliOverrides,
33}
34
35pub async fn run(args: ServerArgs) -> Result<(), Box<dyn std::error::Error>> {
40 let mut config = load_config(&args.config)?;
41 apply_overrides(&mut config, &args.overrides);
42 validate_config(&config)?;
43
44 init_tracing(&config.logging);
45
46 let shutdown = CancellationToken::new();
48 let shutdown_signal = shutdown.clone();
49
50 tokio::spawn(async move {
51 shutdown_signal_handler().await;
52 info!("shutdown signal received");
53 shutdown_signal.cancel();
54 });
55
56 let auth = Arc::new(ReloadableAuth::new(build_auth(&config.auth)));
58
59 #[cfg(unix)]
61 {
62 let config_path = args.config.clone();
63 let overrides = args.overrides.clone();
64 let auth_reload = auth.clone();
65 tokio::spawn(async move {
66 reload_signal_handler(config_path, overrides, auth_reload).await;
67 });
68 }
69
70 run_with_shutdown(config, auth, shutdown).await?;
71 Ok(())
72}
73
74async fn shutdown_signal_handler() {
76 let ctrl_c = async {
77 if let Err(e) = tokio::signal::ctrl_c().await {
78 warn!("failed to listen for Ctrl+C: {}", e);
79 std::future::pending::<()>().await;
81 }
82 };
83
84 #[cfg(unix)]
85 let terminate = async {
86 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
87 Ok(mut sig) => {
88 sig.recv().await;
89 }
90 Err(e) => {
91 warn!("failed to listen for SIGTERM: {}", e);
92 std::future::pending::<()>().await;
94 }
95 }
96 };
97
98 #[cfg(not(unix))]
99 let terminate = std::future::pending::<()>();
100
101 tokio::select! {
102 _ = ctrl_c => {}
103 _ = terminate => {}
104 }
105}
106
107#[cfg(unix)]
109async fn reload_signal_handler(
110 config_path: PathBuf,
111 overrides: CliOverrides,
112 auth: Arc<ReloadableAuth>,
113) {
114 use tokio::signal::unix::{SignalKind, signal};
115
116 let mut sighup = match signal(SignalKind::hangup()) {
117 Ok(sig) => sig,
118 Err(e) => {
119 warn!(
120 "failed to install SIGHUP handler: {}, config reload disabled",
121 e
122 );
123 return;
124 }
125 };
126
127 loop {
128 sighup.recv().await;
129 info!("SIGHUP received, reloading configuration");
130
131 match reload_config(&config_path, &overrides, &auth) {
132 Ok(()) => info!("configuration reloaded successfully"),
133 Err(e) => warn!("failed to reload configuration: {}", e),
134 }
135 }
136}
137
138#[cfg(unix)]
140fn reload_config(
141 config_path: &PathBuf,
142 overrides: &CliOverrides,
143 auth: &Arc<ReloadableAuth>,
144) -> Result<(), Box<dyn std::error::Error>> {
145 let mut config = load_config(config_path)?;
146 apply_overrides(&mut config, overrides);
147 validate_config(&config)?;
148
149 let new_auth = build_auth(&config.auth);
151 auth.reload(new_auth);
152 info!(
153 password_count = config.auth.passwords.len(),
154 user_count = config.auth.users.len(),
155 http = config.auth.http_url.is_some(),
156 "auth reloaded"
157 );
158
159 Ok(())
163}
164
165fn build_auth(auth: &trojan_config::AuthConfig) -> Box<dyn AuthBackend> {
170 if let Some(ref url) = auth.http_url {
171 let codec = match auth.http_codec.as_deref() {
172 Some("json") => Codec::Json,
173 _ => Codec::Bincode,
174 };
175 info!(
176 url = %url,
177 codec = ?codec,
178 cache_ttl = auth.http_cache_ttl_secs,
179 stale_ttl = auth.http_cache_stale_ttl_secs,
180 neg_cache_ttl = auth.http_cache_neg_ttl_secs,
181 "using HTTP auth backend"
182 );
183 let config = HttpAuthConfig {
184 base_url: url.clone(),
185 codec,
186 node_token: auth.http_node_token.clone(),
187 cache_ttl: Duration::from_secs(auth.http_cache_ttl_secs),
188 stale_ttl: Duration::from_secs(auth.http_cache_stale_ttl_secs),
189 neg_cache_ttl: Duration::from_secs(auth.http_cache_neg_ttl_secs),
190 };
191 Box::new(HttpAuth::new(config))
192 } else {
193 let mut mem = MemoryAuth::new();
194 for pw in &auth.passwords {
195 mem.add_password(pw, None);
196 }
197 for u in &auth.users {
198 mem.add_password(&u.password, Some(u.id.clone()));
199 }
200 Box::new(mem)
201 }
202}
203
204fn init_tracing(config: &LoggingConfig) {
212 let base_level = config.level.as_deref().unwrap_or("info");
214 let mut filter_str = base_level.to_string();
215
216 for (module, level) in &config.filters {
217 filter_str.push(',');
218 filter_str.push_str(module);
219 filter_str.push('=');
220 filter_str.push_str(level);
221 }
222
223 let filter = EnvFilter::try_new(&filter_str).unwrap_or_else(|_| EnvFilter::new("info"));
224
225 let format = config.format.as_deref().unwrap_or("pretty");
226 let output = config.output.as_deref().unwrap_or("stderr");
227
228 match (format, output) {
230 ("json", "stdout") => {
231 tracing_subscriber::registry()
232 .with(filter)
233 .with(fmt::layer().json().with_writer(io::stdout))
234 .init();
235 }
236 ("json", _) => {
237 tracing_subscriber::registry()
238 .with(filter)
239 .with(fmt::layer().json().with_writer(io::stderr))
240 .init();
241 }
242 ("compact", "stdout") => {
243 tracing_subscriber::registry()
244 .with(filter)
245 .with(fmt::layer().compact().with_writer(io::stdout))
246 .init();
247 }
248 ("compact", _) => {
249 tracing_subscriber::registry()
250 .with(filter)
251 .with(fmt::layer().compact().with_writer(io::stderr))
252 .init();
253 }
254 (_, "stdout") => {
255 tracing_subscriber::registry()
257 .with(filter)
258 .with(fmt::layer().with_writer(io::stdout))
259 .init();
260 }
261 _ => {
262 tracing_subscriber::registry()
264 .with(filter)
265 .with(fmt::layer().with_writer(io::stderr))
266 .init();
267 }
268 }
269}