systemprompt_api/services/middleware/analytics/
mod.rs1mod detection;
2mod events;
3
4use axum::extract::Request;
5use axum::http::StatusCode;
6use axum::middleware::Next;
7use axum::response::Response;
8use std::sync::Arc;
9
10use systemprompt_analytics::SessionRepository;
11use systemprompt_identifiers::SessionId;
12use systemprompt_logging::AnalyticsRepository;
13use systemprompt_models::{RequestContext, RouteClassifier};
14use systemprompt_runtime::AppContext;
15use systemprompt_security::ScannerDetector;
16
17pub use events::AnalyticsEventParams;
18
19#[derive(Debug, Clone)]
20pub struct AnalyticsMiddleware {
21 session_repo: Arc<SessionRepository>,
22 analytics_repo: Arc<AnalyticsRepository>,
23 route_classifier: Arc<RouteClassifier>,
24}
25
26impl AnalyticsMiddleware {
27 pub fn new(app_context: &AppContext) -> anyhow::Result<Self> {
28 let db_pool = app_context.db_pool();
29 let session_repo = Arc::new(SessionRepository::new(db_pool)?);
30 let analytics_repo = Arc::new(AnalyticsRepository::new(db_pool)?);
31 let route_classifier = app_context.route_classifier().clone();
32
33 Ok(Self {
34 session_repo,
35 analytics_repo,
36 route_classifier,
37 })
38 }
39
40 pub async fn track_request(
41 &self,
42 request: Request,
43 next: Next,
44 ) -> Result<Response, StatusCode> {
45 let method = request.method().clone();
46 let uri = request.uri().clone();
47
48 let Some(req_ctx) = request.extensions().get::<RequestContext>().cloned() else {
49 return Ok(next.run(request).await);
50 };
51
52 if !req_ctx.request.is_tracked {
53 return Ok(next.run(request).await);
54 }
55
56 let user_agent = request
57 .headers()
58 .get("user-agent")
59 .and_then(|v| v.to_str().ok())
60 .map(ToString::to_string);
61
62 let referer = request
63 .headers()
64 .get("referer")
65 .and_then(|v| v.to_str().ok())
66 .map(ToString::to_string);
67
68 let start_time = std::time::Instant::now();
69 let response = next.run(request).await;
70 let response_time_ms = start_time.elapsed().as_millis() as u64;
71 let status_code = response.status();
72
73 let should_track = self
74 .route_classifier
75 .should_track_analytics(uri.path(), method.as_str());
76
77 let is_scanner =
78 ScannerDetector::is_scanner(Some(uri.path()), user_agent.as_deref(), None, None);
79
80 if should_track {
81 self.spawn_tracking_tasks(
82 &req_ctx,
83 &uri,
84 &method,
85 status_code.as_u16(),
86 response_time_ms,
87 user_agent,
88 referer,
89 is_scanner,
90 );
91 }
92
93 Ok(response)
94 }
95
96 fn spawn_tracking_tasks(
97 &self,
98 req_ctx: &RequestContext,
99 uri: &http::Uri,
100 method: &http::Method,
101 status_code: u16,
102 response_time_ms: u64,
103 user_agent: Option<String>,
104 referer: Option<String>,
105 is_scanner: bool,
106 ) {
107 let endpoint = format!("{} {}", method, uri.path());
108 let path = uri.path().to_string();
109
110 if is_scanner {
111 self.spawn_mark_scanner_task(req_ctx.request.session_id.clone());
112 }
113
114 self.spawn_velocity_scanner_check(req_ctx.request.session_id.clone());
115
116 self.spawn_session_tracking_task(req_ctx.request.session_id.clone());
117
118 detection::spawn_behavioral_detection_task(
119 self.session_repo.clone(),
120 req_ctx.request.session_id.clone(),
121 req_ctx.request.fingerprint_hash.clone(),
122 user_agent.clone(),
123 1,
124 );
125
126 events::spawn_analytics_event_task(
127 self.analytics_repo.clone(),
128 self.route_classifier.clone(),
129 AnalyticsEventParams {
130 req_ctx: req_ctx.clone(),
131 endpoint,
132 path,
133 method: method.to_string(),
134 uri: uri.clone(),
135 status_code,
136 response_time_ms,
137 user_agent,
138 referer,
139 },
140 );
141 }
142
143 fn spawn_session_tracking_task(&self, session_id: SessionId) {
144 let session_repo = self.session_repo.clone();
145
146 tokio::spawn(async move {
147 if let Err(e) = session_repo.update_activity(&session_id).await {
148 tracing::error!(error = %e, "Failed to update session activity");
149 }
150
151 if let Err(e) = session_repo.increment_request_count(&session_id).await {
152 tracing::error!(error = %e, "Failed to increment request count");
153 }
154 });
155 }
156
157 fn spawn_velocity_scanner_check(&self, session_id: SessionId) {
158 let session_repo = self.session_repo.clone();
159
160 tokio::spawn(async move {
161 let (request_count, duration_seconds) = session_repo
162 .get_session_velocity(&session_id)
163 .await
164 .unwrap_or((None, None));
165
166 if let (Some(count), Some(duration)) = (request_count, duration_seconds) {
167 if ScannerDetector::is_high_velocity(count, duration) {
168 if let Err(e) = session_repo.mark_as_scanner(&session_id).await {
169 tracing::warn!(
170 error = %e,
171 session_id = %session_id,
172 "Failed to mark high-velocity session as scanner"
173 );
174 }
175 }
176 }
177 });
178 }
179
180 fn spawn_mark_scanner_task(&self, session_id: SessionId) {
181 let session_repo = self.session_repo.clone();
182
183 tokio::spawn(async move {
184 if let Err(e) = session_repo.mark_as_scanner(&session_id).await {
185 tracing::warn!(error = %e, session_id = %session_id, "Failed to mark session as scanner");
186 }
187 });
188 }
189}