Skip to main content

things3_cli/mcp/middleware/
auth.rs

1//! Authentication middleware
2
3use super::{McpMiddleware, MiddlewareContext, MiddlewareResult};
4use crate::mcp::{CallToolRequest, CallToolResult, McpError, McpResult};
5#[allow(unused_imports)]
6use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11pub struct AuthenticationMiddleware {
12    api_keys: HashMap<String, ApiKeyInfo>,
13    jwt_secret: String,
14    #[allow(dead_code)]
15    oauth_config: Option<OAuthConfig>,
16    require_auth: bool,
17}
18
19#[derive(Debug, Clone)]
20pub struct ApiKeyInfo {
21    pub key_id: String,
22    pub permissions: Vec<String>,
23    pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
24}
25
26#[derive(Debug, Clone)]
27pub struct OAuthConfig {
28    pub client_id: String,
29    pub client_secret: String,
30    pub token_endpoint: String,
31    pub scope: Vec<String>,
32}
33
34#[derive(Debug, Serialize, Deserialize)]
35pub struct JwtClaims {
36    pub sub: String, // Subject (user ID)
37    pub exp: usize,  // Expiration time
38    pub iat: usize,  // Issued at
39    pub permissions: Vec<String>,
40}
41
42impl AuthenticationMiddleware {
43    /// Create a new authentication middleware
44    #[must_use]
45    pub fn new(api_keys: HashMap<String, ApiKeyInfo>, jwt_secret: String) -> Self {
46        Self {
47            api_keys,
48            jwt_secret,
49            oauth_config: None,
50            require_auth: true,
51        }
52    }
53
54    /// Create with OAuth 2.0 support
55    #[must_use]
56    pub fn with_oauth(
57        api_keys: HashMap<String, ApiKeyInfo>,
58        jwt_secret: String,
59        oauth_config: OAuthConfig,
60    ) -> Self {
61        Self {
62            api_keys,
63            jwt_secret,
64            oauth_config: Some(oauth_config),
65            require_auth: true,
66        }
67    }
68
69    /// Create without requiring authentication (for testing)
70    #[must_use]
71    pub fn permissive() -> Self {
72        Self {
73            api_keys: HashMap::new(),
74            jwt_secret: "test-secret".to_string(),
75            oauth_config: None,
76            require_auth: false,
77        }
78    }
79
80    /// Extract API key from request headers or arguments
81    fn extract_api_key(request: &CallToolRequest) -> Option<String> {
82        // Check if API key is in request arguments
83        if let Some(args) = &request.arguments {
84            if let Some(api_key) = args.get("api_key").and_then(|v| v.as_str()) {
85                return Some(api_key.to_string());
86            }
87        }
88        None
89    }
90
91    /// Extract JWT token from request headers or arguments
92    fn extract_jwt_token(request: &CallToolRequest) -> Option<String> {
93        // Check if JWT token is in request arguments
94        if let Some(args) = &request.arguments {
95            if let Some(token) = args.get("jwt_token").and_then(|v| v.as_str()) {
96                return Some(token.to_string());
97            }
98        }
99        None
100    }
101
102    /// Validate API key
103    fn validate_api_key(&self, api_key: &str) -> McpResult<ApiKeyInfo> {
104        let info = self
105            .api_keys
106            .get(api_key)
107            .cloned()
108            .ok_or_else(|| McpError::validation_error("Invalid API key"))?;
109        if let Some(exp) = &info.expires_at {
110            if *exp < chrono::Utc::now() {
111                return Err(McpError::validation_error("API key has expired"));
112            }
113        }
114        Ok(info)
115    }
116
117    /// Validate JWT token
118    fn validate_jwt_token(&self, token: &str) -> McpResult<JwtClaims> {
119        let validation = Validation::new(Algorithm::HS256);
120        let key = DecodingKey::from_secret(self.jwt_secret.as_ref());
121
122        let token_data = decode::<JwtClaims>(token, &key, &validation)
123            .map_err(|_| McpError::validation_error("Invalid JWT token"))?;
124
125        // Check if token is expired
126        let now = chrono::Utc::now().timestamp().try_into().unwrap_or(0);
127        if token_data.claims.exp < now {
128            return Err(McpError::validation_error("JWT token has expired"));
129        }
130
131        Ok(token_data.claims)
132    }
133
134    /// Generate JWT token for testing
135    ///
136    /// # Panics
137    /// Panics if JWT encoding fails
138    #[cfg(test)]
139    #[must_use]
140    pub fn generate_test_jwt(&self, user_id: &str, permissions: Vec<String>) -> String {
141        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
142        let now = chrono::Utc::now().timestamp() as usize;
143        let claims = JwtClaims {
144            sub: user_id.to_string(),
145            exp: now + 3600, // 1 hour
146            iat: now,
147            permissions,
148        };
149
150        let header = Header::new(Algorithm::HS256);
151        let key = EncodingKey::from_secret(self.jwt_secret.as_ref());
152        encode(&header, &claims, &key).unwrap()
153    }
154}
155
156#[async_trait::async_trait]
157impl McpMiddleware for AuthenticationMiddleware {
158    fn name(&self) -> &'static str {
159        "authentication"
160    }
161
162    fn priority(&self) -> i32 {
163        10 // High priority to run early
164    }
165
166    async fn before_request(
167        &self,
168        request: &CallToolRequest,
169        context: &mut MiddlewareContext,
170    ) -> McpResult<MiddlewareResult> {
171        if !self.require_auth {
172            context.set_metadata("auth_required".to_string(), Value::Bool(false));
173            return Ok(MiddlewareResult::Continue);
174        }
175
176        // Try API key authentication first
177        if let Some(api_key) = Self::extract_api_key(request) {
178            if let Ok(api_key_info) = self.validate_api_key(&api_key) {
179                context.set_metadata(
180                    "auth_type".to_string(),
181                    Value::String("api_key".to_string()),
182                );
183                context.set_metadata(
184                    "auth_key_id".to_string(),
185                    Value::String(api_key_info.key_id),
186                );
187                context.set_metadata(
188                    "auth_permissions".to_string(),
189                    serde_json::to_value(api_key_info.permissions).unwrap_or(Value::Array(vec![])),
190                );
191                context.set_metadata("auth_required".to_string(), Value::Bool(true));
192                return Ok(MiddlewareResult::Continue);
193            }
194            // API key failed, try JWT
195        }
196
197        // Try JWT authentication
198        if let Some(jwt_token) = Self::extract_jwt_token(request) {
199            if let Ok(claims) = self.validate_jwt_token(&jwt_token) {
200                context.set_metadata("auth_type".to_string(), Value::String("jwt".to_string()));
201                context.set_metadata("auth_user_id".to_string(), Value::String(claims.sub));
202                context.set_metadata(
203                    "auth_permissions".to_string(),
204                    serde_json::to_value(claims.permissions).unwrap_or(Value::Array(vec![])),
205                );
206                context.set_metadata("auth_required".to_string(), Value::Bool(true));
207                return Ok(MiddlewareResult::Continue);
208            }
209            // JWT failed
210        }
211
212        // No valid authentication found
213        let error_result = CallToolResult {
214            content: vec![crate::mcp::Content::Text {
215                text: "Authentication required. Please provide a valid API key or JWT token."
216                    .to_string(),
217            }],
218            is_error: true,
219        };
220
221        Ok(MiddlewareResult::Stop(error_result))
222    }
223}