Skip to main content

systemprompt_api/services/middleware/
throttle.rs

1use axum::extract::Request;
2use axum::middleware::Next;
3use axum::response::{IntoResponse, Response};
4use std::sync::Arc;
5
6use systemprompt_analytics::{SessionRepository, ThrottleLevel};
7use systemprompt_database::DbPool;
8use systemprompt_models::api::{ApiError, ErrorCode};
9use systemprompt_models::RequestContext;
10
11#[derive(Debug, Clone)]
12pub struct ThrottleMiddleware {
13    session_repo: Arc<SessionRepository>,
14}
15
16impl ThrottleMiddleware {
17    pub fn new(db_pool: DbPool) -> Self {
18        Self {
19            session_repo: Arc::new(SessionRepository::new(db_pool)),
20        }
21    }
22
23    pub async fn check_throttle(&self, request: Request, next: Next) -> Result<Response, ApiError> {
24        let Some(req_ctx) = request.extensions().get::<RequestContext>().cloned() else {
25            return Ok(next.run(request).await);
26        };
27
28        if !req_ctx.request.is_tracked {
29            return Ok(next.run(request).await);
30        }
31
32        let throttle_level = self
33            .session_repo
34            .get_throttle_level(&req_ctx.request.session_id)
35            .await
36            .unwrap_or_else(|e| {
37                tracing::warn!(error = %e, session_id = %req_ctx.request.session_id, "Failed to get throttle level");
38                0
39            });
40
41        let level = ThrottleLevel::from(throttle_level);
42
43        if !level.allows_requests() {
44            let api_error = ApiError::new(
45                ErrorCode::RateLimited,
46                "Request blocked due to suspicious activity",
47            );
48            let mut response = api_error.into_response();
49            response
50                .headers_mut()
51                .insert("Retry-After", http::HeaderValue::from_static("3600"));
52            response.headers_mut().insert(
53                "X-Throttle-Level",
54                http::HeaderValue::from_static("blocked"),
55            );
56            response.headers_mut().insert(
57                "X-Throttle-Reason",
58                http::HeaderValue::from_static("behavioral_bot_detection"),
59            );
60            return Ok(response);
61        }
62
63        let mut response = next.run(request).await;
64
65        if throttle_level > 0 {
66            let level_str = match throttle_level {
67                1 => "warning",
68                2 => "severe",
69                _ => "unknown",
70            };
71
72            if let Ok(header_value) = level_str.parse() {
73                response
74                    .headers_mut()
75                    .insert("X-Throttle-Level", header_value);
76            }
77
78            let multiplier = level.rate_multiplier();
79            if let Ok(header_value) = format!("{multiplier}").parse() {
80                response
81                    .headers_mut()
82                    .insert("X-Rate-Multiplier", header_value);
83            }
84        }
85
86        Ok(response)
87    }
88}
89
90pub async fn check_throttle_level(
91    middleware: axum::extract::State<ThrottleMiddleware>,
92    request: Request,
93    next: Next,
94) -> Result<Response, ApiError> {
95    middleware.check_throttle(request, next).await
96}