systemprompt_api/services/middleware/
rate_limit.rs1use crate::services::middleware::authz::{AuthzPolicy, authz_gate};
8use crate::services::middleware::client_addr::resolve_client_ip;
9use crate::services::middleware::context::{
10 A2AContextMiddleware, McpContextMiddleware, PublicContextMiddleware, UserOnlyContextMiddleware,
11};
12use axum::Router;
13use axum::extract::{ConnectInfo, Request};
14use axum::middleware::Next;
15use axum::response::{IntoResponse, Response};
16use governor::clock::DefaultClock;
17use governor::state::keyed::DefaultKeyedStateStore;
18use governor::{Quota, RateLimiter};
19use ipnet::IpNet;
20use std::future::Future;
21use std::net::SocketAddr;
22use std::num::NonZeroU32;
23use std::sync::Arc;
24use systemprompt_models::RequestContext;
25use systemprompt_models::api::{ApiError, ErrorCode};
26use systemprompt_models::auth::RateLimitTier;
27use systemprompt_models::config::RateLimitConfig;
28use tower_governor::key_extractor::SmartIpKeyExtractor;
29use tracing::warn;
30
31pub trait ContextLayer: Clone + Send + Sync + 'static {
39 fn handle(self, req: Request, next: Next) -> impl Future<Output = Response> + Send;
40}
41
42impl ContextLayer for PublicContextMiddleware {
43 async fn handle(self, req: Request, next: Next) -> Response {
44 Self::handle(&self, req, next).await
45 }
46}
47
48impl ContextLayer for UserOnlyContextMiddleware {
49 async fn handle(self, req: Request, next: Next) -> Response {
50 Self::handle(&self, req, next).await
51 }
52}
53
54impl ContextLayer for A2AContextMiddleware {
55 async fn handle(self, req: Request, next: Next) -> Response {
56 Self::handle(&self, req, next).await
57 }
58}
59
60impl ContextLayer for McpContextMiddleware {
61 async fn handle(self, req: Request, next: Next) -> Response {
62 Self::handle(&self, req, next).await
63 }
64}
65
66pub trait RouterExt<S> {
67 fn with_rate_limit(self, rate_config: &RateLimitConfig, per_second: u64) -> Self;
68
69 fn with_auth<L: ContextLayer>(self, auth: L, policy: AuthzPolicy) -> Self;
70}
71
72impl<S> RouterExt<S> for Router<S>
73where
74 S: Clone + Send + Sync + 'static,
75{
76 fn with_rate_limit(self, rate_config: &RateLimitConfig, per_second: u64) -> Self {
77 if rate_config.disabled {
78 return self;
79 }
80
81 let rate_limit_result = tower_governor::governor::GovernorConfigBuilder::default()
82 .per_second(per_second)
83 .burst_size((per_second * rate_config.burst_multiplier) as u32)
84 .key_extractor(SmartIpKeyExtractor)
85 .use_headers()
86 .finish();
87
88 if let Some(rate_limit) = rate_limit_result {
89 self.layer(tower_governor::GovernorLayer::new(rate_limit))
90 } else {
91 warn!("Failed to configure rate limiting - rate limiting disabled for this route");
92 self
93 }
94 }
95
96 fn with_auth<L: ContextLayer>(self, auth: L, policy: AuthzPolicy) -> Self {
97 self.layer(axum::middleware::from_fn(move |req, next| async move {
98 authz_gate(policy, req, next).await
99 }))
100 .layer(axum::middleware::from_fn(move |req, next| {
101 let auth = auth.clone();
102 async move { auth.handle(req, next).await }
103 }))
104 }
105}
106
107type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
108
109#[derive(Clone, Debug)]
110pub struct TieredRateLimiter {
111 admin_limiter: Arc<KeyedRateLimiter>,
112 user_limiter: Arc<KeyedRateLimiter>,
113 a2a_limiter: Arc<KeyedRateLimiter>,
114 mcp_limiter: Arc<KeyedRateLimiter>,
115 service_limiter: Arc<KeyedRateLimiter>,
116 anon_limiter: Arc<KeyedRateLimiter>,
117 disabled: bool,
118 trusted_proxies: Arc<Vec<IpNet>>,
119}
120
121impl TieredRateLimiter {
122 pub fn new(config: &RateLimitConfig, base_per_second: u64) -> Self {
123 Self::with_trusted_proxies(config, base_per_second, Vec::new())
124 }
125
126 pub fn with_trusted_proxies(
127 config: &RateLimitConfig,
128 base_per_second: u64,
129 trusted_proxies: Vec<IpNet>,
130 ) -> Self {
131 let create_limiter = |tier: RateLimitTier| -> Arc<KeyedRateLimiter> {
132 let effective = config.effective_limit(base_per_second, tier);
133 let burst = effective.saturating_mul(config.burst_multiplier);
134 let effective_u32 = u32::try_from(effective).unwrap_or(u32::MAX).max(1);
135 let burst_u32 = u32::try_from(burst).unwrap_or(u32::MAX).max(1);
136 let quota =
137 Quota::per_second(NonZeroU32::new(effective_u32).unwrap_or(NonZeroU32::MIN))
138 .allow_burst(NonZeroU32::new(burst_u32).unwrap_or(NonZeroU32::MIN));
139 Arc::new(RateLimiter::keyed(quota))
140 };
141
142 Self {
143 admin_limiter: create_limiter(RateLimitTier::Admin),
144 user_limiter: create_limiter(RateLimitTier::User),
145 a2a_limiter: create_limiter(RateLimitTier::A2a),
146 mcp_limiter: create_limiter(RateLimitTier::Mcp),
147 service_limiter: create_limiter(RateLimitTier::Service),
148 anon_limiter: create_limiter(RateLimitTier::Anon),
149 disabled: config.disabled,
150 trusted_proxies: Arc::new(trusted_proxies),
151 }
152 }
153
154 pub fn disabled() -> Self {
155 let quota = Quota::per_second(NonZeroU32::MAX);
156 let limiter = Arc::new(RateLimiter::keyed(quota));
157 Self {
158 admin_limiter: Arc::clone(&limiter),
159 user_limiter: Arc::clone(&limiter),
160 a2a_limiter: Arc::clone(&limiter),
161 mcp_limiter: Arc::clone(&limiter),
162 service_limiter: Arc::clone(&limiter),
163 anon_limiter: Arc::clone(&limiter),
164 disabled: true,
165 trusted_proxies: Arc::new(Vec::new()),
166 }
167 }
168
169 #[must_use]
170 pub fn trusted_proxies(&self) -> &[IpNet] {
171 &self.trusted_proxies
172 }
173
174 fn limiter_for_tier(&self, tier: RateLimitTier) -> &KeyedRateLimiter {
175 match tier {
176 RateLimitTier::Admin => &self.admin_limiter,
177 RateLimitTier::User => &self.user_limiter,
178 RateLimitTier::A2a => &self.a2a_limiter,
179 RateLimitTier::Mcp => &self.mcp_limiter,
180 RateLimitTier::Service => &self.service_limiter,
181 RateLimitTier::Anon => &self.anon_limiter,
182 }
183 }
184
185 pub fn check(&self, tier: RateLimitTier, key: &str) -> bool {
186 if self.disabled {
187 return true;
188 }
189 self.limiter_for_tier(tier)
190 .check_key(&key.to_owned())
191 .is_ok()
192 }
193}
194
195pub async fn tiered_rate_limit_middleware(
196 limiter: axum::extract::State<TieredRateLimiter>,
197 request: Request,
198 next: Next,
199) -> Response {
200 if limiter.disabled {
201 return next.run(request).await;
202 }
203
204 let (tier, key) = request.extensions().get::<RequestContext>().map_or_else(
205 || {
206 let connect_info = request.extensions().get::<ConnectInfo<SocketAddr>>();
207 let ip = resolve_client_ip(request.headers(), connect_info, limiter.trusted_proxies())
208 .map_or_else(|| "unknown".to_owned(), |a| a.to_string());
209 (RateLimitTier::Anon, ip)
210 },
211 |ctx| {
212 let tier = ctx.rate_limit_tier();
213 let key = ctx.user_id().to_string();
214 (tier, key)
215 },
216 );
217
218 if limiter.check(tier, &key) {
219 next.run(request).await
220 } else {
221 warn!(
222 tier = %tier.as_str(),
223 key = %key,
224 "Rate limit exceeded"
225 );
226 let api_error = ApiError::new(ErrorCode::RateLimited, "Rate limit exceeded");
227 let mut response = api_error.into_response();
228 response
229 .headers_mut()
230 .insert("Retry-After", http::HeaderValue::from_static("1"));
231 if let Ok(tier_value) = http::HeaderValue::from_str(tier.as_str()) {
232 response
233 .headers_mut()
234 .insert("X-Rate-Limit-Tier", tier_value);
235 }
236 response
237 }
238}