things3_cli/mcp/middleware/
rate_limit.rs1use super::{McpMiddleware, MiddlewareContext, MiddlewareResult};
4use crate::mcp::{CallToolRequest, CallToolResult, McpResult};
5use governor::clock::DefaultClock;
6use governor::{state::keyed::DefaultKeyedStateStore, Quota, RateLimiter};
7use nonzero_ext::nonzero;
8use serde_json::Value;
9use std::sync::Arc;
10
11pub struct RateLimitMiddleware {
12 rate_limiter: Arc<RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>>,
13 default_limit: u32,
14 #[allow(dead_code)]
15 burst_limit: u32,
16}
17
18impl RateLimitMiddleware {
19 #[must_use]
21 pub fn new(requests_per_minute: u32, burst_limit: u32) -> Self {
22 let quota = Quota::per_minute(nonzero!(60u32)); let rate_limiter = Arc::new(RateLimiter::keyed(quota));
24
25 Self {
26 rate_limiter,
27 default_limit: requests_per_minute,
28 burst_limit,
29 }
30 }
31
32 #[must_use]
34 pub fn with_limits(requests_per_minute: u32, burst_limit: u32) -> Self {
35 Self::new(requests_per_minute, burst_limit)
36 }
37
38 #[allow(clippy::should_implement_trait)]
40 #[must_use]
41 pub fn default() -> Self {
42 Self::new(60, 10)
43 }
44
45 fn extract_client_id(request: &CallToolRequest, context: &MiddlewareContext) -> String {
47 if let Some(auth_key_id) = context.get_metadata("auth_key_id").and_then(|v| v.as_str()) {
49 return format!("api_key:{auth_key_id}");
50 }
51
52 if let Some(auth_user_id) = context
53 .get_metadata("auth_user_id")
54 .and_then(|v| v.as_str())
55 {
56 return format!("jwt:{auth_user_id}");
57 }
58
59 if let Some(args) = &request.arguments {
61 if let Some(client_id) = args.get("client_id").and_then(|v| v.as_str()) {
62 return format!("client:{client_id}");
63 }
64 }
65
66 format!("request:{}", context.request_id)
68 }
69
70 fn check_rate_limit(&self, client_id: &str) -> bool {
72 self.rate_limiter.check_key(&client_id.to_string()).is_ok()
73 }
74
75 fn get_remaining_requests(&self, _client_id: &str) -> u32 {
77 self.default_limit
80 }
81}
82
83#[async_trait::async_trait]
84impl McpMiddleware for RateLimitMiddleware {
85 fn name(&self) -> &'static str {
86 "rate_limiting"
87 }
88
89 fn priority(&self) -> i32 {
90 20 }
92
93 async fn before_request(
94 &self,
95 request: &CallToolRequest,
96 context: &mut MiddlewareContext,
97 ) -> McpResult<MiddlewareResult> {
98 let client_id = Self::extract_client_id(request, context);
99
100 if !self.check_rate_limit(&client_id) {
101 let error_result = CallToolResult {
102 content: vec![crate::mcp::Content::Text {
103 text: format!(
104 "Rate limit exceeded. Limit: {} requests per minute. Please try again later.",
105 self.default_limit
106 ),
107 }],
108 is_error: true,
109 };
110
111 context.set_metadata("rate_limited".to_string(), Value::Bool(true));
112 context.set_metadata("rate_limit_client_id".to_string(), Value::String(client_id));
113
114 return Ok(MiddlewareResult::Stop(error_result));
115 }
116
117 let remaining = self.get_remaining_requests(&client_id);
118 context.set_metadata(
119 "rate_limit_remaining".to_string(),
120 Value::Number(serde_json::Number::from(remaining)),
121 );
122 context.set_metadata("rate_limit_client_id".to_string(), Value::String(client_id));
123
124 Ok(MiddlewareResult::Continue)
125 }
126}