1use std::net::{IpAddr, SocketAddr};
4use std::sync::Arc;
5use std::time::Instant;
6
7use tokio::sync::mpsc;
8use tracing::debug;
9use trojan_config::{AnalyticsConfig, AnalyticsPrivacyConfig};
10
11use crate::event::{AuthResult, CloseReason, ConnectionEvent, Protocol, TargetType, Transport};
12
13#[derive(Debug, Clone)]
18pub struct EventCollector {
19 sender: mpsc::Sender<ConnectionEvent>,
20 config: Arc<AnalyticsConfig>,
21}
22
23impl EventCollector {
24 pub(crate) fn new(sender: mpsc::Sender<ConnectionEvent>, config: Arc<AnalyticsConfig>) -> Self {
26 Self { sender, config }
27 }
28
29 #[inline]
33 pub fn record(&self, event: ConnectionEvent) -> bool {
34 self.sender.try_send(event).is_ok()
35 }
36
37 pub fn connection(&self, conn_id: u64, peer: SocketAddr) -> ConnectionEventBuilder {
41 ConnectionEventBuilder::new(self.clone(), conn_id, peer, &self.config)
42 }
43
44 pub fn should_sample(&self, user_id: Option<&str>) -> bool {
48 let sampling = &self.config.sampling;
49
50 if let Some(uid) = user_id
52 && sampling.always_record_users.iter().any(|u| u == uid)
53 {
54 return true;
55 }
56
57 if sampling.rate >= 1.0 {
59 return true;
60 }
61 if sampling.rate <= 0.0 {
62 return false;
63 }
64
65 rand::random::<f64>() < sampling.rate
66 }
67
68 pub fn privacy(&self) -> &AnalyticsPrivacyConfig {
70 &self.config.privacy
71 }
72
73 pub fn server_id(&self) -> Option<&str> {
75 self.config.server_id.as_deref()
76 }
77}
78
79#[derive(Debug)]
84pub struct ConnectionEventBuilder {
85 collector: EventCollector,
86 event: ConnectionEvent,
87 start_time: Instant,
88 sent: bool,
89}
90
91impl ConnectionEventBuilder {
92 fn new(
94 collector: EventCollector,
95 conn_id: u64,
96 peer: SocketAddr,
97 config: &AnalyticsConfig,
98 ) -> Self {
99 let peer_ip = if config.privacy.record_peer_ip {
100 peer.ip()
101 } else {
102 match peer {
104 SocketAddr::V4(_) => IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
105 SocketAddr::V6(_) => IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED),
106 }
107 };
108
109 let mut event = ConnectionEvent::new(conn_id, peer_ip, peer.port());
110 event.server_id = config.server_id.clone().unwrap_or_default();
111
112 Self {
113 collector,
114 event,
115 start_time: Instant::now(),
116 sent: false,
117 }
118 }
119
120 pub fn user(mut self, user_id: impl Into<String>) -> Self {
122 let uid = user_id.into();
123 let privacy = self.collector.privacy();
124
125 self.event.user_id = if privacy.full_user_id {
126 uid
127 } else {
128 let len = privacy.user_id_prefix_len.min(uid.len());
130 uid[..len].to_string()
131 };
132 self.event.auth_result = AuthResult::Success;
133 self
134 }
135
136 pub fn auth_failed(mut self) -> Self {
138 self.event.auth_result = AuthResult::Failed;
139 self
140 }
141
142 pub fn target(mut self, host: impl Into<String>, port: u16, target_type: TargetType) -> Self {
144 self.event.target_host = host.into();
145 self.event.target_port = port;
146 self.event.target_type = target_type;
147 self
148 }
149
150 pub fn sni(mut self, sni: impl Into<String>) -> Self {
152 if self.collector.privacy().record_sni {
153 self.event.sni = sni.into();
154 }
155 self
156 }
157
158 pub fn protocol(mut self, protocol: Protocol) -> Self {
160 self.event.protocol = protocol;
161 self
162 }
163
164 pub fn transport(mut self, transport: Transport) -> Self {
166 self.event.transport = transport;
167 self
168 }
169
170 pub fn fallback(mut self) -> Self {
172 self.event.is_fallback = true;
173 self.event.auth_result = AuthResult::Skipped;
174 self
175 }
176
177 pub fn geo(mut self, result: trojan_config::GeoResult, precision: &str) -> Self {
184 match precision {
185 "city" => {
186 self.event.peer_country = result.country;
187 self.event.peer_region = result.region;
188 self.event.peer_city = result.city;
189 self.event.peer_asn = result.asn;
190 self.event.peer_org = result.org;
191 self.event.peer_longitude = result.longitude;
192 self.event.peer_latitude = result.latitude;
193 }
194 "country" => {
195 self.event.peer_country = result.country;
196 }
197 _ => {} }
199 self
200 }
201
202 #[inline]
204 pub fn add_bytes(&mut self, sent: u64, recv: u64) {
205 self.event.bytes_sent += sent;
206 self.event.bytes_recv += recv;
207 }
208
209 #[inline]
211 pub fn add_packets(&mut self, sent: u64, recv: u64) {
212 self.event.packets_sent += sent;
213 self.event.packets_recv += recv;
214 }
215
216 pub fn event_mut(&mut self) -> &mut ConnectionEvent {
218 &mut self.event
219 }
220
221 #[allow(clippy::cast_possible_truncation)]
223 pub fn finish(mut self, close_reason: CloseReason) {
224 self.event.duration_ms = self.start_time.elapsed().as_millis() as u64;
225 self.event.close_reason = close_reason;
226 self.send();
227 }
228
229 fn send(&mut self) {
231 if self.sent {
232 return;
233 }
234 self.sent = true;
235
236 if !self.collector.record(self.event.clone()) {
237 debug!(
238 conn_id = self.event.conn_id,
239 "analytics buffer full, event dropped"
240 );
241 }
242 }
243}
244
245impl Drop for ConnectionEventBuilder {
246 #[allow(clippy::cast_possible_truncation)]
247 fn drop(&mut self) {
248 if !self.sent {
249 self.event.duration_ms = self.start_time.elapsed().as_millis() as u64;
250 self.send();
251 }
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use std::net::{Ipv4Addr, SocketAddrV4};
259 use std::sync::Arc;
260 use trojan_config::{AnalyticsConfig, GeoResult};
261
262 fn test_collector() -> (EventCollector, mpsc::Receiver<ConnectionEvent>) {
263 let (tx, rx) = mpsc::channel(64);
264 let config = Arc::new(AnalyticsConfig {
265 enabled: true,
266 ..Default::default()
267 });
268 (EventCollector::new(tx, config), rx)
269 }
270
271 fn test_peer() -> SocketAddr {
272 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 12345))
273 }
274
275 #[test]
276 fn geo_builder_city_precision() {
277 let (collector, _rx) = test_collector();
278 let builder = collector.connection(1, test_peer());
279
280 let geo = GeoResult {
281 country: "US".into(),
282 region: "California".into(),
283 city: "Los Angeles".into(),
284 asn: 15169,
285 org: "Google LLC".into(),
286 longitude: -118.24,
287 latitude: 34.05,
288 };
289
290 let builder = builder.geo(geo, "city");
291 assert_eq!(builder.event.peer_country, "US");
292 assert_eq!(builder.event.peer_region, "California");
293 assert_eq!(builder.event.peer_city, "Los Angeles");
294 assert_eq!(builder.event.peer_asn, 15169);
295 assert_eq!(builder.event.peer_org, "Google LLC");
296 assert!((builder.event.peer_longitude - (-118.24)).abs() < 0.001);
297 assert!((builder.event.peer_latitude - 34.05).abs() < 0.001);
298 }
299
300 #[test]
301 fn geo_builder_country_precision() {
302 let (collector, _rx) = test_collector();
303 let builder = collector.connection(2, test_peer());
304
305 let geo = GeoResult {
306 country: "CN".into(),
307 region: "Shanghai".into(),
308 city: "Shanghai".into(),
309 asn: 4134,
310 org: "China Telecom".into(),
311 longitude: 121.47,
312 latitude: 31.23,
313 };
314
315 let builder = builder.geo(geo, "country");
316 assert_eq!(builder.event.peer_country, "CN");
317 assert!(builder.event.peer_region.is_empty());
318 assert!(builder.event.peer_city.is_empty());
319 assert_eq!(builder.event.peer_asn, 0);
320 }
321
322 #[test]
323 fn geo_builder_none_precision() {
324 let (collector, _rx) = test_collector();
325 let builder = collector.connection(3, test_peer());
326
327 let geo = GeoResult {
328 country: "JP".into(),
329 region: "Tokyo".into(),
330 city: "Tokyo".into(),
331 asn: 2497,
332 org: "IIJ".into(),
333 longitude: 139.69,
334 latitude: 35.69,
335 };
336
337 let builder = builder.geo(geo, "none");
338 assert!(builder.event.peer_country.is_empty());
339 assert!(builder.event.peer_region.is_empty());
340 assert_eq!(builder.event.peer_asn, 0);
341 }
342
343 #[tokio::test]
344 async fn event_builder_sends_on_finish() {
345 let (collector, mut rx) = test_collector();
346 let builder = collector.connection(10, test_peer());
347 builder
348 .target("example.com".to_string(), 443, TargetType::Domain)
349 .protocol(Protocol::Tcp)
350 .finish(CloseReason::Normal);
351
352 let event = rx.try_recv().unwrap();
353 assert_eq!(event.conn_id, 10);
354 assert_eq!(event.target_host, "example.com");
355 assert_eq!(event.target_port, 443);
356 assert_eq!(event.protocol, Protocol::Tcp);
357 assert_eq!(event.close_reason, CloseReason::Normal);
358 }
359
360 #[tokio::test]
361 async fn event_builder_sends_on_drop() {
362 let (collector, mut rx) = test_collector();
363 {
364 let _builder = collector.connection(20, test_peer());
365 }
366 let event = rx.try_recv().unwrap();
367 assert_eq!(event.conn_id, 20);
368 }
369
370 #[test]
371 fn should_sample_always_record_user() {
372 let (tx, _rx) = mpsc::channel(1);
373 let config = Arc::new(AnalyticsConfig {
374 enabled: true,
375 sampling: trojan_config::AnalyticsSamplingConfig {
376 rate: 0.0,
377 always_record_users: vec!["vip-user".into()],
378 },
379 ..Default::default()
380 });
381 let collector = EventCollector::new(tx, config);
382 assert!(collector.should_sample(Some("vip-user")));
383 assert!(!collector.should_sample(Some("normal-user")));
384 }
385
386 #[test]
387 fn should_sample_rate_boundaries() {
388 let (tx, _rx) = mpsc::channel(1);
389 let config = Arc::new(AnalyticsConfig {
390 enabled: true,
391 sampling: trojan_config::AnalyticsSamplingConfig {
392 rate: 1.0,
393 always_record_users: vec![],
394 },
395 ..Default::default()
396 });
397 let collector = EventCollector::new(tx, config);
398 assert!(collector.should_sample(None));
399
400 let (tx2, _rx2) = mpsc::channel(1);
401 let config2 = Arc::new(AnalyticsConfig {
402 enabled: true,
403 sampling: trojan_config::AnalyticsSamplingConfig {
404 rate: 0.0,
405 always_record_users: vec![],
406 },
407 ..Default::default()
408 });
409 let collector2 = EventCollector::new(tx2, config2);
410 assert!(!collector2.should_sample(None));
411 }
412
413 #[test]
414 fn user_id_truncation() {
415 let (collector, _rx) = test_collector();
416 let builder = collector.connection(30, test_peer());
417 let builder = builder.user("abcdef1234567890");
418 assert_eq!(builder.event.user_id, "abcdef12");
419 }
420
421 #[test]
422 fn add_bytes_and_packets() {
423 let (collector, _rx) = test_collector();
424 let mut builder = collector.connection(40, test_peer());
425 builder.add_bytes(100, 200);
426 builder.add_bytes(50, 25);
427 builder.add_packets(3, 5);
428 assert_eq!(builder.event.bytes_sent, 150);
429 assert_eq!(builder.event.bytes_recv, 225);
430 assert_eq!(builder.event.packets_sent, 3);
431 assert_eq!(builder.event.packets_recv, 5);
432 }
433}