Skip to main content

systemprompt_api/services/middleware/
rate_limit.rs

1//! Router extension traits for rate limiting and authenticated route groups.
2//!
3//! `RouterExt::with_auth` attaches authentication and authorization in one
4//! call: it requires an `AuthzPolicy`, so a route group cannot be mounted
5//! authenticated-but-unauthorized — omitting the policy is a compile error.
6
7use crate::services::middleware::authz::{AuthzPolicy, authz_gate};
8use crate::services::middleware::client_addr::resolve_client_ip;
9use crate::services::middleware::context::{
10    A2AContextMiddleware, McpContextMiddleware, PublicContextMiddleware, UserOnlyContextMiddleware,
11};
12use axum::Router;
13use axum::extract::{ConnectInfo, Request};
14use axum::middleware::Next;
15use axum::response::{IntoResponse, Response};
16use governor::clock::DefaultClock;
17use governor::state::keyed::DefaultKeyedStateStore;
18use governor::{Quota, RateLimiter};
19use ipnet::IpNet;
20use std::future::Future;
21use std::net::SocketAddr;
22use std::num::NonZeroU32;
23use std::sync::Arc;
24use systemprompt_models::RequestContext;
25use systemprompt_models::api::{ApiError, ErrorCode};
26use systemprompt_models::auth::RateLimitTier;
27use systemprompt_models::config::RateLimitConfig;
28use tower_governor::key_extractor::SmartIpKeyExtractor;
29use tracing::warn;
30
31/// Builds a [`systemprompt_models::RequestContext`] for a route group.
32///
33/// Implemented by each of the four sibling context middlewares
34/// ([`PublicContextMiddleware`], [`UserOnlyContextMiddleware`],
35/// [`A2AContextMiddleware`], [`McpContextMiddleware`]). Sealed to those four —
36/// third parties cannot stand up a new flavour outside this crate, which keeps
37/// the route-mount surface auditable.
38pub trait ContextLayer: Clone + Send + Sync + 'static {
39    fn handle(self, req: Request, next: Next) -> impl Future<Output = Response> + Send;
40}
41
42impl ContextLayer for PublicContextMiddleware {
43    async fn handle(self, req: Request, next: Next) -> Response {
44        Self::handle(&self, req, next).await
45    }
46}
47
48impl ContextLayer for UserOnlyContextMiddleware {
49    async fn handle(self, req: Request, next: Next) -> Response {
50        Self::handle(&self, req, next).await
51    }
52}
53
54impl ContextLayer for A2AContextMiddleware {
55    async fn handle(self, req: Request, next: Next) -> Response {
56        Self::handle(&self, req, next).await
57    }
58}
59
60impl ContextLayer for McpContextMiddleware {
61    async fn handle(self, req: Request, next: Next) -> Response {
62        Self::handle(&self, req, next).await
63    }
64}
65
66pub trait RouterExt<S> {
67    fn with_rate_limit(self, rate_config: &RateLimitConfig, per_second: u64) -> Self;
68
69    fn with_auth<L: ContextLayer>(self, auth: L, policy: AuthzPolicy) -> Self;
70}
71
72impl<S> RouterExt<S> for Router<S>
73where
74    S: Clone + Send + Sync + 'static,
75{
76    fn with_rate_limit(self, rate_config: &RateLimitConfig, per_second: u64) -> Self {
77        if rate_config.disabled {
78            return self;
79        }
80
81        let rate_limit_result = tower_governor::governor::GovernorConfigBuilder::default()
82            .per_second(per_second)
83            .burst_size((per_second * rate_config.burst_multiplier) as u32)
84            .key_extractor(SmartIpKeyExtractor)
85            .use_headers()
86            .finish();
87
88        if let Some(rate_limit) = rate_limit_result {
89            self.layer(tower_governor::GovernorLayer::new(rate_limit))
90        } else {
91            warn!("Failed to configure rate limiting - rate limiting disabled for this route");
92            self
93        }
94    }
95
96    fn with_auth<L: ContextLayer>(self, auth: L, policy: AuthzPolicy) -> Self {
97        self.layer(axum::middleware::from_fn(move |req, next| async move {
98            authz_gate(policy, req, next).await
99        }))
100        .layer(axum::middleware::from_fn(move |req, next| {
101            let auth = auth.clone();
102            async move { auth.handle(req, next).await }
103        }))
104    }
105}
106
107type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
108
109#[derive(Clone, Debug)]
110pub struct TieredRateLimiter {
111    admin_limiter: Arc<KeyedRateLimiter>,
112    user_limiter: Arc<KeyedRateLimiter>,
113    a2a_limiter: Arc<KeyedRateLimiter>,
114    mcp_limiter: Arc<KeyedRateLimiter>,
115    service_limiter: Arc<KeyedRateLimiter>,
116    anon_limiter: Arc<KeyedRateLimiter>,
117    disabled: bool,
118    trusted_proxies: Arc<Vec<IpNet>>,
119}
120
121impl TieredRateLimiter {
122    pub fn new(config: &RateLimitConfig, base_per_second: u64) -> Self {
123        Self::with_trusted_proxies(config, base_per_second, Vec::new())
124    }
125
126    pub fn with_trusted_proxies(
127        config: &RateLimitConfig,
128        base_per_second: u64,
129        trusted_proxies: Vec<IpNet>,
130    ) -> Self {
131        let create_limiter = |tier: RateLimitTier| -> Arc<KeyedRateLimiter> {
132            let effective = config.effective_limit(base_per_second, tier);
133            let burst = effective.saturating_mul(config.burst_multiplier);
134            let effective_u32 = u32::try_from(effective).unwrap_or(u32::MAX).max(1);
135            let burst_u32 = u32::try_from(burst).unwrap_or(u32::MAX).max(1);
136            let quota =
137                Quota::per_second(NonZeroU32::new(effective_u32).unwrap_or(NonZeroU32::MIN))
138                    .allow_burst(NonZeroU32::new(burst_u32).unwrap_or(NonZeroU32::MIN));
139            Arc::new(RateLimiter::keyed(quota))
140        };
141
142        Self {
143            admin_limiter: create_limiter(RateLimitTier::Admin),
144            user_limiter: create_limiter(RateLimitTier::User),
145            a2a_limiter: create_limiter(RateLimitTier::A2a),
146            mcp_limiter: create_limiter(RateLimitTier::Mcp),
147            service_limiter: create_limiter(RateLimitTier::Service),
148            anon_limiter: create_limiter(RateLimitTier::Anon),
149            disabled: config.disabled,
150            trusted_proxies: Arc::new(trusted_proxies),
151        }
152    }
153
154    pub fn disabled() -> Self {
155        let quota = Quota::per_second(NonZeroU32::MAX);
156        let limiter = Arc::new(RateLimiter::keyed(quota));
157        Self {
158            admin_limiter: Arc::clone(&limiter),
159            user_limiter: Arc::clone(&limiter),
160            a2a_limiter: Arc::clone(&limiter),
161            mcp_limiter: Arc::clone(&limiter),
162            service_limiter: Arc::clone(&limiter),
163            anon_limiter: Arc::clone(&limiter),
164            disabled: true,
165            trusted_proxies: Arc::new(Vec::new()),
166        }
167    }
168
169    #[must_use]
170    pub fn trusted_proxies(&self) -> &[IpNet] {
171        &self.trusted_proxies
172    }
173
174    fn limiter_for_tier(&self, tier: RateLimitTier) -> &KeyedRateLimiter {
175        match tier {
176            RateLimitTier::Admin => &self.admin_limiter,
177            RateLimitTier::User => &self.user_limiter,
178            RateLimitTier::A2a => &self.a2a_limiter,
179            RateLimitTier::Mcp => &self.mcp_limiter,
180            RateLimitTier::Service => &self.service_limiter,
181            RateLimitTier::Anon => &self.anon_limiter,
182        }
183    }
184
185    pub fn check(&self, tier: RateLimitTier, key: &str) -> bool {
186        if self.disabled {
187            return true;
188        }
189        self.limiter_for_tier(tier)
190            .check_key(&key.to_owned())
191            .is_ok()
192    }
193}
194
195pub async fn tiered_rate_limit_middleware(
196    limiter: axum::extract::State<TieredRateLimiter>,
197    request: Request,
198    next: Next,
199) -> Response {
200    if limiter.disabled {
201        return next.run(request).await;
202    }
203
204    let (tier, key) = request.extensions().get::<RequestContext>().map_or_else(
205        || {
206            let connect_info = request.extensions().get::<ConnectInfo<SocketAddr>>();
207            let ip = resolve_client_ip(request.headers(), connect_info, limiter.trusted_proxies())
208                .map_or_else(|| "unknown".to_owned(), |a| a.to_string());
209            (RateLimitTier::Anon, ip)
210        },
211        |ctx| {
212            let tier = ctx.rate_limit_tier();
213            let key = ctx.user_id().to_string();
214            (tier, key)
215        },
216    );
217
218    if limiter.check(tier, &key) {
219        next.run(request).await
220    } else {
221        warn!(
222            tier = %tier.as_str(),
223            key = %key,
224            "Rate limit exceeded"
225        );
226        let api_error = ApiError::new(ErrorCode::RateLimited, "Rate limit exceeded");
227        let mut response = api_error.into_response();
228        response
229            .headers_mut()
230            .insert("Retry-After", http::HeaderValue::from_static("1"));
231        if let Ok(tier_value) = http::HeaderValue::from_str(tier.as_str()) {
232            response
233                .headers_mut()
234                .insert("X-Rate-Limit-Tier", tier_value);
235        }
236        response
237    }
238}