systemprompt_api/services/middleware/analytics/
mod.rs1mod detection;
2mod events;
3
4use axum::extract::Request;
5use axum::http::StatusCode;
6use axum::middleware::Next;
7use axum::response::Response;
8use std::sync::Arc;
9
10use systemprompt_analytics::SessionRepository;
11use systemprompt_identifiers::SessionId;
12use systemprompt_logging::AnalyticsRepository;
13use systemprompt_models::{RequestContext, RouteClassifier};
14use systemprompt_runtime::AppContext;
15use systemprompt_security::ScannerDetector;
16
17pub use events::AnalyticsEventParams;
18
19struct TrackingParams<'a> {
20 req_ctx: &'a RequestContext,
21 uri: &'a http::Uri,
22 method: &'a http::Method,
23 status_code: u16,
24 response_time_ms: u64,
25 user_agent: Option<String>,
26 referer: Option<String>,
27 is_scanner: bool,
28}
29
30#[derive(Debug, Clone)]
31pub struct AnalyticsMiddleware {
32 session_repo: Arc<SessionRepository>,
33 analytics_repo: Arc<AnalyticsRepository>,
34 route_classifier: Arc<RouteClassifier>,
35}
36
37impl AnalyticsMiddleware {
38 pub fn new(app_context: &AppContext) -> anyhow::Result<Self> {
39 let db_pool = app_context.db_pool();
40 let session_repo = Arc::new(SessionRepository::new(db_pool)?);
41 let analytics_repo = Arc::new(AnalyticsRepository::new(db_pool)?);
42 let route_classifier = Arc::clone(app_context.route_classifier());
43
44 Ok(Self {
45 session_repo,
46 analytics_repo,
47 route_classifier,
48 })
49 }
50
51 pub async fn track_request(
52 &self,
53 request: Request,
54 next: Next,
55 ) -> Result<Response, StatusCode> {
56 let method = request.method().clone();
57 let uri = request.uri().clone();
58
59 let Some(req_ctx) = request.extensions().get::<RequestContext>().cloned() else {
60 return Ok(next.run(request).await);
61 };
62
63 if !req_ctx.request.is_tracked {
64 return Ok(next.run(request).await);
65 }
66
67 let user_agent = request
68 .headers()
69 .get("user-agent")
70 .and_then(|v| v.to_str().ok())
71 .map(ToString::to_string);
72
73 let referer = request
74 .headers()
75 .get("referer")
76 .and_then(|v| v.to_str().ok())
77 .map(ToString::to_string);
78
79 let start_time = std::time::Instant::now();
80 let response = next.run(request).await;
81 let response_time_ms = start_time.elapsed().as_millis() as u64;
82 let status_code = response.status();
83
84 let should_track = self
85 .route_classifier
86 .should_track_analytics(uri.path(), method.as_str());
87
88 let is_scanner =
89 ScannerDetector::is_scanner(Some(uri.path()), user_agent.as_deref(), None, None);
90
91 if should_track {
92 self.spawn_tracking_tasks(TrackingParams {
93 req_ctx: &req_ctx,
94 uri: &uri,
95 method: &method,
96 status_code: status_code.as_u16(),
97 response_time_ms,
98 user_agent,
99 referer,
100 is_scanner,
101 });
102 }
103
104 Ok(response)
105 }
106
107 fn spawn_tracking_tasks(&self, params: TrackingParams<'_>) {
108 let req_ctx = params.req_ctx;
109 let uri = params.uri;
110 let method = params.method;
111 let status_code = params.status_code;
112 let response_time_ms = params.response_time_ms;
113 let user_agent = params.user_agent;
114 let referer = params.referer;
115 let is_scanner = params.is_scanner;
116 let endpoint = format!("{} {}", method, uri.path());
117 let path = uri.path().to_string();
118
119 if is_scanner {
120 self.spawn_mark_scanner_task(req_ctx.request.session_id.clone());
121 }
122
123 self.spawn_velocity_scanner_check(req_ctx.request.session_id.clone());
124
125 self.spawn_session_tracking_task(req_ctx.request.session_id.clone());
126
127 detection::spawn_behavioral_detection_task(
128 Arc::clone(&self.session_repo),
129 req_ctx.request.session_id.clone(),
130 req_ctx.request.fingerprint_hash.clone(),
131 user_agent.clone(),
132 1,
133 );
134
135 events::spawn_analytics_event_task(
136 Arc::clone(&self.analytics_repo),
137 Arc::clone(&self.route_classifier),
138 AnalyticsEventParams {
139 req_ctx: req_ctx.clone(),
140 endpoint,
141 path,
142 method: method.to_string(),
143 uri: uri.clone(),
144 status_code,
145 response_time_ms,
146 user_agent,
147 referer,
148 },
149 );
150 }
151
152 fn spawn_session_tracking_task(&self, session_id: SessionId) {
153 let session_repo = Arc::clone(&self.session_repo);
154
155 tokio::spawn(async move {
156 if let Err(e) = session_repo.update_activity(&session_id).await {
157 tracing::error!(error = %e, "Failed to update session activity");
158 }
159
160 if let Err(e) = session_repo.increment_request_count(&session_id).await {
161 tracing::error!(error = %e, "Failed to increment request count");
162 }
163 });
164 }
165
166 fn spawn_velocity_scanner_check(&self, session_id: SessionId) {
167 let session_repo = Arc::clone(&self.session_repo);
168
169 tokio::spawn(async move {
170 let (request_count, duration_seconds) = session_repo
171 .get_session_velocity(&session_id)
172 .await
173 .unwrap_or((None, None));
174
175 if let (Some(count), Some(duration)) = (request_count, duration_seconds) {
176 if ScannerDetector::is_high_velocity(count, duration) {
177 if let Err(e) = session_repo.mark_as_scanner(&session_id).await {
178 tracing::warn!(
179 error = %e,
180 session_id = %session_id,
181 "Failed to mark high-velocity session as scanner"
182 );
183 }
184 }
185 }
186 });
187 }
188
189 fn spawn_mark_scanner_task(&self, session_id: SessionId) {
190 let session_repo = Arc::clone(&self.session_repo);
191
192 tokio::spawn(async move {
193 if let Err(e) = session_repo.mark_as_scanner(&session_id).await {
194 tracing::warn!(error = %e, session_id = %session_id, "Failed to mark session as scanner");
195 }
196 });
197 }
198}