Skip to main content

trojan_metrics/
lib.rs

1//! Metrics collection and Prometheus exporter for trojan-rs.
2//!
3//! This module provides metrics instrumentation for the trojan server,
4//! including connection counts, bytes transferred, and error rates.
5
6use std::net::SocketAddr;
7
8use axum::{Router, http::StatusCode, response::IntoResponse, routing::get};
9use metrics::{counter, gauge, histogram};
10use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
11
12/// Initialize metrics server with Prometheus exporter and health check endpoints.
13///
14/// Starts an HTTP server on the given address with:
15/// - `/metrics` - Prometheus metrics endpoint
16/// - `/health` - Liveness probe (always returns 200 OK)
17/// - `/ready` - Readiness probe (always returns 200 READY)
18///
19/// Additional routes can be merged via the `extra_routes` parameter.
20///
21/// Returns a tokio JoinHandle for the server task.
22pub fn init_metrics_server(
23    listen: &str,
24    extra_routes: Option<Router>,
25) -> Result<tokio::task::JoinHandle<()>, String> {
26    let addr: SocketAddr = listen
27        .parse()
28        .map_err(|e| format!("invalid metrics listen address: {}", e))?;
29
30    // Build the Prometheus recorder and get a handle for rendering
31    let builder = PrometheusBuilder::new();
32    let handle = builder
33        .install_recorder()
34        .map_err(|e| format!("failed to install prometheus recorder: {}", e))?;
35
36    // Build the Axum router with metrics and health endpoints
37    let mut app = Router::new()
38        .route("/metrics", get(move || metrics_handler(handle.clone())))
39        .route("/health", get(health_handler))
40        .route("/ready", get(ready_handler));
41
42    if let Some(extra) = extra_routes {
43        app = app.merge(extra);
44    }
45
46    // Spawn the server
47    let server_handle = tokio::spawn(async move {
48        let listener = match tokio::net::TcpListener::bind(addr).await {
49            Ok(l) => l,
50            Err(e) => {
51                eprintln!("failed to bind metrics server to {}: {}", addr, e);
52                return;
53            }
54        };
55        if let Err(e) = axum::serve(
56            listener,
57            app.into_make_service_with_connect_info::<SocketAddr>(),
58        )
59        .await
60        {
61            eprintln!("metrics server error: {}", e);
62        }
63    });
64
65    Ok(server_handle)
66}
67
68/// Handler for /metrics endpoint - returns Prometheus format metrics.
69async fn metrics_handler(handle: PrometheusHandle) -> impl IntoResponse {
70    handle.render()
71}
72
73/// Handler for /health endpoint - liveness probe.
74async fn health_handler() -> impl IntoResponse {
75    (StatusCode::OK, "OK")
76}
77
78/// Handler for /ready endpoint - readiness probe.
79async fn ready_handler() -> impl IntoResponse {
80    (StatusCode::OK, "READY")
81}
82
83/// Initialize Prometheus metrics exporter (legacy function).
84///
85/// Starts an HTTP server on the given address to expose metrics.
86/// Returns an error message if binding fails.
87#[deprecated(since = "0.2.0", note = "Use init_metrics_server instead")]
88pub fn init_prometheus(listen: &str) -> Result<(), String> {
89    let addr: SocketAddr = listen
90        .parse()
91        .map_err(|e| format!("invalid metrics listen address: {}", e))?;
92
93    PrometheusBuilder::new()
94        .with_http_listener(addr)
95        .install()
96        .map_err(|e| format!("failed to install prometheus exporter: {}", e))?;
97
98    Ok(())
99}
100
101// ============================================================================
102// Metric Names
103// ============================================================================
104
105/// Total number of TCP connections accepted.
106pub const CONNECTIONS_TOTAL: &str = "trojan_connections_total";
107/// Number of currently active connections.
108pub const CONNECTIONS_ACTIVE: &str = "trojan_connections_active";
109/// Total number of successful authentications.
110pub const AUTH_SUCCESS_TOTAL: &str = "trojan_auth_success_total";
111/// Total number of failed authentications.
112pub const AUTH_FAILURE_TOTAL: &str = "trojan_auth_failure_total";
113/// Total number of fallback connections (non-trojan traffic).
114pub const FALLBACK_TOTAL: &str = "trojan_fallback_total";
115/// Total bytes received from clients.
116pub const BYTES_RECEIVED_TOTAL: &str = "trojan_bytes_received_total";
117/// Total bytes sent to clients.
118pub const BYTES_SENT_TOTAL: &str = "trojan_bytes_sent_total";
119/// Total number of CONNECT requests.
120pub const CONNECT_REQUESTS_TOTAL: &str = "trojan_connect_requests_total";
121/// Total number of UDP associate requests.
122pub const UDP_ASSOCIATE_REQUESTS_TOTAL: &str = "trojan_udp_associate_requests_total";
123/// Total number of UDP packets relayed.
124pub const UDP_PACKETS_TOTAL: &str = "trojan_udp_packets_total";
125/// Connection duration histogram (seconds).
126pub const CONNECTION_DURATION_SECONDS: &str = "trojan_connection_duration_seconds";
127/// Total number of errors by type.
128pub const ERRORS_TOTAL: &str = "trojan_errors_total";
129/// Total number of connections rejected (rate limit, max connections).
130pub const CONNECTIONS_REJECTED_TOTAL: &str = "trojan_connections_rejected_total";
131/// TLS handshake duration histogram (seconds).
132pub const TLS_HANDSHAKE_DURATION_SECONDS: &str = "trojan_tls_handshake_duration_seconds";
133/// Connection queue depth (pending connections in accept backlog).
134pub const CONNECTION_QUEUE_DEPTH: &str = "trojan_connection_queue_depth";
135/// Per-target connection counts (by destination).
136pub const TARGET_CONNECTIONS_TOTAL: &str = "trojan_target_connections_total";
137/// Per-target bytes transferred.
138pub const TARGET_BYTES_TOTAL: &str = "trojan_target_bytes_total";
139/// Current size of the fallback warm pool.
140pub const FALLBACK_POOL_SIZE: &str = "trojan_fallback_pool_size";
141/// Total number of warm-fill connection failures.
142pub const FALLBACK_POOL_WARM_FAIL_TOTAL: &str = "trojan_fallback_pool_warm_fail_total";
143/// DNS resolution duration histogram (seconds).
144pub const DNS_RESOLVE_DURATION_SECONDS: &str = "trojan_dns_resolve_duration_seconds";
145/// Target connection establishment duration histogram (seconds).
146pub const TARGET_CONNECT_DURATION_SECONDS: &str = "trojan_target_connect_duration_seconds";
147/// Total number of successful rule engine updates (hot-reload).
148pub const RULE_UPDATES_TOTAL: &str = "trojan_rule_updates_total";
149/// Total number of failed rule engine update attempts (hot-reload).
150pub const RULE_UPDATE_ERRORS_TOTAL: &str = "trojan_rule_update_errors_total";
151/// Total connections by source country.
152pub const CONNECTIONS_BY_COUNTRY: &str = "trojan_connections_by_country_total";
153/// Total bytes by source country and direction.
154pub const BYTES_BY_COUNTRY: &str = "trojan_bytes_by_country_total";
155/// Total auth failures by source country.
156pub const AUTH_FAILURE_BY_COUNTRY: &str = "trojan_auth_failure_by_country_total";
157
158// ============================================================================
159// Metric Recording Functions
160// ============================================================================
161
162/// Record a new connection accepted.
163#[inline]
164pub fn record_connection_accepted() {
165    counter!(CONNECTIONS_TOTAL).increment(1);
166    gauge!(CONNECTIONS_ACTIVE).increment(1.0);
167}
168
169/// Record a connection closed.
170#[inline]
171pub fn record_connection_closed(duration_secs: f64) {
172    gauge!(CONNECTIONS_ACTIVE).decrement(1.0);
173    histogram!(CONNECTION_DURATION_SECONDS).record(duration_secs);
174}
175
176/// Record successful authentication.
177#[inline]
178pub fn record_auth_success() {
179    counter!(AUTH_SUCCESS_TOTAL).increment(1);
180}
181
182/// Record failed authentication (triggers fallback).
183#[inline]
184pub fn record_auth_failure() {
185    counter!(AUTH_FAILURE_TOTAL).increment(1);
186}
187
188/// Record fallback to HTTP backend.
189#[inline]
190pub fn record_fallback() {
191    counter!(FALLBACK_TOTAL).increment(1);
192}
193
194/// Record bytes received from client.
195#[inline]
196pub fn record_bytes_received(bytes: u64) {
197    counter!(BYTES_RECEIVED_TOTAL).increment(bytes);
198}
199
200/// Record bytes sent to client.
201#[inline]
202pub fn record_bytes_sent(bytes: u64) {
203    counter!(BYTES_SENT_TOTAL).increment(bytes);
204}
205
206/// Record a CONNECT request.
207#[inline]
208pub fn record_connect_request() {
209    counter!(CONNECT_REQUESTS_TOTAL).increment(1);
210}
211
212/// Record a UDP associate request.
213#[inline]
214pub fn record_udp_associate_request() {
215    counter!(UDP_ASSOCIATE_REQUESTS_TOTAL).increment(1);
216}
217
218/// Record UDP packets relayed (direction: "inbound" or "outbound").
219#[inline]
220pub fn record_udp_packet(direction: &'static str) {
221    counter!(UDP_PACKETS_TOTAL, "direction" => direction).increment(1);
222}
223
224/// Record an error by type.
225#[inline]
226pub fn record_error(error_type: &'static str) {
227    counter!(ERRORS_TOTAL, "type" => error_type).increment(1);
228}
229
230/// Record a rejected connection (reason: "max_connections", "rate_limit").
231#[inline]
232pub fn record_connection_rejected(reason: &'static str) {
233    counter!(CONNECTIONS_REJECTED_TOTAL, "reason" => reason).increment(1);
234}
235
236/// Record TLS handshake duration.
237#[inline]
238pub fn record_tls_handshake_duration(duration_secs: f64) {
239    histogram!(TLS_HANDSHAKE_DURATION_SECONDS).record(duration_secs);
240}
241
242/// Set connection queue depth gauge.
243#[inline]
244pub fn set_connection_queue_depth(depth: f64) {
245    gauge!(CONNECTION_QUEUE_DEPTH).set(depth);
246}
247
248/// Record a connection to a target (by destination host).
249/// The target should be sanitized (e.g., IP address or domain without port).
250/// Note: This function allocates a String for the label. For hot paths with repeated calls,
251/// consider caching the String at the call site.
252#[inline]
253pub fn record_target_connection(target: &str) {
254    counter!(TARGET_CONNECTIONS_TOTAL, "target" => target.to_owned()).increment(1);
255}
256
257/// Record bytes transferred to/from a target.
258/// Direction: "sent" or "received".
259/// Note: This function allocates a String for the label. For hot paths with repeated calls,
260/// consider caching the String at the call site.
261#[inline]
262pub fn record_target_bytes(target: &str, direction: &'static str, bytes: u64) {
263    counter!(TARGET_BYTES_TOTAL, "target" => target.to_owned(), "direction" => direction)
264        .increment(bytes);
265}
266
267/// Set current fallback pool size.
268#[inline]
269pub fn set_fallback_pool_size(size: usize) {
270    gauge!(FALLBACK_POOL_SIZE).set(size as f64);
271}
272
273/// Record warm-fill connection failure.
274#[inline]
275pub fn record_fallback_pool_warm_fail() {
276    counter!(FALLBACK_POOL_WARM_FAIL_TOTAL).increment(1);
277}
278
279/// Record DNS resolution duration.
280#[inline]
281pub fn record_dns_resolve_duration(duration_secs: f64) {
282    histogram!(DNS_RESOLVE_DURATION_SECONDS).record(duration_secs);
283}
284
285/// Record target connection establishment duration.
286#[inline]
287pub fn record_target_connect_duration(duration_secs: f64) {
288    histogram!(TARGET_CONNECT_DURATION_SECONDS).record(duration_secs);
289}
290
291/// Record a successful rule engine update (hot-reload).
292#[inline]
293pub fn record_rule_update() {
294    counter!(RULE_UPDATES_TOTAL).increment(1);
295}
296
297/// Record a failed rule engine update attempt (hot-reload).
298#[inline]
299pub fn record_rule_update_error() {
300    counter!(RULE_UPDATE_ERRORS_TOTAL).increment(1);
301}
302
303/// Record a connection with source country label.
304#[inline]
305pub fn record_connection_with_geo(country: &str) {
306    counter!(CONNECTIONS_BY_COUNTRY, "country" => country.to_owned()).increment(1);
307}
308
309/// Record bytes transferred with source country label.
310/// Direction: "sent" or "received".
311#[inline]
312pub fn record_bytes_with_geo(country: &str, direction: &'static str, bytes: u64) {
313    counter!(BYTES_BY_COUNTRY, "country" => country.to_owned(), "direction" => direction)
314        .increment(bytes);
315}
316
317/// Record an authentication failure with source country label.
318#[inline]
319pub fn record_auth_failure_with_geo(country: &str) {
320    counter!(AUTH_FAILURE_BY_COUNTRY, "country" => country.to_owned()).increment(1);
321}
322
323// ============================================================================
324// Error Type Constants (re-exported from trojan-core)
325// ============================================================================
326
327pub use trojan_core::{
328    ERROR_AUTH, ERROR_CONFIG, ERROR_IO, ERROR_PROTOCOL, ERROR_RESOLVE, ERROR_TIMEOUT,
329    ERROR_TLS_HANDSHAKE,
330};