1use chrono::{Duration, Utc};
4use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7use zeroize::Zeroizing;
8
9use crate::error::ApiError;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Claims {
14 pub sub: String,
16 pub exp: i64,
18 pub iat: i64,
20 pub iss: String,
22 pub role: String,
24 #[serde(skip_serializing_if = "Option::is_none")]
26 pub tenant_id: Option<String>,
27 #[serde(flatten)]
29 pub extra: std::collections::HashMap<String, serde_json::Value>,
30}
31
32impl Claims {
33 pub fn for_user(user_id: &str, role: &str, expires_in: Duration) -> Self {
35 let now = Utc::now();
36 Self {
37 sub: user_id.to_string(),
38 exp: (now + expires_in).timestamp(),
39 iat: now.timestamp(),
40 iss: "vex-api".to_string(),
41 role: role.to_string(),
42 tenant_id: None,
43 extra: std::collections::HashMap::new(),
44 }
45 }
46
47 pub fn for_agent(agent_id: Uuid, expires_in: Duration) -> Self {
49 Self::for_user(&agent_id.to_string(), "agent", expires_in)
50 }
51
52 pub fn is_expired(&self) -> bool {
54 Utc::now().timestamp() > self.exp
55 }
56
57 pub fn has_role(&self, role: &str) -> bool {
59 self.role == role || self.role == "admin"
60 }
61}
62
63#[derive(Clone)]
65pub struct JwtAuth {
66 encoding_key: EncodingKey,
67 decoding_key: DecodingKey,
68 validation: Validation,
69}
70
71impl JwtAuth {
72 pub fn new(secret: &str) -> Self {
74 let encoding_key = EncodingKey::from_secret(secret.as_bytes());
75 let decoding_key = DecodingKey::from_secret(secret.as_bytes());
76
77 let mut validation = Validation::default();
78 validation.set_issuer(&["vex-api"]);
79 validation.validate_exp = true;
80
81 Self {
82 encoding_key,
83 decoding_key,
84 validation,
85 }
86 }
87
88 pub fn from_env() -> Result<Self, ApiError> {
91 let secret: Zeroizing<String> =
93 Zeroizing::new(std::env::var("VEX_JWT_SECRET").map_err(|_| {
94 ApiError::Internal(
95 "VEX_JWT_SECRET environment variable is required. \
96 Generate with: openssl rand -base64 32"
97 .to_string(),
98 )
99 })?);
100
101 if secret.len() < 32 {
102 return Err(ApiError::Internal(
103 "VEX_JWT_SECRET must be at least 32 characters for security".to_string(),
104 ));
105 }
106
107 Ok(Self::new(&secret))
109 }
110
111 pub fn encode(&self, claims: &Claims) -> Result<String, ApiError> {
113 encode(&Header::default(), claims, &self.encoding_key)
114 .map_err(|e| ApiError::Internal(format!("JWT encoding error: {}", e)))
115 }
116
117 pub fn decode(&self, token: &str) -> Result<Claims, ApiError> {
119 decode::<Claims>(token, &self.decoding_key, &self.validation)
120 .map(|data| data.claims)
121 .map_err(|e| match e.kind() {
122 jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
123 ApiError::Unauthorized("Token expired".to_string())
124 }
125 jsonwebtoken::errors::ErrorKind::InvalidToken => {
126 ApiError::Unauthorized("Invalid token".to_string())
127 }
128 _ => ApiError::Unauthorized(format!("Token validation failed: {}", e)),
129 })
130 }
131
132 pub fn extract_from_header(header: &str) -> Result<&str, ApiError> {
134 header.strip_prefix("Bearer ").ok_or_else(|| {
135 ApiError::Unauthorized("Invalid Authorization header format".to_string())
136 })
137 }
138}
139
140#[derive(Debug, Clone)]
142pub struct ApiKey {
143 pub key_id: uuid::Uuid,
144 pub user_id: String,
145 pub name: String,
146 pub scopes: Vec<String>,
147 pub rate_limit: Option<u32>,
148}
149
150impl ApiKey {
151 pub async fn validate<S: vex_persist::ApiKeyStore>(
154 key: &str,
155 store: &S,
156 ) -> Result<Self, ApiError> {
157 let record = vex_persist::validate_api_key(store, key)
159 .await
160 .map_err(|e| match e {
161 vex_persist::ApiKeyError::NotFound => {
162 ApiError::Unauthorized("Invalid API key".to_string())
163 }
164 vex_persist::ApiKeyError::Expired => {
165 ApiError::Unauthorized("API key expired".to_string())
166 }
167 vex_persist::ApiKeyError::Revoked => {
168 ApiError::Unauthorized("API key revoked".to_string())
169 }
170 vex_persist::ApiKeyError::InvalidFormat => {
171 ApiError::Unauthorized("Invalid API key format".to_string())
172 }
173 vex_persist::ApiKeyError::Storage(msg) => {
174 ApiError::Internal(format!("Key validation error: {}", msg))
175 }
176 })?;
177
178 let rate_limit = if record.scopes.contains(&"enterprise".to_string()) {
180 Some(10000)
181 } else if record.scopes.contains(&"pro".to_string()) {
182 Some(1000)
183 } else {
184 Some(100) };
186
187 Ok(ApiKey {
188 key_id: record.id,
189 user_id: record.user_id,
190 name: record.name,
191 scopes: record.scopes,
192 rate_limit,
193 })
194 }
195
196 pub fn has_scope(&self, scope: &str) -> bool {
198 self.scopes.iter().any(|s| s == scope || s == "*")
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn test_jwt_encode_decode() {
208 let auth = JwtAuth::new("test-secret-key-32-bytes-long!!");
209 let claims = Claims::for_user("user123", "user", Duration::hours(1));
210
211 let token = auth.encode(&claims).unwrap();
212 let decoded = auth.decode(&token).unwrap();
213
214 assert_eq!(decoded.sub, "user123");
215 assert_eq!(decoded.role, "user");
216 assert!(!decoded.is_expired());
217 }
218
219 #[test]
220 fn test_expired_token() {
221 let auth = JwtAuth::new("test-secret-key-32-bytes-long!!");
222 let claims = Claims::for_user("user123", "user", Duration::seconds(-300));
224
225 let token = auth.encode(&claims).unwrap();
226 let result = auth.decode(&token);
227
228 match &result {
229 Ok(c) => println!("Decoded claims despite expiry: {:?}", c),
230 Err(e) => println!("Error returned: {:?}", e),
231 }
232
233 assert!(
234 matches!(result, Err(ApiError::Unauthorized(_))),
235 "Expected Unauthorized error, got: {:?}",
236 result
237 );
238 }
239
240 #[test]
241 fn test_role_check() {
242 let claims = Claims::for_user("user123", "admin", Duration::hours(1));
243 assert!(claims.has_role("admin"));
244 assert!(claims.has_role("user")); }
246}