systemprompt_api/services/middleware/
rate_limit.rs1use crate::services::middleware::authz::{AuthzPolicy, authz_gate};
8use crate::services::middleware::client_addr::resolve_client_ip;
9use crate::services::middleware::context::{ContextExtractor, ContextMiddleware};
10use crate::services::middleware::jti_revocation::{JtiRevocationState, jti_revocation_middleware};
11use axum::Router;
12use axum::extract::{ConnectInfo, Request};
13use axum::middleware::Next;
14use axum::response::{IntoResponse, Response};
15use governor::clock::DefaultClock;
16use governor::state::keyed::DefaultKeyedStateStore;
17use governor::{Quota, RateLimiter};
18use ipnet::IpNet;
19use std::net::SocketAddr;
20use std::num::NonZeroU32;
21use std::sync::Arc;
22use systemprompt_models::RequestContext;
23use systemprompt_models::api::{ApiError, ErrorCode};
24use systemprompt_models::auth::RateLimitTier;
25use systemprompt_models::config::RateLimitConfig;
26use tower_governor::key_extractor::SmartIpKeyExtractor;
27use tracing::warn;
28
29pub trait RouterExt<S> {
30 fn with_rate_limit(self, rate_config: &RateLimitConfig, per_second: u64) -> Self;
31
32 fn with_auth<E>(self, auth: ContextMiddleware<E>, policy: AuthzPolicy) -> Self
33 where
34 E: ContextExtractor + Clone + Send + Sync + 'static;
35
36 fn with_jti_check(self, jti_state: JtiRevocationState) -> Self;
37}
38
39impl<S> RouterExt<S> for Router<S>
40where
41 S: Clone + Send + Sync + 'static,
42{
43 fn with_rate_limit(self, rate_config: &RateLimitConfig, per_second: u64) -> Self {
44 if rate_config.disabled {
45 return self;
46 }
47
48 let rate_limit_result = tower_governor::governor::GovernorConfigBuilder::default()
49 .per_second(per_second)
50 .burst_size((per_second * rate_config.burst_multiplier) as u32)
51 .key_extractor(SmartIpKeyExtractor)
52 .use_headers()
53 .finish();
54
55 if let Some(rate_limit) = rate_limit_result {
56 self.layer(tower_governor::GovernorLayer::new(rate_limit))
57 } else {
58 warn!("Failed to configure rate limiting - rate limiting disabled for this route");
59 self
60 }
61 }
62
63 fn with_auth<E>(self, auth: ContextMiddleware<E>, policy: AuthzPolicy) -> Self
64 where
65 E: ContextExtractor + Clone + Send + Sync + 'static,
66 {
67 self.layer(axum::middleware::from_fn(move |req, next| async move {
68 authz_gate(policy, req, next).await
69 }))
70 .layer(axum::middleware::from_fn(move |req, next| {
71 let auth = auth.clone();
72 async move { auth.handle(req, next).await }
73 }))
74 }
75
76 fn with_jti_check(self, jti_state: JtiRevocationState) -> Self {
77 self.layer(axum::middleware::from_fn_with_state(
78 jti_state,
79 jti_revocation_middleware,
80 ))
81 }
82}
83
84type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
85
86#[derive(Clone, Debug)]
87pub struct TieredRateLimiter {
88 admin_limiter: Arc<KeyedRateLimiter>,
89 user_limiter: Arc<KeyedRateLimiter>,
90 a2a_limiter: Arc<KeyedRateLimiter>,
91 mcp_limiter: Arc<KeyedRateLimiter>,
92 service_limiter: Arc<KeyedRateLimiter>,
93 anon_limiter: Arc<KeyedRateLimiter>,
94 disabled: bool,
95 trusted_proxies: Arc<Vec<IpNet>>,
96}
97
98impl TieredRateLimiter {
99 pub fn new(config: &RateLimitConfig, base_per_second: u64) -> Self {
100 Self::with_trusted_proxies(config, base_per_second, Vec::new())
101 }
102
103 pub fn with_trusted_proxies(
104 config: &RateLimitConfig,
105 base_per_second: u64,
106 trusted_proxies: Vec<IpNet>,
107 ) -> Self {
108 let create_limiter = |tier: RateLimitTier| -> Arc<KeyedRateLimiter> {
109 let effective = config.effective_limit(base_per_second, tier);
110 let burst = effective.saturating_mul(config.burst_multiplier);
111 let effective_u32 = u32::try_from(effective).unwrap_or(u32::MAX).max(1);
112 let burst_u32 = u32::try_from(burst).unwrap_or(u32::MAX).max(1);
113 let quota =
114 Quota::per_second(NonZeroU32::new(effective_u32).unwrap_or(NonZeroU32::MIN))
115 .allow_burst(NonZeroU32::new(burst_u32).unwrap_or(NonZeroU32::MIN));
116 Arc::new(RateLimiter::keyed(quota))
117 };
118
119 Self {
120 admin_limiter: create_limiter(RateLimitTier::Admin),
121 user_limiter: create_limiter(RateLimitTier::User),
122 a2a_limiter: create_limiter(RateLimitTier::A2a),
123 mcp_limiter: create_limiter(RateLimitTier::Mcp),
124 service_limiter: create_limiter(RateLimitTier::Service),
125 anon_limiter: create_limiter(RateLimitTier::Anon),
126 disabled: config.disabled,
127 trusted_proxies: Arc::new(trusted_proxies),
128 }
129 }
130
131 pub fn disabled() -> Self {
132 let quota = Quota::per_second(NonZeroU32::MAX);
133 let limiter = Arc::new(RateLimiter::keyed(quota));
134 Self {
135 admin_limiter: Arc::clone(&limiter),
136 user_limiter: Arc::clone(&limiter),
137 a2a_limiter: Arc::clone(&limiter),
138 mcp_limiter: Arc::clone(&limiter),
139 service_limiter: Arc::clone(&limiter),
140 anon_limiter: Arc::clone(&limiter),
141 disabled: true,
142 trusted_proxies: Arc::new(Vec::new()),
143 }
144 }
145
146 #[must_use]
147 pub fn trusted_proxies(&self) -> &[IpNet] {
148 &self.trusted_proxies
149 }
150
151 fn limiter_for_tier(&self, tier: RateLimitTier) -> &KeyedRateLimiter {
152 match tier {
153 RateLimitTier::Admin => &self.admin_limiter,
154 RateLimitTier::User => &self.user_limiter,
155 RateLimitTier::A2a => &self.a2a_limiter,
156 RateLimitTier::Mcp => &self.mcp_limiter,
157 RateLimitTier::Service => &self.service_limiter,
158 RateLimitTier::Anon => &self.anon_limiter,
159 }
160 }
161
162 pub fn check(&self, tier: RateLimitTier, key: &str) -> bool {
163 if self.disabled {
164 return true;
165 }
166 self.limiter_for_tier(tier)
167 .check_key(&key.to_owned())
168 .is_ok()
169 }
170}
171
172pub async fn tiered_rate_limit_middleware(
173 limiter: axum::extract::State<TieredRateLimiter>,
174 request: Request,
175 next: Next,
176) -> Response {
177 if limiter.disabled {
178 return next.run(request).await;
179 }
180
181 let (tier, key) = request.extensions().get::<RequestContext>().map_or_else(
182 || {
183 let connect_info = request.extensions().get::<ConnectInfo<SocketAddr>>();
184 let ip = resolve_client_ip(request.headers(), connect_info, limiter.trusted_proxies())
185 .map_or_else(|| "unknown".to_owned(), |a| a.to_string());
186 (RateLimitTier::Anon, ip)
187 },
188 |ctx| {
189 let tier = ctx.rate_limit_tier();
190 let key = ctx.user_id().to_string();
191 (tier, key)
192 },
193 );
194
195 if limiter.check(tier, &key) {
196 next.run(request).await
197 } else {
198 warn!(
199 tier = %tier.as_str(),
200 key = %key,
201 "Rate limit exceeded"
202 );
203 let api_error = ApiError::new(ErrorCode::RateLimited, "Rate limit exceeded");
204 let mut response = api_error.into_response();
205 response
206 .headers_mut()
207 .insert("Retry-After", http::HeaderValue::from_static("1"));
208 if let Ok(tier_value) = http::HeaderValue::from_str(tier.as_str()) {
209 response
210 .headers_mut()
211 .insert("X-Rate-Limit-Tier", tier_value);
212 }
213 response
214 }
215}