vex_api/
auth.rs

1//! JWT-based authentication
2
3use chrono::{Duration, Utc};
4use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7use zeroize::Zeroizing;
8
9use crate::error::ApiError;
10
11/// JWT claims for API authentication
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Claims {
14    /// Subject (user/agent ID)
15    pub sub: String,
16    /// Expiration time (Unix timestamp)
17    pub exp: i64,
18    /// Issued at (Unix timestamp)
19    pub iat: i64,
20    /// Issuer
21    pub iss: String,
22    /// Role (user, agent, admin)
23    pub role: String,
24    /// Tenant ID (for multi-tenancy)
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub tenant_id: Option<String>,
27    /// Custom claims
28    #[serde(flatten)]
29    pub extra: std::collections::HashMap<String, serde_json::Value>,
30}
31
32impl Claims {
33    /// Create new claims for a user
34    pub fn for_user(user_id: &str, role: &str, expires_in: Duration) -> Self {
35        let now = Utc::now();
36        Self {
37            sub: user_id.to_string(),
38            exp: (now + expires_in).timestamp(),
39            iat: now.timestamp(),
40            iss: "vex-api".to_string(),
41            role: role.to_string(),
42            tenant_id: None,
43            extra: std::collections::HashMap::new(),
44        }
45    }
46
47    /// Create claims for an agent
48    pub fn for_agent(agent_id: Uuid, expires_in: Duration) -> Self {
49        Self::for_user(&agent_id.to_string(), "agent", expires_in)
50    }
51
52    /// Check if claims are expired
53    pub fn is_expired(&self) -> bool {
54        Utc::now().timestamp() > self.exp
55    }
56
57    /// Check if claims have a specific role
58    pub fn has_role(&self, role: &str) -> bool {
59        self.role == role || self.role == "admin"
60    }
61}
62
63/// JWT authentication handler
64#[derive(Clone)]
65pub struct JwtAuth {
66    encoding_key: EncodingKey,
67    decoding_key: DecodingKey,
68    validation: Validation,
69}
70
71impl JwtAuth {
72    /// Create new JWT auth with secret
73    pub fn new(secret: &str) -> Self {
74        let encoding_key = EncodingKey::from_secret(secret.as_bytes());
75        let decoding_key = DecodingKey::from_secret(secret.as_bytes());
76
77        let mut validation = Validation::default();
78        validation.set_issuer(&["vex-api"]);
79        validation.validate_exp = true;
80
81        Self {
82            encoding_key,
83            decoding_key,
84            validation,
85        }
86    }
87
88    /// Create from environment variable (required in production)
89    /// Uses Zeroizing to securely clear the secret from memory after key creation
90    pub fn from_env() -> Result<Self, ApiError> {
91        // Wrap secret in Zeroizing to ensure it's cleared from memory when dropped
92        let secret: Zeroizing<String> =
93            Zeroizing::new(std::env::var("VEX_JWT_SECRET").map_err(|_| {
94                ApiError::Internal(
95                    "VEX_JWT_SECRET environment variable is required. \
96                     Generate with: openssl rand -base64 32"
97                        .to_string(),
98                )
99            })?);
100
101        if secret.len() < 32 {
102            return Err(ApiError::Internal(
103                "VEX_JWT_SECRET must be at least 32 characters for security".to_string(),
104            ));
105        }
106
107        // Keys are created, then secret is automatically zeroed when Zeroizing drops
108        Ok(Self::new(&secret))
109    }
110
111    /// Generate a token for claims
112    pub fn encode(&self, claims: &Claims) -> Result<String, ApiError> {
113        encode(&Header::default(), claims, &self.encoding_key)
114            .map_err(|e| ApiError::Internal(format!("JWT encoding error: {}", e)))
115    }
116
117    /// Validate and decode a token
118    pub fn decode(&self, token: &str) -> Result<Claims, ApiError> {
119        decode::<Claims>(token, &self.decoding_key, &self.validation)
120            .map(|data| data.claims)
121            .map_err(|e| match e.kind() {
122                jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
123                    ApiError::Unauthorized("Token expired".to_string())
124                }
125                jsonwebtoken::errors::ErrorKind::InvalidToken => {
126                    ApiError::Unauthorized("Invalid token".to_string())
127                }
128                _ => ApiError::Unauthorized(format!("Token validation failed: {}", e)),
129            })
130    }
131
132    /// Extract token from Authorization header
133    pub fn extract_from_header(header: &str) -> Result<&str, ApiError> {
134        header.strip_prefix("Bearer ").ok_or_else(|| {
135            ApiError::Unauthorized("Invalid Authorization header format".to_string())
136        })
137    }
138}
139
140/// API key for simplified authentication
141#[derive(Debug, Clone)]
142pub struct ApiKey {
143    pub key_id: uuid::Uuid,
144    pub user_id: String,
145    pub name: String,
146    pub scopes: Vec<String>,
147    pub rate_limit: Option<u32>,
148}
149
150impl ApiKey {
151    /// Validate an API key against a database-backed key store
152    /// Uses Argon2id verification with constant-time comparison
153    pub async fn validate<S: vex_persist::ApiKeyStore>(
154        key: &str,
155        store: &S,
156    ) -> Result<Self, ApiError> {
157        // Use the proper database-backed validation
158        let record = vex_persist::validate_api_key(store, key)
159            .await
160            .map_err(|e| match e {
161                vex_persist::ApiKeyError::NotFound => {
162                    ApiError::Unauthorized("Invalid API key".to_string())
163                }
164                vex_persist::ApiKeyError::Expired => {
165                    ApiError::Unauthorized("API key expired".to_string())
166                }
167                vex_persist::ApiKeyError::Revoked => {
168                    ApiError::Unauthorized("API key revoked".to_string())
169                }
170                vex_persist::ApiKeyError::InvalidFormat => {
171                    ApiError::Unauthorized("Invalid API key format".to_string())
172                }
173                vex_persist::ApiKeyError::Storage(msg) => {
174                    ApiError::Internal(format!("Key validation error: {}", msg))
175                }
176            })?;
177
178        // Determine rate limit based on scopes
179        let rate_limit = if record.scopes.contains(&"enterprise".to_string()) {
180            Some(10000)
181        } else if record.scopes.contains(&"pro".to_string()) {
182            Some(1000)
183        } else {
184            Some(100) // Free tier default
185        };
186
187        Ok(ApiKey {
188            key_id: record.id,
189            user_id: record.user_id,
190            name: record.name,
191            scopes: record.scopes,
192            rate_limit,
193        })
194    }
195
196    /// Check if this API key has a specific scope
197    pub fn has_scope(&self, scope: &str) -> bool {
198        self.scopes.iter().any(|s| s == scope || s == "*")
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_jwt_encode_decode() {
208        let auth = JwtAuth::new("test-secret-key-32-bytes-long!!");
209        let claims = Claims::for_user("user123", "user", Duration::hours(1));
210
211        let token = auth.encode(&claims).unwrap();
212        let decoded = auth.decode(&token).unwrap();
213
214        assert_eq!(decoded.sub, "user123");
215        assert_eq!(decoded.role, "user");
216        assert!(!decoded.is_expired());
217    }
218
219    #[test]
220    fn test_expired_token() {
221        let auth = JwtAuth::new("test-secret-key-32-bytes-long!!");
222        // Use -300s to ensure we exceed default 60s leeway
223        let claims = Claims::for_user("user123", "user", Duration::seconds(-300));
224
225        let token = auth.encode(&claims).unwrap();
226        let result = auth.decode(&token);
227
228        match &result {
229            Ok(c) => println!("Decoded claims despite expiry: {:?}", c),
230            Err(e) => println!("Error returned: {:?}", e),
231        }
232
233        assert!(
234            matches!(result, Err(ApiError::Unauthorized(_))),
235            "Expected Unauthorized error, got: {:?}",
236            result
237        );
238    }
239
240    #[test]
241    fn test_role_check() {
242        let claims = Claims::for_user("user123", "admin", Duration::hours(1));
243        assert!(claims.has_role("admin"));
244        assert!(claims.has_role("user")); // Admin has all roles
245    }
246}