Skip to main content

trojan_analytics/
collector.rs

1//! Event collector for non-blocking event recording.
2
3use std::net::{IpAddr, SocketAddr};
4use std::sync::Arc;
5use std::time::Instant;
6
7use rand::Rng;
8use tokio::sync::mpsc;
9use tracing::debug;
10use trojan_config::{AnalyticsConfig, AnalyticsPrivacyConfig};
11
12use crate::event::{AuthResult, CloseReason, ConnectionEvent, Protocol, TargetType, Transport};
13
14/// Event collector for recording connection events.
15///
16/// This struct is cheap to clone and can be shared across threads.
17/// Events are sent through a bounded channel to a background writer.
18#[derive(Clone)]
19pub struct EventCollector {
20    sender: mpsc::Sender<ConnectionEvent>,
21    config: Arc<AnalyticsConfig>,
22}
23
24impl EventCollector {
25    /// Create a new event collector.
26    pub(crate) fn new(sender: mpsc::Sender<ConnectionEvent>, config: Arc<AnalyticsConfig>) -> Self {
27        Self { sender, config }
28    }
29
30    /// Record a connection event (non-blocking).
31    ///
32    /// Returns `true` if the event was queued, `false` if the buffer is full.
33    #[inline]
34    pub fn record(&self, event: ConnectionEvent) -> bool {
35        self.sender.try_send(event).is_ok()
36    }
37
38    /// Create a connection event builder for the given connection.
39    ///
40    /// The builder will automatically send the event when dropped.
41    pub fn connection(&self, conn_id: u64, peer: SocketAddr) -> ConnectionEventBuilder {
42        ConnectionEventBuilder::new(self.clone(), conn_id, peer, &self.config)
43    }
44
45    /// Check if an event should be recorded based on sampling configuration.
46    ///
47    /// Returns `true` if the event should be recorded.
48    pub fn should_sample(&self, user_id: Option<&str>) -> bool {
49        let sampling = &self.config.sampling;
50
51        // Always record specified users
52        if let Some(uid) = user_id
53            && sampling.always_record_users.iter().any(|u| u == uid)
54        {
55            return true;
56        }
57
58        // Sample based on rate
59        if sampling.rate >= 1.0 {
60            return true;
61        }
62        if sampling.rate <= 0.0 {
63            return false;
64        }
65
66        rand::thread_rng().r#gen::<f64>() < sampling.rate
67    }
68
69    /// Get the privacy configuration.
70    pub fn privacy(&self) -> &AnalyticsPrivacyConfig {
71        &self.config.privacy
72    }
73
74    /// Get the server ID.
75    pub fn server_id(&self) -> Option<&str> {
76        self.config.server_id.as_deref()
77    }
78}
79
80/// Builder for constructing connection events.
81///
82/// Events are automatically sent when the builder is dropped,
83/// or can be explicitly sent with `finish()`.
84pub struct ConnectionEventBuilder {
85    collector: EventCollector,
86    event: ConnectionEvent,
87    start_time: Instant,
88    sent: bool,
89}
90
91impl ConnectionEventBuilder {
92    /// Create a new connection event builder.
93    fn new(
94        collector: EventCollector,
95        conn_id: u64,
96        peer: SocketAddr,
97        config: &AnalyticsConfig,
98    ) -> Self {
99        let peer_ip = if config.privacy.record_peer_ip {
100            peer.ip()
101        } else {
102            // Use unspecified address if not recording
103            match peer {
104                SocketAddr::V4(_) => IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
105                SocketAddr::V6(_) => IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED),
106            }
107        };
108
109        let mut event = ConnectionEvent::new(conn_id, peer_ip, peer.port());
110        event.server_id = config.server_id.clone().unwrap_or_default();
111
112        Self {
113            collector,
114            event,
115            start_time: Instant::now(),
116            sent: false,
117        }
118    }
119
120    /// Set the user ID.
121    pub fn user(mut self, user_id: impl Into<String>) -> Self {
122        let uid = user_id.into();
123        let privacy = self.collector.privacy();
124
125        self.event.user_id = if privacy.full_user_id {
126            uid
127        } else {
128            // Truncate to prefix length
129            let len = privacy.user_id_prefix_len.min(uid.len());
130            uid[..len].to_string()
131        };
132        self.event.auth_result = AuthResult::Success;
133        self
134    }
135
136    /// Set authentication as failed.
137    pub fn auth_failed(mut self) -> Self {
138        self.event.auth_result = AuthResult::Failed;
139        self
140    }
141
142    /// Set the target information.
143    pub fn target(mut self, host: impl Into<String>, port: u16, target_type: TargetType) -> Self {
144        self.event.target_host = host.into();
145        self.event.target_port = port;
146        self.event.target_type = target_type;
147        self
148    }
149
150    /// Set the SNI (Server Name Indication).
151    pub fn sni(mut self, sni: impl Into<String>) -> Self {
152        if self.collector.privacy().record_sni {
153            self.event.sni = sni.into();
154        }
155        self
156    }
157
158    /// Set the protocol type.
159    pub fn protocol(mut self, protocol: Protocol) -> Self {
160        self.event.protocol = protocol;
161        self
162    }
163
164    /// Set the transport type.
165    pub fn transport(mut self, transport: Transport) -> Self {
166        self.event.transport = transport;
167        self
168    }
169
170    /// Mark as fallback connection.
171    pub fn fallback(mut self) -> Self {
172        self.event.is_fallback = true;
173        self.event.auth_result = AuthResult::Skipped;
174        self
175    }
176
177    /// Add bytes to the traffic counters.
178    #[inline]
179    pub fn add_bytes(&mut self, sent: u64, recv: u64) {
180        self.event.bytes_sent += sent;
181        self.event.bytes_recv += recv;
182    }
183
184    /// Add packets to the packet counters (for UDP).
185    #[inline]
186    pub fn add_packets(&mut self, sent: u64, recv: u64) {
187        self.event.packets_sent += sent;
188        self.event.packets_recv += recv;
189    }
190
191    /// Get a mutable reference to the event for direct modification.
192    pub fn event_mut(&mut self) -> &mut ConnectionEvent {
193        &mut self.event
194    }
195
196    /// Finish and send the event with the given close reason.
197    pub fn finish(mut self, close_reason: CloseReason) {
198        self.event.duration_ms = self.start_time.elapsed().as_millis() as u64;
199        self.event.close_reason = close_reason;
200        self.send();
201    }
202
203    /// Send the event.
204    fn send(&mut self) {
205        if self.sent {
206            return;
207        }
208        self.sent = true;
209
210        if !self.collector.record(self.event.clone()) {
211            debug!(
212                conn_id = self.event.conn_id,
213                "analytics buffer full, event dropped"
214            );
215        }
216    }
217}
218
219impl Drop for ConnectionEventBuilder {
220    fn drop(&mut self) {
221        if !self.sent {
222            self.event.duration_ms = self.start_time.elapsed().as_millis() as u64;
223            self.send();
224        }
225    }
226}