Skip to main content

systemprompt_api/services/middleware/analytics/
mod.rs

1//! Request analytics middleware.
2//!
3//! [`AnalyticsMiddleware`] records tracked requests after the response is
4//! produced, spawning detached tasks for session activity, velocity-based
5//! scanner detection, behavioural bot scoring, and analytics-event capture so
6//! the request path is never blocked on persistence.
7
8mod detection;
9mod events;
10
11use axum::extract::Request;
12use axum::http::StatusCode;
13use axum::middleware::Next;
14use axum::response::Response;
15use std::sync::Arc;
16
17use systemprompt_analytics::SessionRepository;
18use systemprompt_identifiers::SessionId;
19use systemprompt_logging::AnalyticsRepository;
20use systemprompt_models::{RequestContext, RouteClassifier};
21use systemprompt_runtime::AppContext;
22use systemprompt_security::ScannerDetector;
23
24pub use events::AnalyticsEventParams;
25
26struct TrackingParams<'a> {
27    req_ctx: &'a RequestContext,
28    uri: &'a http::Uri,
29    method: &'a http::Method,
30    status_code: u16,
31    response_time_ms: u64,
32    user_agent: Option<String>,
33    referer: Option<String>,
34    is_scanner: bool,
35}
36
37#[derive(Debug, Clone)]
38pub struct AnalyticsMiddleware {
39    session_repo: Arc<SessionRepository>,
40    analytics_repo: Arc<AnalyticsRepository>,
41    route_classifier: Arc<RouteClassifier>,
42}
43
44impl AnalyticsMiddleware {
45    pub fn new(app_context: &AppContext) -> anyhow::Result<Self> {
46        let db_pool = app_context.db_pool();
47        let session_repo = Arc::new(SessionRepository::new(db_pool)?);
48        let analytics_repo = Arc::new(AnalyticsRepository::new(db_pool)?);
49        let route_classifier = Arc::clone(app_context.route_classifier());
50
51        Ok(Self {
52            session_repo,
53            analytics_repo,
54            route_classifier,
55        })
56    }
57
58    pub async fn track_request(
59        &self,
60        request: Request,
61        next: Next,
62    ) -> Result<Response, StatusCode> {
63        let method = request.method().clone();
64        let uri = request.uri().clone();
65
66        let Some(req_ctx) = request.extensions().get::<RequestContext>().cloned() else {
67            return Ok(next.run(request).await);
68        };
69
70        if !req_ctx.request.is_tracked {
71            return Ok(next.run(request).await);
72        }
73
74        let user_agent = request
75            .headers()
76            .get("user-agent")
77            .and_then(|v| v.to_str().ok())
78            .map(str::to_owned);
79
80        let referer = request
81            .headers()
82            .get("referer")
83            .and_then(|v| v.to_str().ok())
84            .map(str::to_owned);
85
86        let start_time = std::time::Instant::now();
87        let response = next.run(request).await;
88        let response_time_ms = start_time.elapsed().as_millis() as u64;
89        let status_code = response.status();
90
91        let should_track = self
92            .route_classifier
93            .should_track_analytics(uri.path(), method.as_str());
94
95        let is_scanner =
96            ScannerDetector::is_scanner(Some(uri.path()), user_agent.as_deref(), None, None);
97
98        if should_track {
99            self.spawn_tracking_tasks(TrackingParams {
100                req_ctx: &req_ctx,
101                uri: &uri,
102                method: &method,
103                status_code: status_code.as_u16(),
104                response_time_ms,
105                user_agent,
106                referer,
107                is_scanner,
108            });
109        }
110
111        Ok(response)
112    }
113
114    fn spawn_tracking_tasks(&self, params: TrackingParams<'_>) {
115        let req_ctx = params.req_ctx;
116        let uri = params.uri;
117        let method = params.method;
118        let status_code = params.status_code;
119        let response_time_ms = params.response_time_ms;
120        let user_agent = params.user_agent;
121        let referer = params.referer;
122        let is_scanner = params.is_scanner;
123        let endpoint = format!("{} {}", method, uri.path());
124        let path = uri.path().to_owned();
125
126        if is_scanner {
127            self.spawn_mark_scanner_task(req_ctx.request.session_id.clone());
128        }
129
130        self.spawn_velocity_scanner_check(req_ctx.request.session_id.clone());
131
132        self.spawn_session_tracking_task(req_ctx.request.session_id.clone());
133
134        detection::spawn_behavioral_detection_task(
135            Arc::clone(&self.session_repo),
136            req_ctx.request.session_id.clone(),
137            req_ctx.request.fingerprint_hash.clone(),
138            user_agent.clone(),
139            1,
140        );
141
142        events::spawn_analytics_event_task(
143            Arc::clone(&self.analytics_repo),
144            Arc::clone(&self.route_classifier),
145            AnalyticsEventParams {
146                req_ctx: req_ctx.clone(),
147                endpoint,
148                path,
149                method: method.to_string(),
150                uri: uri.clone(),
151                status_code,
152                response_time_ms,
153                user_agent,
154                referer,
155            },
156        );
157    }
158
159    fn spawn_session_tracking_task(&self, session_id: SessionId) {
160        let session_repo = Arc::clone(&self.session_repo);
161
162        tokio::spawn(async move {
163            if let Err(e) = session_repo.update_activity(&session_id).await {
164                tracing::error!(error = %e, "Failed to update session activity");
165            }
166
167            if let Err(e) = session_repo.increment_request_count(&session_id).await {
168                tracing::error!(error = %e, "Failed to increment request count");
169            }
170        });
171    }
172
173    fn spawn_velocity_scanner_check(&self, session_id: SessionId) {
174        let session_repo = Arc::clone(&self.session_repo);
175
176        tokio::spawn(async move {
177            let (request_count, duration_seconds) = session_repo
178                .get_session_velocity(&session_id)
179                .await
180                .unwrap_or((None, None));
181
182            if let (Some(count), Some(duration)) = (request_count, duration_seconds) {
183                if ScannerDetector::is_high_velocity(count, duration) {
184                    if let Err(e) = session_repo.mark_as_scanner(&session_id).await {
185                        tracing::warn!(
186                            error = %e,
187                            session_id = %session_id,
188                            "Failed to mark high-velocity session as scanner"
189                        );
190                    }
191                }
192            }
193        });
194    }
195
196    fn spawn_mark_scanner_task(&self, session_id: SessionId) {
197        let session_repo = Arc::clone(&self.session_repo);
198
199        tokio::spawn(async move {
200            if let Err(e) = session_repo.mark_as_scanner(&session_id).await {
201                tracing::warn!(error = %e, session_id = %session_id, "Failed to mark session as scanner");
202            }
203        });
204    }
205}