Skip to main content

things3_cli/mcp/middleware/
config.rs

1//! Middleware configuration types
2
3use super::auth::{ApiKeyInfo, OAuthConfig};
4use super::logging::LogLevel;
5use super::{
6    AuthenticationMiddleware, LoggingMiddleware, MiddlewareChain, PerformanceMiddleware,
7    RateLimitMiddleware, ValidationMiddleware,
8};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::Duration;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SecurityConfig {
15    /// Authentication configuration
16    pub authentication: AuthenticationConfig,
17    /// Rate limiting configuration
18    pub rate_limiting: RateLimitingConfig,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct AuthenticationConfig {
23    /// Enable authentication middleware
24    pub enabled: bool,
25    /// Require authentication for all requests
26    pub require_auth: bool,
27    /// JWT secret for token validation
28    pub jwt_secret: String,
29    /// API keys configuration
30    pub api_keys: Vec<ApiKeyConfig>,
31    /// OAuth 2.0 configuration
32    pub oauth: Option<OAuth2Config>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ApiKeyConfig {
37    /// API key value
38    pub key: String,
39    /// Key identifier
40    pub key_id: String,
41    /// Permissions for this key
42    pub permissions: Vec<String>,
43    /// Optional expiration date
44    pub expires_at: Option<String>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct OAuth2Config {
49    /// OAuth client ID
50    pub client_id: String,
51    /// OAuth client secret
52    pub client_secret: String,
53    /// Token endpoint URL
54    pub token_endpoint: String,
55    /// Required scopes
56    pub scopes: Vec<String>,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct RateLimitingConfig {
61    /// Enable rate limiting middleware
62    pub enabled: bool,
63    /// Requests per minute limit
64    pub requests_per_minute: u32,
65    /// Burst limit for short bursts
66    pub burst_limit: u32,
67    /// Custom limits per client type
68    pub custom_limits: Option<HashMap<String, u32>>,
69}
70
71impl Default for SecurityConfig {
72    fn default() -> Self {
73        Self {
74            authentication: AuthenticationConfig {
75                enabled: true,
76                require_auth: false, // Start with auth disabled for easier development
77                jwt_secret: "your-secret-key-change-this-in-production".to_string(),
78                api_keys: vec![],
79                oauth: None,
80            },
81            rate_limiting: RateLimitingConfig {
82                enabled: true,
83                requests_per_minute: 60,
84                burst_limit: 10,
85                custom_limits: None,
86            },
87        }
88    }
89}
90
91/// Middleware configuration
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct MiddlewareConfig {
94    /// Logging configuration
95    pub logging: LoggingConfig,
96    /// Validation configuration
97    pub validation: ValidationConfig,
98    /// Performance monitoring configuration
99    pub performance: PerformanceConfig,
100    /// Security configuration
101    pub security: SecurityConfig,
102}
103
104/// Logging middleware configuration
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct LoggingConfig {
107    /// Enable logging middleware
108    pub enabled: bool,
109    /// Log level for logging middleware
110    pub level: String,
111}
112
113/// Validation middleware configuration
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct ValidationConfig {
116    /// Enable validation middleware
117    pub enabled: bool,
118    /// Use strict validation mode
119    pub strict_mode: bool,
120}
121
122/// Performance monitoring configuration
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct PerformanceConfig {
125    /// Enable performance monitoring
126    pub enabled: bool,
127    /// Slow request threshold in milliseconds
128    pub slow_request_threshold_ms: u64,
129}
130
131impl Default for MiddlewareConfig {
132    fn default() -> Self {
133        Self {
134            logging: LoggingConfig {
135                enabled: true,
136                level: "info".to_string(),
137            },
138            validation: ValidationConfig {
139                enabled: true,
140                strict_mode: false,
141            },
142            performance: PerformanceConfig {
143                enabled: true,
144                slow_request_threshold_ms: 1000,
145            },
146            security: SecurityConfig::default(),
147        }
148    }
149}
150
151impl MiddlewareConfig {
152    /// Create a new middleware configuration
153    #[must_use]
154    pub fn new() -> Self {
155        Self::default()
156    }
157
158    /// Build a middleware chain from this configuration
159    #[must_use]
160    pub fn build_chain(self) -> MiddlewareChain {
161        let mut chain = MiddlewareChain::new();
162
163        // Security middleware (highest priority)
164        if self.security.authentication.enabled {
165            let api_keys: HashMap<String, ApiKeyInfo> = self
166                .security
167                .authentication
168                .api_keys
169                .into_iter()
170                .map(|config| {
171                    let expires_at = config.expires_at.and_then(|date_str| {
172                        chrono::DateTime::parse_from_rfc3339(&date_str)
173                            .ok()
174                            .map(|dt| dt.with_timezone(&chrono::Utc))
175                    });
176
177                    let api_key_info = ApiKeyInfo {
178                        key_id: config.key_id,
179                        permissions: config.permissions,
180                        expires_at,
181                    };
182
183                    (config.key, api_key_info)
184                })
185                .collect();
186
187            let auth_middleware = if self.security.authentication.require_auth {
188                if let Some(oauth_config) = self.security.authentication.oauth {
189                    let oauth = OAuthConfig {
190                        client_id: oauth_config.client_id,
191                        client_secret: oauth_config.client_secret,
192                        token_endpoint: oauth_config.token_endpoint,
193                        scope: oauth_config.scopes,
194                    };
195                    AuthenticationMiddleware::with_oauth(
196                        api_keys,
197                        self.security.authentication.jwt_secret,
198                        oauth,
199                    )
200                } else {
201                    AuthenticationMiddleware::new(api_keys, self.security.authentication.jwt_secret)
202                }
203            } else {
204                AuthenticationMiddleware::permissive()
205            };
206
207            chain = chain.add_middleware(auth_middleware);
208        }
209
210        if self.security.rate_limiting.enabled {
211            let rate_limit_middleware = RateLimitMiddleware::with_limits(
212                self.security.rate_limiting.requests_per_minute,
213                self.security.rate_limiting.burst_limit,
214            );
215            chain = chain.add_middleware(rate_limit_middleware);
216        }
217
218        if self.logging.enabled {
219            let log_level = match self.logging.level.to_lowercase().as_str() {
220                "debug" => LogLevel::Debug,
221                "warn" => LogLevel::Warn,
222                "error" => LogLevel::Error,
223                _ => LogLevel::Info,
224            };
225            chain = chain.add_middleware(LoggingMiddleware::new(log_level));
226        }
227
228        if self.validation.enabled {
229            chain = chain.add_middleware(ValidationMiddleware::new(self.validation.strict_mode));
230        }
231
232        if self.performance.enabled {
233            let threshold = Duration::from_millis(self.performance.slow_request_threshold_ms);
234            chain = chain.add_middleware(PerformanceMiddleware::with_threshold(threshold));
235        }
236
237        chain
238    }
239}