systemprompt_api/services/middleware/analytics/
mod.rs1mod detection;
9mod events;
10
11use axum::extract::Request;
12use axum::http::StatusCode;
13use axum::middleware::Next;
14use axum::response::Response;
15use std::sync::Arc;
16
17use systemprompt_analytics::SessionRepository;
18use systemprompt_identifiers::SessionId;
19use systemprompt_logging::AnalyticsRepository;
20use systemprompt_models::{RequestContext, RouteClassifier};
21use systemprompt_runtime::AppContext;
22use systemprompt_security::ScannerDetector;
23
24pub use events::AnalyticsEventParams;
25
26struct TrackingParams<'a> {
27 req_ctx: &'a RequestContext,
28 uri: &'a http::Uri,
29 method: &'a http::Method,
30 status_code: u16,
31 response_time_ms: u64,
32 user_agent: Option<String>,
33 referer: Option<String>,
34 is_scanner: bool,
35}
36
37#[derive(Debug, Clone)]
38pub struct AnalyticsMiddleware {
39 session_repo: Arc<SessionRepository>,
40 analytics_repo: Arc<AnalyticsRepository>,
41 route_classifier: Arc<RouteClassifier>,
42}
43
44impl AnalyticsMiddleware {
45 pub fn new(app_context: &AppContext) -> anyhow::Result<Self> {
46 let db_pool = app_context.db_pool();
47 let session_repo = Arc::new(SessionRepository::new(db_pool)?);
48 let analytics_repo = Arc::new(AnalyticsRepository::new(db_pool)?);
49 let route_classifier = Arc::clone(app_context.route_classifier());
50
51 Ok(Self {
52 session_repo,
53 analytics_repo,
54 route_classifier,
55 })
56 }
57
58 pub async fn track_request(
59 &self,
60 request: Request,
61 next: Next,
62 ) -> Result<Response, StatusCode> {
63 let method = request.method().clone();
64 let uri = request.uri().clone();
65
66 let Some(req_ctx) = request.extensions().get::<RequestContext>().cloned() else {
67 return Ok(next.run(request).await);
68 };
69
70 if !req_ctx.request.is_tracked {
71 return Ok(next.run(request).await);
72 }
73
74 let user_agent = request
75 .headers()
76 .get("user-agent")
77 .and_then(|v| v.to_str().ok())
78 .map(str::to_owned);
79
80 let referer = request
81 .headers()
82 .get("referer")
83 .and_then(|v| v.to_str().ok())
84 .map(str::to_owned);
85
86 let start_time = std::time::Instant::now();
87 let response = next.run(request).await;
88 let response_time_ms = start_time.elapsed().as_millis() as u64;
89 let status_code = response.status();
90
91 let should_track = self
92 .route_classifier
93 .should_track_analytics(uri.path(), method.as_str());
94
95 let is_scanner =
96 ScannerDetector::is_scanner(Some(uri.path()), user_agent.as_deref(), None, None);
97
98 if should_track {
99 self.spawn_tracking_tasks(TrackingParams {
100 req_ctx: &req_ctx,
101 uri: &uri,
102 method: &method,
103 status_code: status_code.as_u16(),
104 response_time_ms,
105 user_agent,
106 referer,
107 is_scanner,
108 });
109 }
110
111 Ok(response)
112 }
113
114 fn spawn_tracking_tasks(&self, params: TrackingParams<'_>) {
115 let req_ctx = params.req_ctx;
116 let uri = params.uri;
117 let method = params.method;
118 let status_code = params.status_code;
119 let response_time_ms = params.response_time_ms;
120 let user_agent = params.user_agent;
121 let referer = params.referer;
122 let is_scanner = params.is_scanner;
123 let endpoint = format!("{} {}", method, uri.path());
124 let path = uri.path().to_owned();
125
126 if is_scanner {
127 self.spawn_mark_scanner_task(req_ctx.request.session_id.clone());
128 }
129
130 self.spawn_velocity_scanner_check(req_ctx.request.session_id.clone());
131
132 self.spawn_session_tracking_task(req_ctx.request.session_id.clone());
133
134 detection::spawn_behavioral_detection_task(
135 Arc::clone(&self.session_repo),
136 req_ctx.request.session_id.clone(),
137 req_ctx.request.fingerprint_hash.clone(),
138 user_agent.clone(),
139 1,
140 );
141
142 events::spawn_analytics_event_task(
143 Arc::clone(&self.analytics_repo),
144 Arc::clone(&self.route_classifier),
145 AnalyticsEventParams {
146 req_ctx: req_ctx.clone(),
147 endpoint,
148 path,
149 method: method.to_string(),
150 uri: uri.clone(),
151 status_code,
152 response_time_ms,
153 user_agent,
154 referer,
155 },
156 );
157 }
158
159 fn spawn_session_tracking_task(&self, session_id: SessionId) {
160 let session_repo = Arc::clone(&self.session_repo);
161
162 tokio::spawn(async move {
163 if let Err(e) = session_repo.update_activity(&session_id).await {
164 tracing::error!(error = %e, "Failed to update session activity");
165 }
166
167 if let Err(e) = session_repo.increment_request_count(&session_id).await {
168 tracing::error!(error = %e, "Failed to increment request count");
169 }
170 });
171 }
172
173 fn spawn_velocity_scanner_check(&self, session_id: SessionId) {
174 let session_repo = Arc::clone(&self.session_repo);
175
176 tokio::spawn(async move {
177 let (request_count, duration_seconds) = session_repo
178 .get_session_velocity(&session_id)
179 .await
180 .unwrap_or((None, None));
181
182 if let (Some(count), Some(duration)) = (request_count, duration_seconds) {
183 if ScannerDetector::is_high_velocity(count, duration) {
184 if let Err(e) = session_repo.mark_as_scanner(&session_id).await {
185 tracing::warn!(
186 error = %e,
187 session_id = %session_id,
188 "Failed to mark high-velocity session as scanner"
189 );
190 }
191 }
192 }
193 });
194 }
195
196 fn spawn_mark_scanner_task(&self, session_id: SessionId) {
197 let session_repo = Arc::clone(&self.session_repo);
198
199 tokio::spawn(async move {
200 if let Err(e) = session_repo.mark_as_scanner(&session_id).await {
201 tracing::warn!(error = %e, session_id = %session_id, "Failed to mark session as scanner");
202 }
203 });
204 }
205}