1pub mod jwt;
18pub mod opaque;
19pub mod provider;
20
21use std::path::Path;
22
23use base64::{Engine, prelude::BASE64_URL_SAFE_NO_PAD};
24use chrono::{DateTime, Duration, Utc};
25use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
26use rand::{TryRngCore, rngs::OsRng};
27use serde::{Deserialize, Serialize};
28
29use crate::{
30 Error,
31 error::{SessionError, ValidationError},
32 user::UserId,
33};
34
35pub use jwt::JwtSessionProvider;
37pub use opaque::OpaqueSessionProvider;
38pub use provider::SessionProvider;
39
40fn generate_random_string(length: usize) -> String {
44 if length < 32 {
45 panic!("Length must be at least 32");
46 }
47 let mut bytes = vec![0u8; length];
48 OsRng.try_fill_bytes(&mut bytes).unwrap();
49 BASE64_URL_SAFE_NO_PAD.encode(bytes)
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
54#[serde(untagged)]
55pub enum SessionToken {
56 Opaque(String),
59 Jwt(String),
62}
63
64impl SessionToken {
65 pub fn new(token: &str) -> Self {
67 if token.chars().filter(|&c| c == '.').count() == 2 {
69 SessionToken::Jwt(token.to_string())
70 } else {
71 SessionToken::Opaque(token.to_string())
72 }
73 }
74
75 pub fn new_random() -> Self {
77 SessionToken::Opaque(generate_random_string(32))
78 }
79
80 pub fn new_jwt(claims: &JwtClaims, config: &JwtConfig) -> Result<Self, Error> {
82 let header = Header::new(config.jwt_algorithm());
83
84 let encoding_key = config.get_encoding_key()?;
85
86 let token = encode(&header, claims, &encoding_key)
87 .map_err(|e| SessionError::InvalidToken(format!("Failed to encode JWT: {e}")))?;
88
89 Ok(SessionToken::Jwt(token))
90 }
91
92 pub fn verify_jwt(&self, config: &JwtConfig) -> Result<JwtClaims, Error> {
94 match self {
95 SessionToken::Jwt(token) => {
96 let decoding_key = config.get_decoding_key()?;
97 let validation = config.get_validation();
98
99 let token_data =
100 decode::<JwtClaims>(token, &decoding_key, &validation).map_err(|e| {
101 SessionError::InvalidToken(format!("JWT validation failed: {e}"))
102 })?;
103
104 Ok(token_data.claims)
105 }
106 SessionToken::Opaque(_) => Err(Error::Session(SessionError::InvalidToken(
107 "Not a JWT token".to_string(),
108 ))),
109 }
110 }
111
112 pub fn new_jwt_rs256(claims: &JwtClaims, private_key: &[u8]) -> Result<Self, Error> {
114 let config = JwtConfig::new_rs256(private_key.to_vec(), vec![]);
115 Self::new_jwt(claims, &config)
116 }
117
118 pub fn verify_jwt_rs256(&self, public_key: &[u8]) -> Result<JwtClaims, Error> {
120 let config = JwtConfig::new_rs256(vec![], public_key.to_vec());
121 self.verify_jwt(&config)
122 }
123
124 pub fn new_jwt_hs256(claims: &JwtClaims, secret_key: &[u8]) -> Result<Self, Error> {
126 let config = JwtConfig::new_hs256(secret_key.to_vec());
127 Self::new_jwt(claims, &config)
128 }
129
130 pub fn verify_jwt_hs256(&self, secret_key: &[u8]) -> Result<JwtClaims, Error> {
132 let config = JwtConfig::new_hs256(secret_key.to_vec());
133 self.verify_jwt(&config)
134 }
135
136 pub fn into_inner(self) -> String {
138 match self {
139 SessionToken::Opaque(token) => token,
140 SessionToken::Jwt(token) => token,
141 }
142 }
143
144 pub fn as_str(&self) -> &str {
146 match self {
147 SessionToken::Opaque(token) => token,
148 SessionToken::Jwt(token) => token,
149 }
150 }
151}
152
153impl Default for SessionToken {
154 fn default() -> Self {
155 Self::new_random()
156 }
157}
158
159impl From<String> for SessionToken {
160 fn from(s: String) -> Self {
161 Self::new(&s)
162 }
163}
164
165impl From<&str> for SessionToken {
166 fn from(s: &str) -> Self {
167 Self::new(s)
168 }
169}
170
171impl std::fmt::Display for SessionToken {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 match self {
175 SessionToken::Opaque(token) => write!(f, "{token}"),
176 SessionToken::Jwt(token) => write!(f, "{token}"),
177 }
178 }
179}
180
181#[derive(Debug, Serialize, Deserialize)]
183pub struct JwtClaims {
184 pub sub: String,
186 pub iat: i64,
188 pub exp: i64,
190 #[serde(skip_serializing_if = "Option::is_none")]
192 pub iss: Option<String>,
193 #[serde(skip_serializing_if = "Option::is_none")]
195 pub metadata: Option<JwtMetadata>,
196}
197
198#[derive(Debug, Serialize, Deserialize)]
200pub struct JwtMetadata {
201 #[serde(skip_serializing_if = "Option::is_none")]
203 pub user_agent: Option<String>,
204 #[serde(skip_serializing_if = "Option::is_none")]
206 pub ip_address: Option<String>,
207}
208
209#[derive(Debug, Clone)]
211pub enum JwtAlgorithm {
212 RS256 {
214 private_key: Vec<u8>,
216 public_key: Vec<u8>,
218 },
219 HS256 {
221 secret_key: Vec<u8>,
223 },
224}
225
226#[derive(Debug, Clone)]
228pub struct JwtConfig {
229 pub algorithm: JwtAlgorithm,
231 pub issuer: Option<String>,
233 pub include_metadata: bool,
235}
236
237impl JwtConfig {
238 pub fn new_rs256(private_key: Vec<u8>, public_key: Vec<u8>) -> Self {
240 Self {
241 algorithm: JwtAlgorithm::RS256 {
242 private_key,
243 public_key,
244 },
245 issuer: None,
246 include_metadata: false,
247 }
248 }
249
250 pub fn new_hs256(secret_key: Vec<u8>) -> Self {
252 Self {
253 algorithm: JwtAlgorithm::HS256 { secret_key },
254 issuer: None,
255 include_metadata: false,
256 }
257 }
258
259 pub fn from_rs256_pem_files(
261 private_key_path: impl AsRef<Path>,
262 public_key_path: impl AsRef<Path>,
263 ) -> Result<Self, Error> {
264 use std::fs::read;
265
266 let private_key = read(private_key_path).map_err(|e| {
267 ValidationError::InvalidField(format!("Failed to read private key file: {e}"))
268 })?;
269
270 let public_key = read(public_key_path).map_err(|e| {
271 ValidationError::InvalidField(format!("Failed to read public key file: {e}"))
272 })?;
273
274 Ok(Self::new_rs256(private_key, public_key))
275 }
276
277 #[cfg(test)]
279 pub fn new_random_hs256() -> Self {
280 use rand::TryRngCore;
281
282 let mut secret_key = vec![0u8; 32];
283 rand::rng().try_fill_bytes(&mut secret_key).unwrap();
284 Self::new_hs256(secret_key)
285 }
286
287 pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
289 self.issuer = Some(issuer.into());
290 self
291 }
292
293 pub fn with_metadata(mut self, include_metadata: bool) -> Self {
295 self.include_metadata = include_metadata;
296 self
297 }
298
299 pub fn jwt_algorithm(&self) -> Algorithm {
301 match &self.algorithm {
302 JwtAlgorithm::RS256 { .. } => Algorithm::RS256,
303 JwtAlgorithm::HS256 { .. } => Algorithm::HS256,
304 }
305 }
306
307 pub fn get_encoding_key(&self) -> Result<EncodingKey, Error> {
310 match &self.algorithm {
311 JwtAlgorithm::RS256 { private_key, .. } => EncodingKey::from_rsa_pem(private_key)
312 .map_err(|e| {
313 ValidationError::InvalidField(format!("Invalid RSA private key: {e}")).into()
314 }),
315 JwtAlgorithm::HS256 { secret_key } => Ok(EncodingKey::from_secret(secret_key)),
316 }
317 }
318
319 pub fn get_decoding_key(&self) -> Result<DecodingKey, Error> {
321 match &self.algorithm {
322 JwtAlgorithm::RS256 { public_key, .. } => DecodingKey::from_rsa_pem(public_key)
323 .map_err(|e| {
324 ValidationError::InvalidField(format!("Invalid RSA public key: {e}")).into()
325 }),
326 JwtAlgorithm::HS256 { secret_key } => Ok(DecodingKey::from_secret(secret_key)),
327 }
328 }
329
330 pub fn get_validation(&self) -> Validation {
332 Validation::new(self.jwt_algorithm())
333 }
334}
335
336#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct Session {
338 pub token: SessionToken,
340
341 pub user_id: UserId,
343
344 pub user_agent: Option<String>,
346
347 pub ip_address: Option<String>,
349
350 pub created_at: DateTime<Utc>,
352
353 pub updated_at: DateTime<Utc>,
355
356 pub expires_at: DateTime<Utc>,
358}
359
360impl Session {
361 pub fn builder() -> SessionBuilder {
362 SessionBuilder::default()
363 }
364
365 pub fn is_expired(&self) -> bool {
366 Utc::now() > self.expires_at
367 }
368
369 pub fn to_jwt_claims(&self, issuer: Option<String>, include_metadata: bool) -> JwtClaims {
371 let metadata = if include_metadata {
372 Some(JwtMetadata {
373 user_agent: self.user_agent.clone(),
374 ip_address: self.ip_address.clone(),
375 })
376 } else {
377 None
378 };
379
380 JwtClaims {
381 sub: self.user_id.to_string(),
382 iat: self.created_at.timestamp(),
383 exp: self.expires_at.timestamp(),
384 iss: issuer,
385 metadata,
386 }
387 }
388
389 pub fn from_jwt_claims(token: SessionToken, claims: &JwtClaims) -> Self {
391 let now = Utc::now();
392 let created_at = DateTime::from_timestamp(claims.iat, 0).unwrap_or(now);
393 let expires_at = DateTime::from_timestamp(claims.exp, 0).unwrap_or(now);
394
395 let (user_agent, ip_address) = if let Some(metadata) = &claims.metadata {
396 (metadata.user_agent.clone(), metadata.ip_address.clone())
397 } else {
398 (None, None)
399 };
400
401 Self {
402 token,
403 user_id: UserId::new(&claims.sub),
404 user_agent,
405 ip_address,
406 created_at,
407 updated_at: now,
408 expires_at,
409 }
410 }
411}
412
413#[derive(Default)]
414pub struct SessionBuilder {
415 token: Option<SessionToken>,
416 user_id: Option<UserId>,
417 user_agent: Option<String>,
418 ip_address: Option<String>,
419 created_at: Option<DateTime<Utc>>,
420 updated_at: Option<DateTime<Utc>>,
421 expires_at: Option<DateTime<Utc>>,
422}
423
424impl SessionBuilder {
425 pub fn token(mut self, token: SessionToken) -> Self {
426 self.token = Some(token);
427 self
428 }
429
430 pub fn user_id(mut self, user_id: UserId) -> Self {
431 self.user_id = Some(user_id);
432 self
433 }
434
435 pub fn user_agent(mut self, user_agent: Option<String>) -> Self {
436 self.user_agent = user_agent;
437 self
438 }
439
440 pub fn ip_address(mut self, ip_address: Option<String>) -> Self {
441 self.ip_address = ip_address;
442 self
443 }
444
445 pub fn created_at(mut self, created_at: DateTime<Utc>) -> Self {
446 self.created_at = Some(created_at);
447 self
448 }
449
450 pub fn updated_at(mut self, updated_at: DateTime<Utc>) -> Self {
451 self.updated_at = Some(updated_at);
452 self
453 }
454
455 pub fn expires_at(mut self, expires_at: DateTime<Utc>) -> Self {
456 self.expires_at = Some(expires_at);
457 self
458 }
459
460 pub fn build(self) -> Result<Session, Error> {
461 let now = Utc::now();
462 Ok(Session {
463 token: self.token.unwrap_or_default(),
464 user_id: self.user_id.ok_or(ValidationError::MissingField(
465 "User ID is required".to_string(),
466 ))?,
467 user_agent: self.user_agent,
468 ip_address: self.ip_address,
469 created_at: self.created_at.unwrap_or(now),
470 updated_at: self.updated_at.unwrap_or(now),
471 expires_at: self.expires_at.unwrap_or(now + Duration::days(30)),
472 })
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use chrono::Duration;
479
480 use super::*;
481
482 const TEST_HS256_SECRET: &[u8] = b"this_is_a_test_secret_key_for_hs256_jwt_tokens_not_for_prod";
484
485 const TEST_RS256_PRIVATE_KEY: &[u8] = b"-----BEGIN PRIVATE KEY-----
488MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDBsFIR164UGIOZ
489R2nT57RQ8AloqAmJXh5KdoKZjHi5uSRALSASp1Dk0tDjiiwqvfWiUItcVqZRqsx4
490VuzjpkdoeWvwBoJ91K+DjFEAG7RjbNoaITgY8Ec5QjulpLTh9WDUeqUu4ZxPp9rF
491H+S3uJK2sD1K2KOGRVcT0a+rIyXDOXr14J7XGbB5W7j2EvkKXZinzKcdMpsL4NBu
4928ArJ8qV6lLBeKB+IbKrV0yUQGFAjTA8eoaSNaHJAZD0kubEdXEprB1SZpvaL3lZM
493AcqS6ZATo8IfiXj7H7RSHLf3ORYxQTX4T01gSfmSfgEOdTySdCSuFmDrsjcR2nWe
494Ly0QWM4jAgMBAAECggEAG9wzueWhtbn0TVB54aVjCP9grcFPTzHkE9w/GzzFmBq6
495+FDlW6QzMm7mkCGYX8o03RT5Lsjh9z5PrKxS5R35CIc/+5Bxew25n1JIIRwFvbAd
496y9i6ZnqYFsg2/IkYDFE3jT4E/keCgeyy6bGVkchcBijh8B8ASo3fzCCDGbqeXG8V
4979WEhN+xrEwJ/5s3IYY0JSVrL4BzoQT/R9/+IsvUQw9aOECDXpFsRLjoze3JVXzYa
498LklDJWe1z3i+4mR/Gwx1GLRL64bJFz0u8zUVSkY5T3SZLr7HGjlrtc/7DIctyx5w
499h80nRDohVih69z1AViXSIzYRvJ3tIq8Gp5EvYjieZQKBgQDi1Y5hvn8+KO9+9mPK
500lx/P92M1pUfSuALILctFWyFbY7XKYApJud0Nme81ASaNofINpka7tWOEBk8H0lyy
501W9uELDYHtVxKU0Ch1Q0joeKb3vcF0wMBMdOiOef+AH4R9ZqF8Mbhc/lwb86vl1BL
5021zFQZVpjg0Un57PMKefwl/yS5wKBgQDal8DTj1UaOGjsx667nUE1x6ILdRlHMIe1
503lf1VqCkP8ykFMe3iDJE1/rW/ct8uO+ZEf/8nbjeCHcnrtdF14HEPdspCSGvXW87W
50465Lsx0O7gdMKZEnN7BarTikpWJU3COcgQHGFsqjZ+07ujQWj8dPrNTd9dsYYFky8
505OKtmXJQ/ZQKBgA5G/NBAKkgiUXi/T2an/nObkZ4FyjCELoClCT9TThUvgHi9dMhR
506L420m67NZLzTbaXYSml0MFBWCVFntzfuujFmivwPOUDgXpgRDeOpQ9clwIyYTH8d
507wMFcPbLqGwVMXS6DCjGUmCWwk+TPdFlhsRPrXTYYRBkP52w5UwT8vAQPAoGAZEMu
5084trfggNVvSVp9AwRGQXUQcUYLxsHZDbD2EIlc3do3UUlg4WYJVgLLSEXVTGMUOcU
509tZVMSJY5Q7BFvvePZDRsWTK2pDUsDlBHN+u+GYdWsXGGmLktPK3BG4HSD0g6GwT0
510DQsBf9pRPgHZEHWfakciiJ2uBuZTlBG6LF1ScjECgYEA4DPQopjh/kS9j5NyUMDA
5115Pvz2mppg0NR7RQjDGET3Lh4/lDgfFyJOlsRLF+kUgAOb4s3tPg+5hujTq2FpotK
512JFQKh2GE6V1BMi+qJ9ipj0ESBv7rqPYC8ShUSr/SbkRU8jg2tOcvw+7KNtaMk6rv
513wl6BPaq7Rv4JOPgimQGP3d4=
514-----END PRIVATE KEY-----";
515
516 const TEST_RS256_PUBLIC_KEY: &[u8] = b"-----BEGIN PUBLIC KEY-----
517MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwbBSEdeuFBiDmUdp0+e0
518UPAJaKgJiV4eSnaCmYx4ubkkQC0gEqdQ5NLQ44osKr31olCLXFamUarMeFbs46ZH
519aHlr8AaCfdSvg4xRABu0Y2zaGiE4GPBHOUI7paS04fVg1HqlLuGcT6faxR/kt7iS
520trA9StijhkVXE9GvqyMlwzl69eCe1xmweVu49hL5Cl2Yp8ynHTKbC+DQbvAKyfKl
521epSwXigfiGyq1dMlEBhQI0wPHqGkjWhyQGQ9JLmxHVxKawdUmab2i95WTAHKkumQ
522E6PCH4l4+x+0Uhy39zkWMUE1+E9NYEn5kn4BDnU8knQkrhZg67I3Edp1ni8tEFjO
523IwIDAQAB
524-----END PUBLIC KEY-----";
525
526 #[test]
527 fn test_random_rs256_key_generation() {
528 let config = JwtConfig::new_rs256(
530 TEST_RS256_PRIVATE_KEY.to_vec(),
531 TEST_RS256_PUBLIC_KEY.to_vec(),
532 );
533
534 let user_id = UserId::new_random();
536 let session = Session::builder()
537 .user_id(user_id.clone())
538 .expires_at(Utc::now() + Duration::days(1))
539 .build()
540 .unwrap();
541
542 let claims = session.to_jwt_claims(None, false);
543
544 let token = SessionToken::new_jwt(&claims, &config).unwrap();
546
547 let verified_claims = token.verify_jwt(&config).unwrap();
549
550 assert_eq!(verified_claims.sub, user_id.to_string());
552 }
553
554 #[test]
555 fn test_session_token_simple() {
556 let id = SessionToken::new_random();
557 match &id {
558 SessionToken::Opaque(token) => {
559 assert_eq!(id.to_string(), token.to_string());
560 }
561 _ => panic!("Expected simple token"),
562 }
563 }
564
565 #[test]
566 fn test_session_builder() {
567 let session = Session::builder()
568 .user_id(UserId::new_random())
569 .user_agent(Some("test".to_string()))
570 .ip_address(Some("127.0.0.1".to_string()))
571 .expires_at(Utc::now() + Duration::days(30))
572 .build()
573 .unwrap();
574
575 assert!(!session.is_expired());
576 }
577
578 #[test]
579 fn test_jwt_config_hs256() {
580 let config = JwtConfig::new_hs256(TEST_HS256_SECRET.to_vec());
581
582 match &config.algorithm {
583 JwtAlgorithm::HS256 { secret_key } => {
584 assert_eq!(secret_key, &TEST_HS256_SECRET.to_vec());
585 }
586 _ => panic!("Expected HS256 algorithm"),
587 }
588
589 assert_eq!(config.jwt_algorithm(), Algorithm::HS256);
590 }
591
592 #[test]
593 fn test_jwt_config_random_hs256() {
594 let config = JwtConfig::new_random_hs256();
595
596 match &config.algorithm {
597 JwtAlgorithm::HS256 { secret_key } => {
598 assert_eq!(secret_key.len(), 32);
599 }
600 _ => panic!("Expected HS256 algorithm"),
601 }
602 }
603
604 #[test]
605 fn test_jwt_token_creation_and_verification_hs256() {
606 let config = JwtConfig::new_hs256(TEST_HS256_SECRET.to_vec())
607 .with_issuer("test-issuer-hs256")
608 .with_metadata(true);
609
610 let user_id = UserId::new_random();
611 let session = Session::builder()
612 .user_id(user_id.clone())
613 .user_agent(Some("test-agent-hs256".to_string()))
614 .ip_address(Some("127.0.0.2".to_string()))
615 .expires_at(Utc::now() + Duration::days(1))
616 .build()
617 .unwrap();
618
619 let claims = session.to_jwt_claims(config.issuer.clone(), config.include_metadata);
621
622 let token = SessionToken::new_jwt(&claims, &config).unwrap();
624
625 let verified_claims = token.verify_jwt(&config).unwrap();
627
628 assert_eq!(verified_claims.sub, user_id.to_string());
629 assert_eq!(verified_claims.iss, Some("test-issuer-hs256".to_string()));
630 assert!(verified_claims.metadata.is_some());
631 let metadata = verified_claims.metadata.unwrap();
632 assert_eq!(metadata.user_agent, Some("test-agent-hs256".to_string()));
633 assert_eq!(metadata.ip_address, Some("127.0.0.2".to_string()));
634
635 let token2 = SessionToken::new_jwt_hs256(&claims, TEST_HS256_SECRET).unwrap();
637 let verified_claims2 = token2.verify_jwt_hs256(TEST_HS256_SECRET).unwrap();
638 assert_eq!(verified_claims2.sub, user_id.to_string());
639 }
640}