vex_api/
auth.rs

1//! JWT-based authentication
2
3use chrono::{Duration, Utc};
4use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7
8use crate::error::ApiError;
9
10/// JWT claims for API authentication
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Claims {
13    /// Subject (user/agent ID)
14    pub sub: String,
15    /// Expiration time (Unix timestamp)
16    pub exp: i64,
17    /// Issued at (Unix timestamp)
18    pub iat: i64,
19    /// Issuer
20    pub iss: String,
21    /// Role (user, agent, admin)
22    pub role: String,
23    /// Tenant ID (for multi-tenancy)
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub tenant_id: Option<String>,
26    /// Custom claims
27    #[serde(flatten)]
28    pub extra: std::collections::HashMap<String, serde_json::Value>,
29}
30
31impl Claims {
32    /// Create new claims for a user
33    pub fn for_user(user_id: &str, role: &str, expires_in: Duration) -> Self {
34        let now = Utc::now();
35        Self {
36            sub: user_id.to_string(),
37            exp: (now + expires_in).timestamp(),
38            iat: now.timestamp(),
39            iss: "vex-api".to_string(),
40            role: role.to_string(),
41            tenant_id: None,
42            extra: std::collections::HashMap::new(),
43        }
44    }
45
46    /// Create claims for an agent
47    pub fn for_agent(agent_id: Uuid, expires_in: Duration) -> Self {
48        Self::for_user(&agent_id.to_string(), "agent", expires_in)
49    }
50
51    /// Check if claims are expired
52    pub fn is_expired(&self) -> bool {
53        Utc::now().timestamp() > self.exp
54    }
55
56    /// Check if claims have a specific role
57    pub fn has_role(&self, role: &str) -> bool {
58        self.role == role || self.role == "admin"
59    }
60}
61
62/// JWT authentication handler
63#[derive(Clone)]
64pub struct JwtAuth {
65    encoding_key: EncodingKey,
66    decoding_key: DecodingKey,
67    validation: Validation,
68}
69
70impl JwtAuth {
71    /// Create new JWT auth with secret
72    pub fn new(secret: &str) -> Self {
73        let encoding_key = EncodingKey::from_secret(secret.as_bytes());
74        let decoding_key = DecodingKey::from_secret(secret.as_bytes());
75
76        let mut validation = Validation::default();
77        validation.set_issuer(&["vex-api"]);
78        validation.validate_exp = true;
79
80        Self {
81            encoding_key,
82            decoding_key,
83            validation,
84        }
85    }
86
87    /// Create from environment variable (required in production)
88    pub fn from_env() -> Result<Self, ApiError> {
89        let secret = std::env::var("VEX_JWT_SECRET").map_err(|_| {
90            ApiError::Internal(
91                "VEX_JWT_SECRET environment variable is required. \
92                     Generate with: openssl rand -base64 32"
93                    .to_string(),
94            )
95        })?;
96
97        if secret.len() < 32 {
98            return Err(ApiError::Internal(
99                "VEX_JWT_SECRET must be at least 32 characters for security".to_string(),
100            ));
101        }
102
103        Ok(Self::new(&secret))
104    }
105
106    /// Generate a token for claims
107    pub fn encode(&self, claims: &Claims) -> Result<String, ApiError> {
108        encode(&Header::default(), claims, &self.encoding_key)
109            .map_err(|e| ApiError::Internal(format!("JWT encoding error: {}", e)))
110    }
111
112    /// Validate and decode a token
113    pub fn decode(&self, token: &str) -> Result<Claims, ApiError> {
114        decode::<Claims>(token, &self.decoding_key, &self.validation)
115            .map(|data| data.claims)
116            .map_err(|e| match e.kind() {
117                jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
118                    ApiError::Unauthorized("Token expired".to_string())
119                }
120                jsonwebtoken::errors::ErrorKind::InvalidToken => {
121                    ApiError::Unauthorized("Invalid token".to_string())
122                }
123                _ => ApiError::Unauthorized(format!("Token validation failed: {}", e)),
124            })
125    }
126
127    /// Extract token from Authorization header
128    pub fn extract_from_header(header: &str) -> Result<&str, ApiError> {
129        header.strip_prefix("Bearer ").ok_or_else(|| {
130            ApiError::Unauthorized("Invalid Authorization header format".to_string())
131        })
132    }
133}
134
135/// API key for simplified authentication
136#[derive(Debug, Clone)]
137pub struct ApiKey {
138    pub key: String,
139    pub name: String,
140    pub roles: Vec<String>,
141    pub rate_limit: Option<u32>,
142}
143
144impl ApiKey {
145    /// Validate an API key (placeholder - connect to your key store)
146    pub async fn validate(key: &str) -> Result<Self, ApiError> {
147        // In production, this would check against a database
148        if key.starts_with("vex_") && key.len() > 20 {
149            Ok(ApiKey {
150                key: key.to_string(),
151                name: "default".to_string(),
152                roles: vec!["user".to_string()],
153                rate_limit: Some(100),
154            })
155        } else {
156            Err(ApiError::Unauthorized("Invalid API key".to_string()))
157        }
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn test_jwt_encode_decode() {
167        let auth = JwtAuth::new("test-secret-key-32-bytes-long!!");
168        let claims = Claims::for_user("user123", "user", Duration::hours(1));
169
170        let token = auth.encode(&claims).unwrap();
171        let decoded = auth.decode(&token).unwrap();
172
173        assert_eq!(decoded.sub, "user123");
174        assert_eq!(decoded.role, "user");
175        assert!(!decoded.is_expired());
176    }
177
178    #[test]
179    fn test_expired_token() {
180        let auth = JwtAuth::new("test-secret-key-32-bytes-long!!");
181        // Use -300s to ensure we exceed default 60s leeway
182        let claims = Claims::for_user("user123", "user", Duration::seconds(-300));
183
184        let token = auth.encode(&claims).unwrap();
185        let result = auth.decode(&token);
186
187        match &result {
188            Ok(c) => println!("Decoded claims despite expiry: {:?}", c),
189            Err(e) => println!("Error returned: {:?}", e),
190        }
191
192        assert!(
193            matches!(result, Err(ApiError::Unauthorized(_))),
194            "Expected Unauthorized error, got: {:?}",
195            result
196        );
197    }
198
199    #[test]
200    fn test_role_check() {
201        let claims = Claims::for_user("user123", "admin", Duration::hours(1));
202        assert!(claims.has_role("admin"));
203        assert!(claims.has_role("user")); // Admin has all roles
204    }
205}