systemprompt_api/services/middleware/
throttle.rs1use axum::extract::Request;
2use axum::middleware::Next;
3use axum::response::{IntoResponse, Response};
4use std::sync::Arc;
5
6use systemprompt_analytics::{SessionRepository, ThrottleLevel};
7use systemprompt_database::DbPool;
8use systemprompt_models::api::{ApiError, ErrorCode};
9use systemprompt_models::RequestContext;
10
11#[derive(Debug, Clone)]
12pub struct ThrottleMiddleware {
13 session_repo: Arc<SessionRepository>,
14}
15
16impl ThrottleMiddleware {
17 pub fn new(db_pool: DbPool) -> Self {
18 Self {
19 session_repo: Arc::new(SessionRepository::new(db_pool)),
20 }
21 }
22
23 pub async fn check_throttle(&self, request: Request, next: Next) -> Result<Response, ApiError> {
24 let Some(req_ctx) = request.extensions().get::<RequestContext>().cloned() else {
25 return Ok(next.run(request).await);
26 };
27
28 if !req_ctx.request.is_tracked {
29 return Ok(next.run(request).await);
30 }
31
32 let throttle_level = self
33 .session_repo
34 .get_throttle_level(&req_ctx.request.session_id)
35 .await
36 .unwrap_or_else(|e| {
37 tracing::warn!(error = %e, session_id = %req_ctx.request.session_id, "Failed to get throttle level");
38 0
39 });
40
41 let level = ThrottleLevel::from(throttle_level);
42
43 if !level.allows_requests() {
44 let api_error = ApiError::new(
45 ErrorCode::RateLimited,
46 "Request blocked due to suspicious activity",
47 );
48 let mut response = api_error.into_response();
49 response
50 .headers_mut()
51 .insert("Retry-After", http::HeaderValue::from_static("3600"));
52 response.headers_mut().insert(
53 "X-Throttle-Level",
54 http::HeaderValue::from_static("blocked"),
55 );
56 response.headers_mut().insert(
57 "X-Throttle-Reason",
58 http::HeaderValue::from_static("behavioral_bot_detection"),
59 );
60 return Ok(response);
61 }
62
63 let mut response = next.run(request).await;
64
65 if throttle_level > 0 {
66 let level_str = match throttle_level {
67 1 => "warning",
68 2 => "severe",
69 _ => "unknown",
70 };
71
72 if let Ok(header_value) = level_str.parse() {
73 response
74 .headers_mut()
75 .insert("X-Throttle-Level", header_value);
76 }
77
78 let multiplier = level.rate_multiplier();
79 if let Ok(header_value) = format!("{multiplier}").parse() {
80 response
81 .headers_mut()
82 .insert("X-Rate-Multiplier", header_value);
83 }
84 }
85
86 Ok(response)
87 }
88}
89
90pub async fn check_throttle_level(
91 middleware: axum::extract::State<ThrottleMiddleware>,
92 request: Request,
93 next: Next,
94) -> Result<Response, ApiError> {
95 middleware.check_throttle(request, next).await
96}