ruvector_security/
auth.rs

1//! Authentication middleware and token validation
2//!
3//! Provides bearer token authentication for MCP endpoints.
4
5use crate::error::{SecurityError, SecurityResult};
6use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use hmac::{Hmac, Mac};
8use sha2::Sha256;
9use std::sync::Arc;
10use subtle::ConstantTimeEq;
11
12type HmacSha256 = Hmac<Sha256>;
13
14/// Authentication mode
15#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
16#[serde(rename_all = "lowercase")]
17pub enum AuthMode {
18    /// No authentication (development only)
19    #[default]
20    None,
21    /// Bearer token authentication
22    Bearer,
23    /// Mutual TLS (not yet implemented)
24    Mtls,
25}
26
27/// Authentication configuration
28#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
29pub struct AuthConfig {
30    /// Authentication mode
31    pub mode: AuthMode,
32    /// Bearer token (for Bearer mode)
33    #[serde(skip_serializing)]
34    pub token: Option<String>,
35    /// Secret key for HMAC validation
36    #[serde(skip_serializing)]
37    pub secret_key: Option<String>,
38    /// Token expiry in seconds (0 = no expiry)
39    pub token_expiry_secs: u64,
40    /// Allow localhost without auth (development)
41    pub allow_localhost: bool,
42}
43
44impl Default for AuthConfig {
45    fn default() -> Self {
46        Self {
47            mode: AuthMode::None,
48            token: None,
49            secret_key: None,
50            token_expiry_secs: 0,
51            allow_localhost: true,
52        }
53    }
54}
55
56/// Token validator trait
57pub trait TokenValidator: Send + Sync {
58    /// Validate a token and return Ok if valid
59    fn validate(&self, token: &str) -> SecurityResult<()>;
60}
61
62/// Bearer token validator using constant-time comparison
63#[derive(Clone)]
64pub struct BearerTokenValidator {
65    /// Expected token hash
66    token_hash: Vec<u8>,
67    /// HMAC key for hashing
68    hmac_key: Vec<u8>,
69}
70
71impl BearerTokenValidator {
72    /// Create a new bearer token validator
73    pub fn new(expected_token: &str) -> Self {
74        // Generate a random key for HMAC
75        let mut hmac_key = vec![0u8; 32];
76        rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut hmac_key);
77
78        // Hash the expected token
79        let mut mac = HmacSha256::new_from_slice(&hmac_key).expect("HMAC can take key of any size");
80        mac.update(expected_token.as_bytes());
81        let token_hash = mac.finalize().into_bytes().to_vec();
82
83        Self {
84            token_hash,
85            hmac_key,
86        }
87    }
88
89    /// Create from environment variable
90    pub fn from_env(env_var: &str) -> Option<Self> {
91        std::env::var(env_var).ok().map(|token| Self::new(&token))
92    }
93}
94
95impl TokenValidator for BearerTokenValidator {
96    fn validate(&self, token: &str) -> SecurityResult<()> {
97        // Hash the provided token
98        let mut mac =
99            HmacSha256::new_from_slice(&self.hmac_key).expect("HMAC can take key of any size");
100        mac.update(token.as_bytes());
101        let token_hash = mac.finalize().into_bytes();
102
103        // Constant-time comparison to prevent timing attacks
104        if token_hash.ct_eq(&self.token_hash).into() {
105            Ok(())
106        } else {
107            Err(SecurityError::InvalidToken)
108        }
109    }
110}
111
112/// Authentication middleware
113#[derive(Clone)]
114pub struct AuthMiddleware {
115    validator: Option<Arc<dyn TokenValidator>>,
116    config: AuthConfig,
117}
118
119impl AuthMiddleware {
120    /// Create new auth middleware with configuration
121    pub fn new(config: AuthConfig) -> Self {
122        let validator: Option<Arc<dyn TokenValidator>> = match &config.mode {
123            AuthMode::None => None,
124            AuthMode::Bearer => config
125                .token
126                .as_ref()
127                .map(|t| Arc::new(BearerTokenValidator::new(t)) as Arc<dyn TokenValidator>),
128            AuthMode::Mtls => {
129                tracing::warn!("mTLS authentication not yet implemented, falling back to None");
130                None
131            }
132        };
133
134        Self { validator, config }
135    }
136
137    /// Create middleware that requires no authentication (development)
138    pub fn none() -> Self {
139        Self::new(AuthConfig::default())
140    }
141
142    /// Create middleware with bearer token
143    pub fn bearer(token: &str) -> Self {
144        Self::new(AuthConfig {
145            mode: AuthMode::Bearer,
146            token: Some(token.to_string()),
147            ..Default::default()
148        })
149    }
150
151    /// Create from environment variable
152    pub fn from_env(env_var: &str) -> Self {
153        match std::env::var(env_var) {
154            Ok(token) if !token.is_empty() => Self::bearer(&token),
155            _ => Self::none(),
156        }
157    }
158
159    /// Validate a request's authorization header
160    pub fn validate_header(&self, auth_header: Option<&str>) -> SecurityResult<()> {
161        // If no authentication required, allow all
162        if self.config.mode == AuthMode::None {
163            return Ok(());
164        }
165
166        // Get validator or fail
167        let validator = self
168            .validator
169            .as_ref()
170            .ok_or(SecurityError::AuthenticationRequired)?;
171
172        // Parse authorization header
173        let header = auth_header.ok_or(SecurityError::AuthenticationRequired)?;
174
175        // Extract bearer token
176        let token = header
177            .strip_prefix("Bearer ")
178            .or_else(|| header.strip_prefix("bearer "))
179            .ok_or(SecurityError::InvalidToken)?;
180
181        validator.validate(token)
182    }
183
184    /// Check if a remote address should bypass authentication
185    pub fn is_localhost_allowed(&self, remote_addr: &str) -> bool {
186        if !self.config.allow_localhost {
187            return false;
188        }
189
190        remote_addr.starts_with("127.0.0.1")
191            || remote_addr.starts_with("::1")
192            || remote_addr.starts_with("localhost")
193    }
194
195    /// Get the authentication mode
196    pub fn mode(&self) -> &AuthMode {
197        &self.config.mode
198    }
199}
200
201impl Default for AuthMiddleware {
202    fn default() -> Self {
203        Self::none()
204    }
205}
206
207/// Generate a secure random token
208pub fn generate_token() -> String {
209    let mut bytes = [0u8; 32];
210    rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut bytes);
211    BASE64.encode(bytes)
212}
213
214/// Generate a token with a specific prefix for identification
215pub fn generate_prefixed_token(prefix: &str) -> String {
216    format!("{}_{}", prefix, generate_token())
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_bearer_token_validation() {
225        let token = "my_secret_token";
226        let validator = BearerTokenValidator::new(token);
227
228        assert!(validator.validate(token).is_ok());
229        assert!(validator.validate("wrong_token").is_err());
230    }
231
232    #[test]
233    fn test_auth_middleware_bearer() {
234        let token = "test_token_12345";
235        let middleware = AuthMiddleware::bearer(token);
236
237        // Valid token
238        assert!(middleware
239            .validate_header(Some(&format!("Bearer {}", token)))
240            .is_ok());
241
242        // Invalid token
243        assert!(middleware
244            .validate_header(Some("Bearer wrong_token"))
245            .is_err());
246
247        // Missing header
248        assert!(middleware.validate_header(None).is_err());
249    }
250
251    #[test]
252    fn test_auth_middleware_none() {
253        let middleware = AuthMiddleware::none();
254
255        // All requests should pass
256        assert!(middleware.validate_header(None).is_ok());
257        assert!(middleware.validate_header(Some("anything")).is_ok());
258    }
259
260    #[test]
261    fn test_generate_token() {
262        let token1 = generate_token();
263        let token2 = generate_token();
264
265        // Tokens should be different
266        assert_ne!(token1, token2);
267
268        // Tokens should be reasonable length
269        assert!(token1.len() >= 40);
270    }
271
272    #[test]
273    fn test_localhost_bypass() {
274        let config = AuthConfig {
275            mode: AuthMode::Bearer,
276            token: Some("secret".to_string()),
277            allow_localhost: true,
278            ..Default::default()
279        };
280        let middleware = AuthMiddleware::new(config);
281
282        assert!(middleware.is_localhost_allowed("127.0.0.1:8080"));
283        assert!(middleware.is_localhost_allowed("::1:8080"));
284        assert!(!middleware.is_localhost_allowed("192.168.1.1:8080"));
285    }
286}