Skip to main content

systemprompt_api/services/middleware/
rate_limit.rs

1//! Router extension traits for rate limiting and authenticated route groups.
2//!
3//! `RouterExt::with_auth` attaches authentication and authorization in one
4//! call: it requires an `AuthzPolicy`, so a route group cannot be mounted
5//! authenticated-but-unauthorized — omitting the policy is a compile error.
6
7use crate::services::middleware::authz::{AuthzPolicy, authz_gate};
8use crate::services::middleware::client_addr::resolve_client_ip;
9use crate::services::middleware::context::{ContextExtractor, ContextMiddleware};
10use crate::services::middleware::jti_revocation::{JtiRevocationState, jti_revocation_middleware};
11use axum::Router;
12use axum::extract::{ConnectInfo, Request};
13use axum::middleware::Next;
14use axum::response::{IntoResponse, Response};
15use governor::clock::DefaultClock;
16use governor::state::keyed::DefaultKeyedStateStore;
17use governor::{Quota, RateLimiter};
18use ipnet::IpNet;
19use std::net::SocketAddr;
20use std::num::NonZeroU32;
21use std::sync::Arc;
22use systemprompt_models::RequestContext;
23use systemprompt_models::api::{ApiError, ErrorCode};
24use systemprompt_models::auth::RateLimitTier;
25use systemprompt_models::config::RateLimitConfig;
26use tower_governor::key_extractor::SmartIpKeyExtractor;
27use tracing::warn;
28
29pub trait RouterExt<S> {
30    fn with_rate_limit(self, rate_config: &RateLimitConfig, per_second: u64) -> Self;
31
32    fn with_auth<E>(self, auth: ContextMiddleware<E>, policy: AuthzPolicy) -> Self
33    where
34        E: ContextExtractor + Clone + Send + Sync + 'static;
35
36    fn with_jti_check(self, jti_state: JtiRevocationState) -> Self;
37}
38
39impl<S> RouterExt<S> for Router<S>
40where
41    S: Clone + Send + Sync + 'static,
42{
43    fn with_rate_limit(self, rate_config: &RateLimitConfig, per_second: u64) -> Self {
44        if rate_config.disabled {
45            return self;
46        }
47
48        let rate_limit_result = tower_governor::governor::GovernorConfigBuilder::default()
49            .per_second(per_second)
50            .burst_size((per_second * rate_config.burst_multiplier) as u32)
51            .key_extractor(SmartIpKeyExtractor)
52            .use_headers()
53            .finish();
54
55        if let Some(rate_limit) = rate_limit_result {
56            self.layer(tower_governor::GovernorLayer::new(rate_limit))
57        } else {
58            warn!("Failed to configure rate limiting - rate limiting disabled for this route");
59            self
60        }
61    }
62
63    fn with_auth<E>(self, auth: ContextMiddleware<E>, policy: AuthzPolicy) -> Self
64    where
65        E: ContextExtractor + Clone + Send + Sync + 'static,
66    {
67        self.layer(axum::middleware::from_fn(move |req, next| async move {
68            authz_gate(policy, req, next).await
69        }))
70        .layer(axum::middleware::from_fn(move |req, next| {
71            let auth = auth.clone();
72            async move { auth.handle(req, next).await }
73        }))
74    }
75
76    fn with_jti_check(self, jti_state: JtiRevocationState) -> Self {
77        self.layer(axum::middleware::from_fn_with_state(
78            jti_state,
79            jti_revocation_middleware,
80        ))
81    }
82}
83
84type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
85
86#[derive(Clone, Debug)]
87pub struct TieredRateLimiter {
88    admin_limiter: Arc<KeyedRateLimiter>,
89    user_limiter: Arc<KeyedRateLimiter>,
90    a2a_limiter: Arc<KeyedRateLimiter>,
91    mcp_limiter: Arc<KeyedRateLimiter>,
92    service_limiter: Arc<KeyedRateLimiter>,
93    anon_limiter: Arc<KeyedRateLimiter>,
94    disabled: bool,
95    trusted_proxies: Arc<Vec<IpNet>>,
96}
97
98impl TieredRateLimiter {
99    pub fn new(config: &RateLimitConfig, base_per_second: u64) -> Self {
100        Self::with_trusted_proxies(config, base_per_second, Vec::new())
101    }
102
103    pub fn with_trusted_proxies(
104        config: &RateLimitConfig,
105        base_per_second: u64,
106        trusted_proxies: Vec<IpNet>,
107    ) -> Self {
108        let create_limiter = |tier: RateLimitTier| -> Arc<KeyedRateLimiter> {
109            let effective = config.effective_limit(base_per_second, tier);
110            let burst = effective.saturating_mul(config.burst_multiplier);
111            let effective_u32 = u32::try_from(effective).unwrap_or(u32::MAX).max(1);
112            let burst_u32 = u32::try_from(burst).unwrap_or(u32::MAX).max(1);
113            let quota =
114                Quota::per_second(NonZeroU32::new(effective_u32).unwrap_or(NonZeroU32::MIN))
115                    .allow_burst(NonZeroU32::new(burst_u32).unwrap_or(NonZeroU32::MIN));
116            Arc::new(RateLimiter::keyed(quota))
117        };
118
119        Self {
120            admin_limiter: create_limiter(RateLimitTier::Admin),
121            user_limiter: create_limiter(RateLimitTier::User),
122            a2a_limiter: create_limiter(RateLimitTier::A2a),
123            mcp_limiter: create_limiter(RateLimitTier::Mcp),
124            service_limiter: create_limiter(RateLimitTier::Service),
125            anon_limiter: create_limiter(RateLimitTier::Anon),
126            disabled: config.disabled,
127            trusted_proxies: Arc::new(trusted_proxies),
128        }
129    }
130
131    pub fn disabled() -> Self {
132        let quota = Quota::per_second(NonZeroU32::MAX);
133        let limiter = Arc::new(RateLimiter::keyed(quota));
134        Self {
135            admin_limiter: Arc::clone(&limiter),
136            user_limiter: Arc::clone(&limiter),
137            a2a_limiter: Arc::clone(&limiter),
138            mcp_limiter: Arc::clone(&limiter),
139            service_limiter: Arc::clone(&limiter),
140            anon_limiter: Arc::clone(&limiter),
141            disabled: true,
142            trusted_proxies: Arc::new(Vec::new()),
143        }
144    }
145
146    #[must_use]
147    pub fn trusted_proxies(&self) -> &[IpNet] {
148        &self.trusted_proxies
149    }
150
151    fn limiter_for_tier(&self, tier: RateLimitTier) -> &KeyedRateLimiter {
152        match tier {
153            RateLimitTier::Admin => &self.admin_limiter,
154            RateLimitTier::User => &self.user_limiter,
155            RateLimitTier::A2a => &self.a2a_limiter,
156            RateLimitTier::Mcp => &self.mcp_limiter,
157            RateLimitTier::Service => &self.service_limiter,
158            RateLimitTier::Anon => &self.anon_limiter,
159        }
160    }
161
162    pub fn check(&self, tier: RateLimitTier, key: &str) -> bool {
163        if self.disabled {
164            return true;
165        }
166        self.limiter_for_tier(tier)
167            .check_key(&key.to_owned())
168            .is_ok()
169    }
170}
171
172pub async fn tiered_rate_limit_middleware(
173    limiter: axum::extract::State<TieredRateLimiter>,
174    request: Request,
175    next: Next,
176) -> Response {
177    if limiter.disabled {
178        return next.run(request).await;
179    }
180
181    let (tier, key) = request.extensions().get::<RequestContext>().map_or_else(
182        || {
183            let connect_info = request.extensions().get::<ConnectInfo<SocketAddr>>();
184            let ip = resolve_client_ip(request.headers(), connect_info, limiter.trusted_proxies())
185                .map_or_else(|| "unknown".to_owned(), |a| a.to_string());
186            (RateLimitTier::Anon, ip)
187        },
188        |ctx| {
189            let tier = ctx.rate_limit_tier();
190            let key = ctx.user_id().to_string();
191            (tier, key)
192        },
193    );
194
195    if limiter.check(tier, &key) {
196        next.run(request).await
197    } else {
198        warn!(
199            tier = %tier.as_str(),
200            key = %key,
201            "Rate limit exceeded"
202        );
203        let api_error = ApiError::new(ErrorCode::RateLimited, "Rate limit exceeded");
204        let mut response = api_error.into_response();
205        response
206            .headers_mut()
207            .insert("Retry-After", http::HeaderValue::from_static("1"));
208        if let Ok(tier_value) = http::HeaderValue::from_str(tier.as_str()) {
209            response
210                .headers_mut()
211                .insert("X-Rate-Limit-Tier", tier_value);
212        }
213        response
214    }
215}