Skip to main content

systemprompt_api/services/middleware/analytics/
mod.rs

1mod detection;
2mod events;
3
4use axum::extract::Request;
5use axum::http::StatusCode;
6use axum::middleware::Next;
7use axum::response::Response;
8use std::sync::Arc;
9
10use systemprompt_analytics::SessionRepository;
11use systemprompt_identifiers::SessionId;
12use systemprompt_logging::AnalyticsRepository;
13use systemprompt_models::{RequestContext, RouteClassifier};
14use systemprompt_runtime::AppContext;
15use systemprompt_security::ScannerDetector;
16
17pub use events::AnalyticsEventParams;
18
19#[derive(Debug, Clone)]
20pub struct AnalyticsMiddleware {
21    session_repo: Arc<SessionRepository>,
22    analytics_repo: Arc<AnalyticsRepository>,
23    route_classifier: Arc<RouteClassifier>,
24}
25
26impl AnalyticsMiddleware {
27    pub fn new(app_context: &AppContext) -> anyhow::Result<Self> {
28        let db_pool = app_context.db_pool();
29        let session_repo = Arc::new(SessionRepository::new(db_pool)?);
30        let analytics_repo = Arc::new(AnalyticsRepository::new(db_pool)?);
31        let route_classifier = app_context.route_classifier().clone();
32
33        Ok(Self {
34            session_repo,
35            analytics_repo,
36            route_classifier,
37        })
38    }
39
40    pub async fn track_request(
41        &self,
42        request: Request,
43        next: Next,
44    ) -> Result<Response, StatusCode> {
45        let method = request.method().clone();
46        let uri = request.uri().clone();
47
48        let Some(req_ctx) = request.extensions().get::<RequestContext>().cloned() else {
49            return Ok(next.run(request).await);
50        };
51
52        if !req_ctx.request.is_tracked {
53            return Ok(next.run(request).await);
54        }
55
56        let user_agent = request
57            .headers()
58            .get("user-agent")
59            .and_then(|v| v.to_str().ok())
60            .map(ToString::to_string);
61
62        let referer = request
63            .headers()
64            .get("referer")
65            .and_then(|v| v.to_str().ok())
66            .map(ToString::to_string);
67
68        let start_time = std::time::Instant::now();
69        let response = next.run(request).await;
70        let response_time_ms = start_time.elapsed().as_millis() as u64;
71        let status_code = response.status();
72
73        let should_track = self
74            .route_classifier
75            .should_track_analytics(uri.path(), method.as_str());
76
77        let is_scanner =
78            ScannerDetector::is_scanner(Some(uri.path()), user_agent.as_deref(), None, None);
79
80        if should_track {
81            self.spawn_tracking_tasks(
82                &req_ctx,
83                &uri,
84                &method,
85                status_code.as_u16(),
86                response_time_ms,
87                user_agent,
88                referer,
89                is_scanner,
90            );
91        }
92
93        Ok(response)
94    }
95
96    fn spawn_tracking_tasks(
97        &self,
98        req_ctx: &RequestContext,
99        uri: &http::Uri,
100        method: &http::Method,
101        status_code: u16,
102        response_time_ms: u64,
103        user_agent: Option<String>,
104        referer: Option<String>,
105        is_scanner: bool,
106    ) {
107        let endpoint = format!("{} {}", method, uri.path());
108        let path = uri.path().to_string();
109
110        if is_scanner {
111            self.spawn_mark_scanner_task(req_ctx.request.session_id.clone());
112        }
113
114        self.spawn_velocity_scanner_check(req_ctx.request.session_id.clone());
115
116        self.spawn_session_tracking_task(req_ctx.request.session_id.clone());
117
118        detection::spawn_behavioral_detection_task(
119            self.session_repo.clone(),
120            req_ctx.request.session_id.clone(),
121            req_ctx.request.fingerprint_hash.clone(),
122            user_agent.clone(),
123            1,
124        );
125
126        events::spawn_analytics_event_task(
127            self.analytics_repo.clone(),
128            self.route_classifier.clone(),
129            AnalyticsEventParams {
130                req_ctx: req_ctx.clone(),
131                endpoint,
132                path,
133                method: method.to_string(),
134                uri: uri.clone(),
135                status_code,
136                response_time_ms,
137                user_agent,
138                referer,
139            },
140        );
141    }
142
143    fn spawn_session_tracking_task(&self, session_id: SessionId) {
144        let session_repo = self.session_repo.clone();
145
146        tokio::spawn(async move {
147            if let Err(e) = session_repo.update_activity(&session_id).await {
148                tracing::error!(error = %e, "Failed to update session activity");
149            }
150
151            if let Err(e) = session_repo.increment_request_count(&session_id).await {
152                tracing::error!(error = %e, "Failed to increment request count");
153            }
154        });
155    }
156
157    fn spawn_velocity_scanner_check(&self, session_id: SessionId) {
158        let session_repo = self.session_repo.clone();
159
160        tokio::spawn(async move {
161            let (request_count, duration_seconds) = session_repo
162                .get_session_velocity(&session_id)
163                .await
164                .unwrap_or((None, None));
165
166            if let (Some(count), Some(duration)) = (request_count, duration_seconds) {
167                if ScannerDetector::is_high_velocity(count, duration) {
168                    if let Err(e) = session_repo.mark_as_scanner(&session_id).await {
169                        tracing::warn!(
170                            error = %e,
171                            session_id = %session_id,
172                            "Failed to mark high-velocity session as scanner"
173                        );
174                    }
175                }
176            }
177        });
178    }
179
180    fn spawn_mark_scanner_task(&self, session_id: SessionId) {
181        let session_repo = self.session_repo.clone();
182
183        tokio::spawn(async move {
184            if let Err(e) = session_repo.mark_as_scanner(&session_id).await {
185                tracing::warn!(error = %e, session_id = %session_id, "Failed to mark session as scanner");
186            }
187        });
188    }
189}