Skip to main content

systemprompt_api/services/middleware/analytics/
mod.rs

1mod detection;
2mod events;
3
4use axum::extract::Request;
5use axum::http::StatusCode;
6use axum::middleware::Next;
7use axum::response::Response;
8use std::sync::Arc;
9
10use systemprompt_analytics::SessionRepository;
11use systemprompt_identifiers::SessionId;
12use systemprompt_logging::AnalyticsRepository;
13use systemprompt_models::{RequestContext, RouteClassifier};
14use systemprompt_runtime::AppContext;
15use systemprompt_security::ScannerDetector;
16
17pub use events::AnalyticsEventParams;
18
19struct TrackingParams<'a> {
20    req_ctx: &'a RequestContext,
21    uri: &'a http::Uri,
22    method: &'a http::Method,
23    status_code: u16,
24    response_time_ms: u64,
25    user_agent: Option<String>,
26    referer: Option<String>,
27    is_scanner: bool,
28}
29
30#[derive(Debug, Clone)]
31pub struct AnalyticsMiddleware {
32    session_repo: Arc<SessionRepository>,
33    analytics_repo: Arc<AnalyticsRepository>,
34    route_classifier: Arc<RouteClassifier>,
35}
36
37impl AnalyticsMiddleware {
38    pub fn new(app_context: &AppContext) -> anyhow::Result<Self> {
39        let db_pool = app_context.db_pool();
40        let session_repo = Arc::new(SessionRepository::new(db_pool)?);
41        let analytics_repo = Arc::new(AnalyticsRepository::new(db_pool)?);
42        let route_classifier = Arc::clone(app_context.route_classifier());
43
44        Ok(Self {
45            session_repo,
46            analytics_repo,
47            route_classifier,
48        })
49    }
50
51    pub async fn track_request(
52        &self,
53        request: Request,
54        next: Next,
55    ) -> Result<Response, StatusCode> {
56        let method = request.method().clone();
57        let uri = request.uri().clone();
58
59        let Some(req_ctx) = request.extensions().get::<RequestContext>().cloned() else {
60            return Ok(next.run(request).await);
61        };
62
63        if !req_ctx.request.is_tracked {
64            return Ok(next.run(request).await);
65        }
66
67        let user_agent = request
68            .headers()
69            .get("user-agent")
70            .and_then(|v| v.to_str().ok())
71            .map(ToString::to_string);
72
73        let referer = request
74            .headers()
75            .get("referer")
76            .and_then(|v| v.to_str().ok())
77            .map(ToString::to_string);
78
79        let start_time = std::time::Instant::now();
80        let response = next.run(request).await;
81        let response_time_ms = start_time.elapsed().as_millis() as u64;
82        let status_code = response.status();
83
84        let should_track = self
85            .route_classifier
86            .should_track_analytics(uri.path(), method.as_str());
87
88        let is_scanner =
89            ScannerDetector::is_scanner(Some(uri.path()), user_agent.as_deref(), None, None);
90
91        if should_track {
92            self.spawn_tracking_tasks(TrackingParams {
93                req_ctx: &req_ctx,
94                uri: &uri,
95                method: &method,
96                status_code: status_code.as_u16(),
97                response_time_ms,
98                user_agent,
99                referer,
100                is_scanner,
101            });
102        }
103
104        Ok(response)
105    }
106
107    fn spawn_tracking_tasks(&self, params: TrackingParams<'_>) {
108        let req_ctx = params.req_ctx;
109        let uri = params.uri;
110        let method = params.method;
111        let status_code = params.status_code;
112        let response_time_ms = params.response_time_ms;
113        let user_agent = params.user_agent;
114        let referer = params.referer;
115        let is_scanner = params.is_scanner;
116        let endpoint = format!("{} {}", method, uri.path());
117        let path = uri.path().to_string();
118
119        if is_scanner {
120            self.spawn_mark_scanner_task(req_ctx.request.session_id.clone());
121        }
122
123        self.spawn_velocity_scanner_check(req_ctx.request.session_id.clone());
124
125        self.spawn_session_tracking_task(req_ctx.request.session_id.clone());
126
127        detection::spawn_behavioral_detection_task(
128            Arc::clone(&self.session_repo),
129            req_ctx.request.session_id.clone(),
130            req_ctx.request.fingerprint_hash.clone(),
131            user_agent.clone(),
132            1,
133        );
134
135        events::spawn_analytics_event_task(
136            Arc::clone(&self.analytics_repo),
137            Arc::clone(&self.route_classifier),
138            AnalyticsEventParams {
139                req_ctx: req_ctx.clone(),
140                endpoint,
141                path,
142                method: method.to_string(),
143                uri: uri.clone(),
144                status_code,
145                response_time_ms,
146                user_agent,
147                referer,
148            },
149        );
150    }
151
152    fn spawn_session_tracking_task(&self, session_id: SessionId) {
153        let session_repo = Arc::clone(&self.session_repo);
154
155        tokio::spawn(async move {
156            if let Err(e) = session_repo.update_activity(&session_id).await {
157                tracing::error!(error = %e, "Failed to update session activity");
158            }
159
160            if let Err(e) = session_repo.increment_request_count(&session_id).await {
161                tracing::error!(error = %e, "Failed to increment request count");
162            }
163        });
164    }
165
166    fn spawn_velocity_scanner_check(&self, session_id: SessionId) {
167        let session_repo = Arc::clone(&self.session_repo);
168
169        tokio::spawn(async move {
170            let (request_count, duration_seconds) = session_repo
171                .get_session_velocity(&session_id)
172                .await
173                .unwrap_or((None, None));
174
175            if let (Some(count), Some(duration)) = (request_count, duration_seconds) {
176                if ScannerDetector::is_high_velocity(count, duration) {
177                    if let Err(e) = session_repo.mark_as_scanner(&session_id).await {
178                        tracing::warn!(
179                            error = %e,
180                            session_id = %session_id,
181                            "Failed to mark high-velocity session as scanner"
182                        );
183                    }
184                }
185            }
186        });
187    }
188
189    fn spawn_mark_scanner_task(&self, session_id: SessionId) {
190        let session_repo = Arc::clone(&self.session_repo);
191
192        tokio::spawn(async move {
193            if let Err(e) = session_repo.mark_as_scanner(&session_id).await {
194                tracing::warn!(error = %e, session_id = %session_id, "Failed to mark session as scanner");
195            }
196        });
197    }
198}