Skip to main content

things3_cli/mcp/middleware/
rate_limit.rs

1//! Rate limiting middleware
2
3use super::{McpMiddleware, MiddlewareContext, MiddlewareResult};
4use crate::mcp::{CallToolRequest, CallToolResult, McpResult};
5use governor::clock::DefaultClock;
6use governor::{state::keyed::DefaultKeyedStateStore, Quota, RateLimiter};
7use nonzero_ext::nonzero;
8use serde_json::Value;
9use std::sync::Arc;
10
11pub struct RateLimitMiddleware {
12    rate_limiter: Arc<RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>>,
13    default_limit: u32,
14    #[allow(dead_code)]
15    burst_limit: u32,
16}
17
18impl RateLimitMiddleware {
19    /// Create a new rate limiting middleware
20    #[must_use]
21    pub fn new(requests_per_minute: u32, burst_limit: u32) -> Self {
22        let quota = Quota::per_minute(nonzero!(60u32)); // Use a constant for now
23        let rate_limiter = Arc::new(RateLimiter::keyed(quota));
24
25        Self {
26            rate_limiter,
27            default_limit: requests_per_minute,
28            burst_limit,
29        }
30    }
31
32    /// Create with custom limits
33    #[must_use]
34    pub fn with_limits(requests_per_minute: u32, burst_limit: u32) -> Self {
35        Self::new(requests_per_minute, burst_limit)
36    }
37
38    /// Create with default limits (60 requests per minute, burst of 10)
39    #[allow(clippy::should_implement_trait)]
40    #[must_use]
41    pub fn default() -> Self {
42        Self::new(60, 10)
43    }
44
45    /// Extract client identifier from request
46    fn extract_client_id(request: &CallToolRequest, context: &MiddlewareContext) -> String {
47        // Try to get from authentication context first
48        if let Some(auth_key_id) = context.get_metadata("auth_key_id").and_then(|v| v.as_str()) {
49            return format!("api_key:{auth_key_id}");
50        }
51
52        if let Some(auth_user_id) = context
53            .get_metadata("auth_user_id")
54            .and_then(|v| v.as_str())
55        {
56            return format!("jwt:{auth_user_id}");
57        }
58
59        // Fallback to request-based identifier
60        if let Some(args) = &request.arguments {
61            if let Some(client_id) = args.get("client_id").and_then(|v| v.as_str()) {
62                return format!("client:{client_id}");
63            }
64        }
65
66        // Use request ID as fallback
67        format!("request:{}", context.request_id)
68    }
69
70    /// Check if request is within rate limits
71    fn check_rate_limit(&self, client_id: &str) -> bool {
72        self.rate_limiter.check_key(&client_id.to_string()).is_ok()
73    }
74
75    /// Get remaining requests for client
76    fn get_remaining_requests(&self, _client_id: &str) -> u32 {
77        // This is a simplified implementation
78        // In a real implementation, you'd want to track remaining requests more precisely
79        self.default_limit
80    }
81}
82
83#[async_trait::async_trait]
84impl McpMiddleware for RateLimitMiddleware {
85    fn name(&self) -> &'static str {
86        "rate_limiting"
87    }
88
89    fn priority(&self) -> i32 {
90        20 // Run after authentication but before other middleware
91    }
92
93    async fn before_request(
94        &self,
95        request: &CallToolRequest,
96        context: &mut MiddlewareContext,
97    ) -> McpResult<MiddlewareResult> {
98        let client_id = Self::extract_client_id(request, context);
99
100        if !self.check_rate_limit(&client_id) {
101            let error_result = CallToolResult {
102                content: vec![crate::mcp::Content::Text {
103                    text: format!(
104                        "Rate limit exceeded. Limit: {} requests per minute. Please try again later.",
105                        self.default_limit
106                    ),
107                }],
108                is_error: true,
109            };
110
111            context.set_metadata("rate_limited".to_string(), Value::Bool(true));
112            context.set_metadata("rate_limit_client_id".to_string(), Value::String(client_id));
113
114            return Ok(MiddlewareResult::Stop(error_result));
115        }
116
117        let remaining = self.get_remaining_requests(&client_id);
118        context.set_metadata(
119            "rate_limit_remaining".to_string(),
120            Value::Number(serde_json::Number::from(remaining)),
121        );
122        context.set_metadata("rate_limit_client_id".to_string(), Value::String(client_id));
123
124        Ok(MiddlewareResult::Continue)
125    }
126}