things3_cli/mcp/middleware/
auth.rs1use super::{McpMiddleware, MiddlewareContext, MiddlewareResult};
4use crate::mcp::{CallToolRequest, CallToolResult, McpError, McpResult};
5#[allow(unused_imports)]
6use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11pub struct AuthenticationMiddleware {
12 api_keys: HashMap<String, ApiKeyInfo>,
13 jwt_secret: String,
14 #[allow(dead_code)]
15 oauth_config: Option<OAuthConfig>,
16 require_auth: bool,
17}
18
19#[derive(Debug, Clone)]
20pub struct ApiKeyInfo {
21 pub key_id: String,
22 pub permissions: Vec<String>,
23 pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
24}
25
26#[derive(Debug, Clone)]
27pub struct OAuthConfig {
28 pub client_id: String,
29 pub client_secret: String,
30 pub token_endpoint: String,
31 pub scope: Vec<String>,
32}
33
34#[derive(Debug, Serialize, Deserialize)]
35pub struct JwtClaims {
36 pub sub: String, pub exp: usize, pub iat: usize, pub permissions: Vec<String>,
40}
41
42impl AuthenticationMiddleware {
43 #[must_use]
45 pub fn new(api_keys: HashMap<String, ApiKeyInfo>, jwt_secret: String) -> Self {
46 Self {
47 api_keys,
48 jwt_secret,
49 oauth_config: None,
50 require_auth: true,
51 }
52 }
53
54 #[must_use]
56 pub fn with_oauth(
57 api_keys: HashMap<String, ApiKeyInfo>,
58 jwt_secret: String,
59 oauth_config: OAuthConfig,
60 ) -> Self {
61 Self {
62 api_keys,
63 jwt_secret,
64 oauth_config: Some(oauth_config),
65 require_auth: true,
66 }
67 }
68
69 #[must_use]
71 pub fn permissive() -> Self {
72 Self {
73 api_keys: HashMap::new(),
74 jwt_secret: "test-secret".to_string(),
75 oauth_config: None,
76 require_auth: false,
77 }
78 }
79
80 fn extract_api_key(request: &CallToolRequest) -> Option<String> {
82 if let Some(args) = &request.arguments {
84 if let Some(api_key) = args.get("api_key").and_then(|v| v.as_str()) {
85 return Some(api_key.to_string());
86 }
87 }
88 None
89 }
90
91 fn extract_jwt_token(request: &CallToolRequest) -> Option<String> {
93 if let Some(args) = &request.arguments {
95 if let Some(token) = args.get("jwt_token").and_then(|v| v.as_str()) {
96 return Some(token.to_string());
97 }
98 }
99 None
100 }
101
102 fn validate_api_key(&self, api_key: &str) -> McpResult<ApiKeyInfo> {
104 let info = self
105 .api_keys
106 .get(api_key)
107 .cloned()
108 .ok_or_else(|| McpError::validation_error("Invalid API key"))?;
109 if let Some(exp) = &info.expires_at {
110 if *exp < chrono::Utc::now() {
111 return Err(McpError::validation_error("API key has expired"));
112 }
113 }
114 Ok(info)
115 }
116
117 fn validate_jwt_token(&self, token: &str) -> McpResult<JwtClaims> {
119 let validation = Validation::new(Algorithm::HS256);
120 let key = DecodingKey::from_secret(self.jwt_secret.as_ref());
121
122 let token_data = decode::<JwtClaims>(token, &key, &validation)
123 .map_err(|_| McpError::validation_error("Invalid JWT token"))?;
124
125 let now = chrono::Utc::now().timestamp().try_into().unwrap_or(0);
127 if token_data.claims.exp < now {
128 return Err(McpError::validation_error("JWT token has expired"));
129 }
130
131 Ok(token_data.claims)
132 }
133
134 #[cfg(test)]
139 #[must_use]
140 pub fn generate_test_jwt(&self, user_id: &str, permissions: Vec<String>) -> String {
141 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
142 let now = chrono::Utc::now().timestamp() as usize;
143 let claims = JwtClaims {
144 sub: user_id.to_string(),
145 exp: now + 3600, iat: now,
147 permissions,
148 };
149
150 let header = Header::new(Algorithm::HS256);
151 let key = EncodingKey::from_secret(self.jwt_secret.as_ref());
152 encode(&header, &claims, &key).unwrap()
153 }
154}
155
156#[async_trait::async_trait]
157impl McpMiddleware for AuthenticationMiddleware {
158 fn name(&self) -> &'static str {
159 "authentication"
160 }
161
162 fn priority(&self) -> i32 {
163 10 }
165
166 async fn before_request(
167 &self,
168 request: &CallToolRequest,
169 context: &mut MiddlewareContext,
170 ) -> McpResult<MiddlewareResult> {
171 if !self.require_auth {
172 context.set_metadata("auth_required".to_string(), Value::Bool(false));
173 return Ok(MiddlewareResult::Continue);
174 }
175
176 if let Some(api_key) = Self::extract_api_key(request) {
178 if let Ok(api_key_info) = self.validate_api_key(&api_key) {
179 context.set_metadata(
180 "auth_type".to_string(),
181 Value::String("api_key".to_string()),
182 );
183 context.set_metadata(
184 "auth_key_id".to_string(),
185 Value::String(api_key_info.key_id),
186 );
187 context.set_metadata(
188 "auth_permissions".to_string(),
189 serde_json::to_value(api_key_info.permissions).unwrap_or(Value::Array(vec![])),
190 );
191 context.set_metadata("auth_required".to_string(), Value::Bool(true));
192 return Ok(MiddlewareResult::Continue);
193 }
194 }
196
197 if let Some(jwt_token) = Self::extract_jwt_token(request) {
199 if let Ok(claims) = self.validate_jwt_token(&jwt_token) {
200 context.set_metadata("auth_type".to_string(), Value::String("jwt".to_string()));
201 context.set_metadata("auth_user_id".to_string(), Value::String(claims.sub));
202 context.set_metadata(
203 "auth_permissions".to_string(),
204 serde_json::to_value(claims.permissions).unwrap_or(Value::Array(vec![])),
205 );
206 context.set_metadata("auth_required".to_string(), Value::Bool(true));
207 return Ok(MiddlewareResult::Continue);
208 }
209 }
211
212 let error_result = CallToolResult {
214 content: vec![crate::mcp::Content::Text {
215 text: "Authentication required. Please provide a valid API key or JWT token."
216 .to_string(),
217 }],
218 is_error: true,
219 };
220
221 Ok(MiddlewareResult::Stop(error_result))
222 }
223}