Skip to main content

systemprompt_api/services/middleware/
rate_limit.rs

1use crate::services::middleware::context::{ContextExtractor, ContextMiddleware};
2use axum::extract::Request;
3use axum::middleware::Next;
4use axum::response::{IntoResponse, Response};
5use axum::Router;
6use governor::clock::DefaultClock;
7use governor::state::keyed::DefaultKeyedStateStore;
8use governor::{Quota, RateLimiter};
9use std::num::NonZeroU32;
10use std::sync::Arc;
11use systemprompt_models::api::{ApiError, ErrorCode};
12use systemprompt_models::auth::RateLimitTier;
13use systemprompt_models::config::RateLimitConfig;
14use systemprompt_models::RequestContext;
15use tower_governor::key_extractor::SmartIpKeyExtractor;
16use tracing::warn;
17
18pub trait RouterExt<S> {
19    fn with_rate_limit(self, rate_config: &RateLimitConfig, per_second: u64) -> Self;
20    fn with_auth_middleware<E>(self, middleware: ContextMiddleware<E>) -> Self
21    where
22        E: ContextExtractor + Clone + Send + Sync + 'static;
23}
24
25impl<S> RouterExt<S> for Router<S>
26where
27    S: Clone + Send + Sync + 'static,
28{
29    fn with_rate_limit(self, rate_config: &RateLimitConfig, per_second: u64) -> Self {
30        if rate_config.disabled {
31            return self;
32        }
33
34        let rate_limit_result = tower_governor::governor::GovernorConfigBuilder::default()
35            .per_second(per_second)
36            .burst_size((per_second * rate_config.burst_multiplier) as u32)
37            .key_extractor(SmartIpKeyExtractor)
38            .use_headers()
39            .finish();
40
41        if let Some(rate_limit) = rate_limit_result {
42            self.layer(tower_governor::GovernorLayer::new(rate_limit))
43        } else {
44            warn!("Failed to configure rate limiting - rate limiting disabled for this route");
45            self
46        }
47    }
48
49    fn with_auth_middleware<E>(self, middleware: ContextMiddleware<E>) -> Self
50    where
51        E: ContextExtractor + Clone + Send + Sync + 'static,
52    {
53        self.layer(axum::middleware::from_fn(move |req, next| {
54            let middleware = middleware.clone();
55            async move { middleware.handle(req, next).await }
56        }))
57    }
58}
59
60type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
61
62#[derive(Clone, Debug)]
63pub struct TieredRateLimiter {
64    admin_limiter: Arc<KeyedRateLimiter>,
65    user_limiter: Arc<KeyedRateLimiter>,
66    a2a_limiter: Arc<KeyedRateLimiter>,
67    mcp_limiter: Arc<KeyedRateLimiter>,
68    service_limiter: Arc<KeyedRateLimiter>,
69    anon_limiter: Arc<KeyedRateLimiter>,
70    disabled: bool,
71}
72
73impl TieredRateLimiter {
74    pub fn new(config: &RateLimitConfig, base_per_second: u64) -> Self {
75        let create_limiter = |tier: RateLimitTier| -> Arc<KeyedRateLimiter> {
76            let effective = config.effective_limit(base_per_second, tier);
77            let burst = effective.saturating_mul(config.burst_multiplier);
78            let effective_u32 = u32::try_from(effective).unwrap_or(u32::MAX).max(1);
79            let burst_u32 = u32::try_from(burst).unwrap_or(u32::MAX).max(1);
80            let quota =
81                Quota::per_second(NonZeroU32::new(effective_u32).unwrap_or(NonZeroU32::MIN))
82                    .allow_burst(NonZeroU32::new(burst_u32).unwrap_or(NonZeroU32::MIN));
83            Arc::new(RateLimiter::keyed(quota))
84        };
85
86        Self {
87            admin_limiter: create_limiter(RateLimitTier::Admin),
88            user_limiter: create_limiter(RateLimitTier::User),
89            a2a_limiter: create_limiter(RateLimitTier::A2a),
90            mcp_limiter: create_limiter(RateLimitTier::Mcp),
91            service_limiter: create_limiter(RateLimitTier::Service),
92            anon_limiter: create_limiter(RateLimitTier::Anon),
93            disabled: config.disabled,
94        }
95    }
96
97    pub fn disabled() -> Self {
98        let quota = Quota::per_second(NonZeroU32::MAX);
99        let limiter = Arc::new(RateLimiter::keyed(quota));
100        Self {
101            admin_limiter: Arc::clone(&limiter),
102            user_limiter: Arc::clone(&limiter),
103            a2a_limiter: Arc::clone(&limiter),
104            mcp_limiter: Arc::clone(&limiter),
105            service_limiter: Arc::clone(&limiter),
106            anon_limiter: Arc::clone(&limiter),
107            disabled: true,
108        }
109    }
110
111    fn limiter_for_tier(&self, tier: RateLimitTier) -> &KeyedRateLimiter {
112        match tier {
113            RateLimitTier::Admin => &self.admin_limiter,
114            RateLimitTier::User => &self.user_limiter,
115            RateLimitTier::A2a => &self.a2a_limiter,
116            RateLimitTier::Mcp => &self.mcp_limiter,
117            RateLimitTier::Service => &self.service_limiter,
118            RateLimitTier::Anon => &self.anon_limiter,
119        }
120    }
121
122    pub fn check(&self, tier: RateLimitTier, key: &str) -> bool {
123        if self.disabled {
124            return true;
125        }
126        self.limiter_for_tier(tier)
127            .check_key(&key.to_string())
128            .is_ok()
129    }
130}
131
132pub async fn tiered_rate_limit_middleware(
133    limiter: axum::extract::State<TieredRateLimiter>,
134    request: Request,
135    next: Next,
136) -> Response {
137    if limiter.disabled {
138        return next.run(request).await;
139    }
140
141    let (tier, key) = request
142        .extensions()
143        .get::<RequestContext>()
144        .map(|ctx| {
145            let tier = ctx.rate_limit_tier();
146            let key = ctx.user_id().to_string();
147            (tier, key)
148        })
149        .unwrap_or_else(|| {
150            let ip = request
151                .headers()
152                .get("x-forwarded-for")
153                .and_then(|h| {
154                    h.to_str()
155                        .map_err(|e| {
156                            tracing::trace!(error = %e, "Invalid UTF-8 in x-forwarded-for header");
157                            e
158                        })
159                        .ok()
160                })
161                .and_then(|s| s.split(',').next())
162                .map_or_else(|| "unknown".to_string(), ToString::to_string);
163            (RateLimitTier::Anon, ip)
164        });
165
166    if limiter.check(tier, &key) {
167        next.run(request).await
168    } else {
169        warn!(
170            tier = %tier.as_str(),
171            key = %key,
172            "Rate limit exceeded"
173        );
174        let api_error = ApiError::new(ErrorCode::RateLimited, "Rate limit exceeded");
175        let mut response = api_error.into_response();
176        response
177            .headers_mut()
178            .insert("Retry-After", http::HeaderValue::from_static("1"));
179        if let Ok(tier_value) = http::HeaderValue::from_str(tier.as_str()) {
180            response
181                .headers_mut()
182                .insert("X-Rate-Limit-Tier", tier_value);
183        }
184        response
185    }
186}