1use std::collections::HashMap;
29use std::future::Future;
30use std::pin::Pin;
31use std::sync::{Mutex, OnceLock};
32use std::time::{Duration as StdDuration, Instant};
33
34use argon2::password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString};
35use argon2::Argon2;
36use chrono::{DateTime, Duration, Utc};
37use rand::rngs::OsRng;
38use rand::RngCore;
39use sqlx::Row as _;
40
41use crate::context::Context;
42use crate::error::Error;
43use crate::http::{Request, Response};
44use crate::middleware::Next;
45use crate::orm::Db;
46
47pub const SESSION_COOKIE: &str = "rustio_session";
51
52pub const SESSION_TTL_DAYS: i64 = 7;
56
57const SESSION_TOKEN_BYTES: usize = 32;
60
61pub const ROLE_ADMIN: &str = "admin";
65
66pub const ROLE_USER: &str = "user";
68
69type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
70
71#[derive(Debug, Clone, PartialEq)]
74pub struct User {
75 pub id: i64,
76 pub email: String,
77 pub password_hash: String,
80 pub is_active: bool,
81 pub role: String,
82}
83
84impl User {
85 pub fn is_admin(&self) -> bool {
95 crate::admin::rbac::Role::from_role_string(&self.role).is_some()
96 }
97}
98
99#[derive(Debug, Clone, PartialEq)]
102pub struct Identity {
103 pub user_id: i64,
104 pub email: String,
105 pub is_admin: bool,
106}
107
108impl From<&User> for Identity {
109 fn from(u: &User) -> Self {
110 Self {
111 user_id: u.id,
112 email: u.email.clone(),
113 is_admin: u.is_admin(),
114 }
115 }
116}
117
118#[non_exhaustive]
124#[derive(Debug, Clone, PartialEq)]
125pub struct Session {
126 pub id: String,
129 pub user_id: i64,
130 pub expires_at: DateTime<Utc>,
131 pub csrf_token: String,
136}
137
138pub mod password {
143 use super::*;
148
149 pub fn hash(password: &str) -> Result<String, Error> {
155 if password.is_empty() {
156 return Err(Error::BadRequest("password must not be empty".into()));
157 }
158 let salt = SaltString::generate(&mut OsRng);
159 let hash = Argon2::default()
160 .hash_password(password.as_bytes(), &salt)
161 .map_err(|e| Error::Internal(format!("password hashing failed: {e}")))?;
162 Ok(hash.to_string())
163 }
164
165 pub fn verify(password: &str, stored: &str) -> bool {
171 let Ok(parsed) = PasswordHash::new(stored) else {
175 return false;
176 };
177 Argon2::default()
178 .verify_password(password.as_bytes(), &parsed)
179 .is_ok()
180 }
181}
182
183fn generate_token() -> String {
189 use std::fmt::Write;
190 let mut buf = [0u8; SESSION_TOKEN_BYTES];
191 OsRng.fill_bytes(&mut buf);
192 let mut out = String::with_capacity(SESSION_TOKEN_BYTES * 2);
193 for b in buf {
194 let _ = write!(out, "{b:02x}");
195 }
196 out
197}
198
199#[derive(Debug, Clone, PartialEq)]
212pub struct CsrfToken(pub String);
213
214pub mod csrf {
215 pub fn generate_token() -> String {
230 super::generate_token()
234 }
235
236 pub fn verify_token(expected: &str, provided: &str) -> bool {
245 if expected.is_empty() || provided.is_empty() {
246 return false;
247 }
248 if expected.len() != provided.len() {
249 return false;
250 }
251 let mut diff: u8 = 0;
252 for (a, b) in expected.bytes().zip(provided.bytes()) {
253 diff |= a ^ b;
254 }
255 diff == 0
256 }
257}
258
259pub mod session {
260 use super::*;
266
267 pub async fn create(db: &Db, user_id: i64) -> Result<Session, Error> {
275 let id = generate_token();
276 let csrf_token = csrf::generate_token();
277 let expires_at = Utc::now() + Duration::days(SESSION_TTL_DAYS);
278 sqlx::query(
279 "INSERT INTO rustio_sessions (id, user_id, expires_at, csrf_token)
280 VALUES (?, ?, ?, ?)",
281 )
282 .bind(&id)
283 .bind(user_id)
284 .bind(expires_at)
285 .bind(&csrf_token)
286 .execute(db.pool())
287 .await?;
288 Ok(Session {
289 id,
290 user_id,
291 expires_at,
292 csrf_token,
293 })
294 }
295
296 pub async fn find_valid(db: &Db, id: &str) -> Result<Option<Session>, Error> {
305 let row = sqlx::query(
306 "SELECT id, user_id, expires_at, csrf_token
307 FROM rustio_sessions WHERE id = ?",
308 )
309 .bind(id)
310 .fetch_optional(db.pool())
311 .await?;
312 let Some(r) = row else {
313 return Ok(None);
314 };
315 let expires_at: DateTime<Utc> = r.try_get("expires_at")?;
316 if expires_at <= Utc::now() {
317 let _ = delete(db, id).await;
318 return Ok(None);
319 }
320 Ok(Some(Session {
321 id: r.try_get("id")?,
322 user_id: r.try_get("user_id")?,
323 expires_at,
324 csrf_token: r.try_get("csrf_token")?,
325 }))
326 }
327
328 pub async fn delete(db: &Db, id: &str) -> Result<(), Error> {
331 sqlx::query("DELETE FROM rustio_sessions WHERE id = ?")
332 .bind(id)
333 .execute(db.pool())
334 .await?;
335 Ok(())
336 }
337
338 pub async fn sweep_expired(db: &Db) -> Result<u64, Error> {
341 let result = sqlx::query("DELETE FROM rustio_sessions WHERE expires_at <= ?")
342 .bind(Utc::now())
343 .execute(db.pool())
344 .await?;
345 Ok(result.rows_affected())
346 }
347}
348
349pub mod user {
354 use super::*;
358
359 pub async fn create(db: &Db, email: &str, password: &str, role: &str) -> Result<User, Error> {
367 let email = normalise_email(email);
368 validate_email(&email)?;
369 if role != ROLE_ADMIN && role != ROLE_USER {
370 return Err(Error::BadRequest(format!(
371 "role must be `{ROLE_ADMIN}` or `{ROLE_USER}`, got `{role}`"
372 )));
373 }
374 let hash = password::hash(password)?;
375 let result = sqlx::query(
376 "INSERT INTO rustio_users (email, password_hash, is_active, role)
377 VALUES (?, ?, 1, ?)",
378 )
379 .bind(&email)
380 .bind(&hash)
381 .bind(role)
382 .execute(db.pool())
383 .await
384 .map_err(|e| match &e {
385 sqlx::Error::Database(de) if de.is_unique_violation() => {
386 Error::BadRequest(format!("a user with email `{email}` already exists"))
387 }
388 _ => Error::from(e),
389 })?;
390 Ok(User {
391 id: result.last_insert_rowid(),
392 email,
393 password_hash: hash,
394 is_active: true,
395 role: role.to_string(),
396 })
397 }
398
399 pub async fn find_by_email(db: &Db, email: &str) -> Result<Option<User>, Error> {
400 let email = normalise_email(email);
401 let row = sqlx::query(
402 "SELECT id, email, password_hash, is_active, role
403 FROM rustio_users WHERE email = ?",
404 )
405 .bind(&email)
406 .fetch_optional(db.pool())
407 .await?;
408 match row {
409 Some(r) => Ok(Some(user_from_row(&r)?)),
410 None => Ok(None),
411 }
412 }
413
414 pub async fn find_by_id(db: &Db, id: i64) -> Result<Option<User>, Error> {
415 let row = sqlx::query(
416 "SELECT id, email, password_hash, is_active, role
417 FROM rustio_users WHERE id = ?",
418 )
419 .bind(id)
420 .fetch_optional(db.pool())
421 .await?;
422 match row {
423 Some(r) => Ok(Some(user_from_row(&r)?)),
424 None => Ok(None),
425 }
426 }
427
428 pub async fn set_password(db: &Db, id: i64, password: &str) -> Result<(), Error> {
436 let hash = password::hash(password)?;
437 let mut tx = db.pool().begin().await?;
438 sqlx::query("UPDATE rustio_users SET password_hash = ? WHERE id = ?")
439 .bind(&hash)
440 .bind(id)
441 .execute(&mut *tx)
442 .await?;
443 sqlx::query("DELETE FROM rustio_sessions WHERE user_id = ?")
444 .bind(id)
445 .execute(&mut *tx)
446 .await?;
447 tx.commit().await?;
448 Ok(())
449 }
450
451 pub async fn set_active(db: &Db, id: i64, is_active: bool) -> Result<(), Error> {
452 sqlx::query("UPDATE rustio_users SET is_active = ? WHERE id = ?")
453 .bind(is_active)
454 .bind(id)
455 .execute(db.pool())
456 .await?;
457 Ok(())
458 }
459
460 pub async fn count(db: &Db) -> Result<i64, Error> {
461 let row = sqlx::query("SELECT COUNT(*) FROM rustio_users")
462 .fetch_one(db.pool())
463 .await?;
464 Ok(row.try_get(0)?)
465 }
466
467 fn user_from_row(r: &sqlx::sqlite::SqliteRow) -> Result<User, Error> {
468 Ok(User {
469 id: r.try_get("id")?,
470 email: r.try_get("email")?,
471 password_hash: r.try_get("password_hash")?,
472 is_active: r.try_get("is_active")?,
473 role: r.try_get("role")?,
474 })
475 }
476}
477
478pub fn normalise_email(email: &str) -> String {
484 email.trim().to_lowercase()
485}
486
487pub fn validate_email(email: &str) -> Result<(), Error> {
493 if email.is_empty() {
494 return Err(Error::BadRequest("email must not be empty".into()));
495 }
496 let Some((local, domain)) = email.split_once('@') else {
497 return Err(Error::BadRequest(format!(
498 "`{email}` is not a valid email (missing @)"
499 )));
500 };
501 if local.is_empty() || domain.is_empty() || !domain.contains('.') {
502 return Err(Error::BadRequest(format!("`{email}` is not a valid email")));
503 }
504 Ok(())
505}
506
507pub fn dummy_password_hash() -> &'static str {
521 static DUMMY: OnceLock<String> = OnceLock::new();
522 DUMMY.get_or_init(|| {
523 password::hash("timing-attack-filler-not-a-real-password").expect("dummy hash must succeed")
524 })
525}
526
527struct FailureEntry {
533 count: u32,
534 locked_until: Instant,
537}
538
539pub struct LoginRateLimiter {
551 failures: Mutex<HashMap<String, FailureEntry>>,
552 max_failures: u32,
553 lockout: StdDuration,
554}
555
556impl LoginRateLimiter {
557 pub const MAX_FAILURES: u32 = 5;
561 pub const LOCKOUT: StdDuration = StdDuration::from_secs(60);
562
563 pub fn new() -> Self {
564 Self::with_params(Self::MAX_FAILURES, Self::LOCKOUT)
565 }
566
567 pub fn with_params(max_failures: u32, lockout: StdDuration) -> Self {
570 Self {
571 failures: Mutex::new(HashMap::new()),
572 max_failures,
573 lockout,
574 }
575 }
576
577 pub fn global() -> &'static Self {
581 static INSTANCE: OnceLock<LoginRateLimiter> = OnceLock::new();
582 INSTANCE.get_or_init(LoginRateLimiter::new)
583 }
584
585 pub fn check(&self, key: &str) -> Result<(), StdDuration> {
589 let mut map = self.failures.lock().expect("rate-limiter mutex poisoned");
590 match map.get(key) {
591 Some(entry) if entry.count >= self.max_failures => {
592 let now = Instant::now();
593 if entry.locked_until > now {
594 Err(entry.locked_until - now)
595 } else {
596 map.remove(key);
598 Ok(())
599 }
600 }
601 _ => Ok(()),
602 }
603 }
604
605 pub fn record_failure(&self, key: &str) {
609 let mut map = self.failures.lock().expect("rate-limiter mutex poisoned");
610 let entry = map.entry(key.to_string()).or_insert(FailureEntry {
611 count: 0,
612 locked_until: Instant::now(),
613 });
614 entry.count = entry.count.saturating_add(1);
615 if entry.count >= self.max_failures {
616 entry.locked_until = Instant::now() + self.lockout;
617 }
618 }
619
620 pub fn record_success(&self, key: &str) {
624 self.failures
625 .lock()
626 .expect("rate-limiter mutex poisoned")
627 .remove(key);
628 }
629
630 pub fn compose_key(email: &str, ip: Option<&str>) -> String {
645 match ip {
646 Some(ip) => format!("email:{email}|ip:{ip}"),
647 None => format!("email:{email}"),
648 }
649 }
650}
651
652impl Default for LoginRateLimiter {
653 fn default() -> Self {
654 Self::new()
655 }
656}
657
658pub async fn resolve_identity_with_session(
670 db: &Db,
671 token: Option<&str>,
672) -> Option<(Identity, Session)> {
673 let token = token?;
674 let sess = session::find_valid(db, token).await.ok().flatten()?;
675 let user = user::find_by_id(db, sess.user_id).await.ok().flatten()?;
676 if !user.is_active {
677 return None;
678 }
679 Some((Identity::from(&user), sess))
680}
681
682pub async fn resolve_identity(db: &Db, token: Option<&str>) -> Option<Identity> {
690 resolve_identity_with_session(db, token)
691 .await
692 .map(|(identity, _)| identity)
693}
694
695pub fn authenticate(
709 db: Db,
710) -> impl Fn(Request, Next) -> BoxFuture<Result<Response, Error>> + Send + Sync + Clone + 'static {
711 move |mut req, next| {
712 let db = db.clone();
713 Box::pin(async move {
714 let token = req.cookie(SESSION_COOKIE);
715 if let Some((identity, session)) =
716 resolve_identity_with_session(&db, token.as_deref()).await
717 {
718 req.ctx_mut().insert(CsrfToken(session.csrf_token));
723 req.ctx_mut().insert(identity);
724 }
725 next.run(req).await
726 })
727 }
728}
729
730pub async fn ensure_core_tables(db: &Db) -> Result<(), Error> {
747 db.execute(
748 "CREATE TABLE IF NOT EXISTS rustio_users (
749 id INTEGER PRIMARY KEY AUTOINCREMENT,
750 email TEXT NOT NULL UNIQUE,
751 password_hash TEXT NOT NULL,
752 is_active INTEGER NOT NULL DEFAULT 1,
753 role TEXT NOT NULL DEFAULT 'user',
754 created_at TEXT NOT NULL DEFAULT (datetime('now'))
755 )",
756 )
757 .await?;
758 db.execute(
759 "CREATE TABLE IF NOT EXISTS rustio_sessions (
760 id TEXT PRIMARY KEY,
761 user_id INTEGER NOT NULL,
762 expires_at TEXT NOT NULL,
763 csrf_token TEXT NOT NULL DEFAULT '',
764 created_at TEXT NOT NULL DEFAULT (datetime('now')),
765 FOREIGN KEY (user_id) REFERENCES rustio_users(id) ON DELETE CASCADE
766 )",
767 )
768 .await?;
769
770 let cols: Vec<String> =
775 sqlx::query_scalar::<_, String>("SELECT name FROM pragma_table_info('rustio_sessions')")
776 .fetch_all(db.pool())
777 .await?;
778 if !cols.iter().any(|c| c == "csrf_token") {
779 db.execute("ALTER TABLE rustio_sessions ADD COLUMN csrf_token TEXT NOT NULL DEFAULT ''")
780 .await?;
781 }
782
783 db.execute(
789 "CREATE TABLE IF NOT EXISTS rustio_admin_actions (
790 id INTEGER PRIMARY KEY AUTOINCREMENT,
791 user_id INTEGER NOT NULL,
792 action_type TEXT NOT NULL,
793 model_name TEXT NOT NULL,
794 object_id INTEGER NOT NULL,
795 timestamp TEXT NOT NULL,
796 ip_address TEXT NULL,
797 summary TEXT NOT NULL,
798 FOREIGN KEY (user_id) REFERENCES rustio_users(id) ON DELETE CASCADE
799 )",
800 )
801 .await?;
802 db.execute(
803 "CREATE INDEX IF NOT EXISTS idx_rustio_admin_actions_model_object
804 ON rustio_admin_actions(model_name, object_id)",
805 )
806 .await?;
807 db.execute(
808 "CREATE INDEX IF NOT EXISTS idx_rustio_admin_actions_timestamp
809 ON rustio_admin_actions(timestamp DESC)",
810 )
811 .await?;
812 Ok(())
813}
814
815pub fn in_production() -> bool {
825 std::env::var("RUSTIO_ENV")
826 .map(|v| {
827 let v = v.to_ascii_lowercase();
828 v == "production" || v == "prod"
829 })
830 .unwrap_or(false)
831}
832
833pub fn bearer_token(req: &Request) -> Option<&str> {
837 req.headers()
838 .get("authorization")
839 .and_then(|v| v.to_str().ok())
840 .and_then(|s| s.strip_prefix("Bearer "))
841}
842
843pub fn identity(ctx: &Context) -> Option<&Identity> {
844 ctx.get::<Identity>()
845}
846
847pub fn require_auth(ctx: &Context) -> Result<&Identity, Error> {
848 identity(ctx).ok_or(Error::Unauthorized)
849}
850
851pub fn require_admin(ctx: &Context) -> Result<&Identity, Error> {
852 let id = require_auth(ctx)?;
853 if !id.is_admin {
854 return Err(Error::Forbidden);
855 }
856 Ok(id)
857}
858
859#[cfg(test)]
864mod tests {
865 use super::*;
866
867 fn admin_identity() -> Identity {
868 Identity {
869 user_id: 1,
870 email: "admin@example.com".into(),
871 is_admin: true,
872 }
873 }
874
875 fn user_identity() -> Identity {
876 Identity {
877 user_id: 2,
878 email: "user@example.com".into(),
879 is_admin: false,
880 }
881 }
882
883 #[test]
886 fn identity_returns_none_when_absent() {
887 let ctx = Context::new();
888 assert!(identity(&ctx).is_none());
889 }
890
891 #[test]
892 fn identity_returns_reference_when_attached() {
893 let mut ctx = Context::new();
894 ctx.insert(user_identity());
895 assert_eq!(
896 identity(&ctx).map(|i| i.email.as_str()),
897 Some("user@example.com")
898 );
899 }
900
901 #[test]
902 fn require_auth_missing_returns_unauthorized() {
903 let ctx = Context::new();
904 assert!(matches!(require_auth(&ctx), Err(Error::Unauthorized)));
905 }
906
907 #[test]
908 fn require_admin_non_admin_returns_forbidden() {
909 let mut ctx = Context::new();
910 ctx.insert(user_identity());
911 assert!(matches!(require_admin(&ctx), Err(Error::Forbidden)));
912 }
913
914 #[test]
915 fn require_admin_admin_returns_identity() {
916 let mut ctx = Context::new();
917 ctx.insert(admin_identity());
918 let id = require_admin(&ctx).unwrap();
919 assert!(id.is_admin);
920 }
921
922 #[test]
925 fn hash_then_verify_succeeds() {
926 let h = password::hash("correct horse battery staple").unwrap();
927 assert!(password::verify("correct horse battery staple", &h));
928 }
929
930 #[test]
931 fn verify_wrong_password_fails() {
932 let h = password::hash("real").unwrap();
933 assert!(!password::verify("fake", &h));
934 }
935
936 #[test]
937 fn verify_invalid_hash_returns_false_without_panic() {
938 assert!(!password::verify("anything", ""));
940 assert!(!password::verify("anything", "not a phc string"));
941 assert!(!password::verify("anything", "$argon2id$v=19$m=1"));
942 }
943
944 #[test]
945 fn hash_rejects_empty_password() {
946 assert!(matches!(password::hash(""), Err(Error::BadRequest(_))));
947 }
948
949 #[test]
950 fn hash_is_salted_so_same_input_produces_different_hash() {
951 let a = password::hash("same").unwrap();
952 let b = password::hash("same").unwrap();
953 assert_ne!(a, b, "identical inputs must produce different hashes");
954 assert!(password::verify("same", &a));
956 assert!(password::verify("same", &b));
957 }
958
959 #[test]
962 fn normalise_email_trims_and_lowercases() {
963 assert_eq!(
964 normalise_email(" Alice@EXAMPLE.com "),
965 "alice@example.com"
966 );
967 }
968
969 #[test]
970 fn validate_email_accepts_reasonable_forms() {
971 assert!(validate_email("a@b.co").is_ok());
972 assert!(validate_email("alice.smith+tag@example.co.uk").is_ok());
973 }
974
975 #[test]
976 fn validate_email_rejects_bad_forms() {
977 assert!(validate_email("").is_err());
978 assert!(validate_email("no-at-sign").is_err());
979 assert!(validate_email("@no-local").is_err());
980 assert!(validate_email("no-domain@").is_err());
981 assert!(validate_email("no-dot@localhost").is_err());
982 }
983
984 #[test]
987 fn generate_token_is_stable_length_and_hex() {
988 let t = generate_token();
989 assert_eq!(t.len(), SESSION_TOKEN_BYTES * 2);
990 assert!(t.chars().all(|c| c.is_ascii_hexdigit()));
991 }
992
993 #[test]
994 fn generate_token_does_not_repeat() {
995 let a = generate_token();
997 let b = generate_token();
998 assert_ne!(a, b);
999 }
1000
1001 async fn setup() -> Db {
1004 let db = Db::memory().await.unwrap();
1005 ensure_core_tables(&db).await.unwrap();
1006 db
1007 }
1008
1009 #[tokio::test]
1010 async fn user_create_round_trips() {
1011 let db = setup().await;
1012 let u = user::create(&db, "Admin@Example.com", "hunter2", ROLE_ADMIN)
1013 .await
1014 .unwrap();
1015 assert_eq!(u.email, "admin@example.com");
1017 assert!(u.is_admin());
1018 assert!(u.is_active);
1019
1020 let lookup = user::find_by_email(&db, "ADMIN@example.com")
1021 .await
1022 .unwrap()
1023 .unwrap();
1024 assert_eq!(lookup.id, u.id);
1025 assert!(password::verify("hunter2", &lookup.password_hash));
1026 }
1027
1028 #[test]
1029 fn is_admin_recognises_0_10_role_strings() {
1030 assert!(user_with_role("admin").is_admin());
1032 assert!(!user_with_role("user").is_admin());
1033 assert!(!user_with_role("").is_admin());
1034 assert!(user_with_role("superadmin").is_admin());
1036 assert!(user_with_role("restricted_admin").is_admin());
1037 assert!(user_with_role("editor").is_admin());
1038 assert!(user_with_role("viewer").is_admin());
1039 assert!(!user_with_role("nobody").is_admin());
1041 }
1042
1043 fn user_with_role(role: &str) -> User {
1044 User {
1045 id: 1,
1046 email: "t@example.com".into(),
1047 password_hash: "x".into(),
1048 is_active: true,
1049 role: role.into(),
1050 }
1051 }
1052
1053 #[tokio::test]
1054 async fn user_create_rejects_duplicate_email() {
1055 let db = setup().await;
1056 user::create(&db, "a@b.co", "pw", ROLE_USER).await.unwrap();
1057 let err = user::create(&db, "a@b.co", "pw2", ROLE_USER).await;
1058 assert!(matches!(err, Err(Error::BadRequest(_))));
1059 }
1060
1061 #[tokio::test]
1062 async fn user_create_rejects_unknown_role() {
1063 let db = setup().await;
1064 let err = user::create(&db, "a@b.co", "pw", "emperor").await;
1065 assert!(matches!(err, Err(Error::BadRequest(_))));
1066 }
1067
1068 #[tokio::test]
1069 async fn set_password_changes_verifiable_hash() {
1070 let db = setup().await;
1071 let u = user::create(&db, "a@b.co", "old", ROLE_USER).await.unwrap();
1072 user::set_password(&db, u.id, "new").await.unwrap();
1073 let reloaded = user::find_by_id(&db, u.id).await.unwrap().unwrap();
1074 assert!(!password::verify("old", &reloaded.password_hash));
1075 assert!(password::verify("new", &reloaded.password_hash));
1076 }
1077
1078 #[tokio::test]
1079 async fn set_active_toggles_flag() {
1080 let db = setup().await;
1081 let u = user::create(&db, "a@b.co", "pw", ROLE_USER).await.unwrap();
1082 user::set_active(&db, u.id, false).await.unwrap();
1083 let reloaded = user::find_by_id(&db, u.id).await.unwrap().unwrap();
1084 assert!(!reloaded.is_active);
1085 }
1086
1087 #[tokio::test]
1088 async fn session_create_and_find_returns_live_session() {
1089 let db = setup().await;
1090 let u = user::create(&db, "a@b.co", "pw", ROLE_USER).await.unwrap();
1091 let s = session::create(&db, u.id).await.unwrap();
1092 let found = session::find_valid(&db, &s.id).await.unwrap().unwrap();
1093 assert_eq!(found.user_id, u.id);
1094 assert_eq!(found.id, s.id);
1096 assert!(found.expires_at > Utc::now());
1097 }
1098
1099 #[tokio::test]
1100 async fn session_lookup_rejects_unknown_token() {
1101 let db = setup().await;
1102 let out = session::find_valid(&db, "deadbeef").await.unwrap();
1103 assert!(out.is_none());
1104 }
1105
1106 #[tokio::test]
1107 async fn session_lookup_rejects_expired_session() {
1108 let db = setup().await;
1109 let u = user::create(&db, "a@b.co", "pw", ROLE_USER).await.unwrap();
1110 let token = generate_token();
1113 sqlx::query("INSERT INTO rustio_sessions (id, user_id, expires_at) VALUES (?, ?, ?)")
1114 .bind(&token)
1115 .bind(u.id)
1116 .bind(Utc::now() - Duration::seconds(1))
1117 .execute(db.pool())
1118 .await
1119 .unwrap();
1120
1121 let out = session::find_valid(&db, &token).await.unwrap();
1122 assert!(out.is_none(), "expired sessions must not validate");
1123 }
1124
1125 #[tokio::test]
1126 async fn session_delete_invalidates_lookup() {
1127 let db = setup().await;
1128 let u = user::create(&db, "a@b.co", "pw", ROLE_USER).await.unwrap();
1129 let s = session::create(&db, u.id).await.unwrap();
1130 session::delete(&db, &s.id).await.unwrap();
1131 assert!(session::find_valid(&db, &s.id).await.unwrap().is_none());
1132 }
1133
1134 #[tokio::test]
1135 async fn sweep_expired_removes_only_expired() {
1136 let db = setup().await;
1137 let u = user::create(&db, "a@b.co", "pw", ROLE_USER).await.unwrap();
1138 let live = session::create(&db, u.id).await.unwrap();
1139 let dead_token = generate_token();
1140 sqlx::query("INSERT INTO rustio_sessions (id, user_id, expires_at) VALUES (?, ?, ?)")
1141 .bind(&dead_token)
1142 .bind(u.id)
1143 .bind(Utc::now() - Duration::seconds(1))
1144 .execute(db.pool())
1145 .await
1146 .unwrap();
1147
1148 let removed = session::sweep_expired(&db).await.unwrap();
1149 assert_eq!(removed, 1);
1150 assert!(session::find_valid(&db, &live.id).await.unwrap().is_some());
1151 assert!(session::find_valid(&db, &dead_token)
1152 .await
1153 .unwrap()
1154 .is_none());
1155 }
1156
1157 #[tokio::test]
1158 async fn deleting_user_cascades_to_sessions() {
1159 let db = setup().await;
1165 let u = user::create(&db, "a@b.co", "pw", ROLE_USER).await.unwrap();
1166 let s = session::create(&db, u.id).await.unwrap();
1167 assert!(session::find_valid(&db, &s.id).await.unwrap().is_some());
1168
1169 sqlx::query("DELETE FROM rustio_users WHERE id = ?")
1170 .bind(u.id)
1171 .execute(db.pool())
1172 .await
1173 .unwrap();
1174
1175 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM rustio_sessions")
1176 .fetch_one(db.pool())
1177 .await
1178 .unwrap();
1179 assert_eq!(
1180 count, 0,
1181 "FK cascade must have removed the orphan session; is PRAGMA foreign_keys on?"
1182 );
1183 }
1184
1185 #[tokio::test]
1186 async fn ensure_core_tables_is_idempotent() {
1187 let db = setup().await; ensure_core_tables(&db).await.unwrap();
1189 ensure_core_tables(&db).await.unwrap();
1190 assert_eq!(user::count(&db).await.unwrap(), 0);
1191 }
1192
1193 async fn seeded_user(db: &Db, role: &str) -> User {
1202 user::create(db, "u@example.com", "pw", role).await.unwrap()
1203 }
1204
1205 #[tokio::test]
1206 async fn resolve_identity_none_cookie_returns_none() {
1207 let db = setup().await;
1208 assert!(resolve_identity(&db, None).await.is_none());
1209 }
1210
1211 #[tokio::test]
1212 async fn resolve_identity_unknown_token_returns_none() {
1213 let db = setup().await;
1214 assert!(resolve_identity(&db, Some("not-a-real-token"))
1215 .await
1216 .is_none());
1217 }
1218
1219 #[tokio::test]
1220 async fn resolve_identity_expired_session_returns_none() {
1221 let db = setup().await;
1222 let u = seeded_user(&db, ROLE_USER).await;
1223 let token = generate_token();
1224 sqlx::query("INSERT INTO rustio_sessions (id, user_id, expires_at) VALUES (?, ?, ?)")
1225 .bind(&token)
1226 .bind(u.id)
1227 .bind(Utc::now() - Duration::seconds(1))
1228 .execute(db.pool())
1229 .await
1230 .unwrap();
1231 assert!(resolve_identity(&db, Some(&token)).await.is_none());
1232 }
1233
1234 #[tokio::test]
1235 async fn resolve_identity_inactive_user_returns_none() {
1236 let db = setup().await;
1237 let u = seeded_user(&db, ROLE_USER).await;
1238 user::set_active(&db, u.id, false).await.unwrap();
1239 let s = session::create(&db, u.id).await.unwrap();
1240 assert!(
1241 resolve_identity(&db, Some(&s.id)).await.is_none(),
1242 "inactive users must not resolve to an Identity"
1243 );
1244 }
1245
1246 #[tokio::test]
1247 async fn resolve_identity_deleted_user_returns_none() {
1248 let db = setup().await;
1252 let u = seeded_user(&db, ROLE_USER).await;
1253 let s = session::create(&db, u.id).await.unwrap();
1254 sqlx::query("DELETE FROM rustio_users WHERE id = ?")
1255 .bind(u.id)
1256 .execute(db.pool())
1257 .await
1258 .unwrap();
1259 assert!(resolve_identity(&db, Some(&s.id)).await.is_none());
1260 }
1261
1262 #[tokio::test]
1263 async fn resolve_identity_valid_admin_session_attaches_admin_identity() {
1264 let db = setup().await;
1265 let u = seeded_user(&db, ROLE_ADMIN).await;
1266 let s = session::create(&db, u.id).await.unwrap();
1267 let id = resolve_identity(&db, Some(&s.id)).await.unwrap();
1268 assert_eq!(id.user_id, u.id);
1269 assert!(id.is_admin);
1270 }
1271
1272 #[tokio::test]
1273 async fn resolve_identity_valid_user_session_attaches_non_admin_identity() {
1274 let db = setup().await;
1275 let u = seeded_user(&db, ROLE_USER).await;
1276 let s = session::create(&db, u.id).await.unwrap();
1277 let id = resolve_identity(&db, Some(&s.id)).await.unwrap();
1278 assert_eq!(id.user_id, u.id);
1279 assert!(!id.is_admin);
1280 }
1281
1282 #[tokio::test]
1285 async fn changing_password_invalidates_all_user_sessions() {
1286 let db = setup().await;
1287 let u = seeded_user(&db, ROLE_USER).await;
1288 let s1 = session::create(&db, u.id).await.unwrap();
1289 let s2 = session::create(&db, u.id).await.unwrap();
1290 assert!(session::find_valid(&db, &s1.id).await.unwrap().is_some());
1291 assert!(session::find_valid(&db, &s2.id).await.unwrap().is_some());
1292
1293 user::set_password(&db, u.id, "new password").await.unwrap();
1294
1295 let remaining: i64 =
1296 sqlx::query_scalar("SELECT COUNT(*) FROM rustio_sessions WHERE user_id = ?")
1297 .bind(u.id)
1298 .fetch_one(db.pool())
1299 .await
1300 .unwrap();
1301 assert_eq!(
1302 remaining, 0,
1303 "password change must wipe every live session for the user"
1304 );
1305 assert!(session::find_valid(&db, &s1.id).await.unwrap().is_none());
1306 assert!(session::find_valid(&db, &s2.id).await.unwrap().is_none());
1307 }
1308
1309 #[tokio::test]
1312 async fn find_valid_cleans_up_expired_row_inline() {
1313 let db = setup().await;
1314 let u = seeded_user(&db, ROLE_USER).await;
1315 let token = generate_token();
1316 sqlx::query("INSERT INTO rustio_sessions (id, user_id, expires_at) VALUES (?, ?, ?)")
1317 .bind(&token)
1318 .bind(u.id)
1319 .bind(Utc::now() - Duration::seconds(1))
1320 .execute(db.pool())
1321 .await
1322 .unwrap();
1323
1324 assert!(session::find_valid(&db, &token).await.unwrap().is_none());
1325
1326 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM rustio_sessions WHERE id = ?")
1328 .bind(&token)
1329 .fetch_one(db.pool())
1330 .await
1331 .unwrap();
1332 assert_eq!(count, 0, "find_valid must purge expired rows inline");
1333 }
1334
1335 #[test]
1338 fn rate_limiter_allows_up_to_threshold() {
1339 let limiter = LoginRateLimiter::with_params(3, StdDuration::from_secs(60));
1340 assert!(limiter.check("alice@example.com").is_ok());
1341 limiter.record_failure("alice@example.com");
1342 limiter.record_failure("alice@example.com");
1343 assert!(limiter.check("alice@example.com").is_ok());
1344 }
1345
1346 #[test]
1347 fn rate_limiter_locks_out_at_threshold() {
1348 let limiter = LoginRateLimiter::with_params(3, StdDuration::from_secs(60));
1349 for _ in 0..3 {
1350 limiter.record_failure("alice@example.com");
1351 }
1352 let result = limiter.check("alice@example.com");
1353 assert!(result.is_err(), "3rd failure must trip the lockout");
1354 let remaining = result.unwrap_err();
1355 assert!(remaining > StdDuration::ZERO);
1356 assert!(remaining <= StdDuration::from_secs(60));
1357 }
1358
1359 #[test]
1360 fn rate_limiter_resets_on_successful_login() {
1361 let limiter = LoginRateLimiter::with_params(3, StdDuration::from_secs(60));
1362 for _ in 0..3 {
1363 limiter.record_failure("alice@example.com");
1364 }
1365 assert!(limiter.check("alice@example.com").is_err());
1366
1367 limiter.record_success("alice@example.com");
1368 assert!(
1369 limiter.check("alice@example.com").is_ok(),
1370 "a successful login must clear the lockout counter"
1371 );
1372 }
1373
1374 #[tokio::test]
1375 async fn rate_limiter_lockout_expires_after_duration() {
1376 let limiter = LoginRateLimiter::with_params(3, StdDuration::from_millis(50));
1377 for _ in 0..3 {
1378 limiter.record_failure("bob@example.com");
1379 }
1380 assert!(limiter.check("bob@example.com").is_err());
1381
1382 tokio::time::sleep(StdDuration::from_millis(80)).await;
1383
1384 assert!(
1385 limiter.check("bob@example.com").is_ok(),
1386 "lockout must lift after the configured duration"
1387 );
1388 }
1389
1390 #[test]
1393 fn compose_key_email_only_is_stable() {
1394 let k = LoginRateLimiter::compose_key("alice@example.com", None);
1395 assert_eq!(k, "email:alice@example.com");
1396 }
1397
1398 #[test]
1399 fn compose_key_with_ip_is_distinct_from_email_only() {
1400 let a = LoginRateLimiter::compose_key("alice@example.com", None);
1401 let b = LoginRateLimiter::compose_key("alice@example.com", Some("203.0.113.5"));
1402 assert_ne!(a, b);
1403 assert_eq!(b, "email:alice@example.com|ip:203.0.113.5");
1404 }
1405
1406 #[test]
1407 fn compose_key_distinct_ips_produce_distinct_keys() {
1408 let a = LoginRateLimiter::compose_key("a@b.co", Some("10.0.0.1"));
1411 let b = LoginRateLimiter::compose_key("a@b.co", Some("10.0.0.2"));
1412 assert_ne!(a, b);
1413 }
1414
1415 #[test]
1418 fn csrf_generate_returns_hex_of_expected_length() {
1419 let t = csrf::generate_token();
1420 assert_eq!(t.len(), 64);
1422 assert!(t.chars().all(|c| c.is_ascii_hexdigit()));
1423 }
1424
1425 #[test]
1426 fn csrf_generate_produces_unique_tokens() {
1427 let a = csrf::generate_token();
1428 let b = csrf::generate_token();
1429 assert_ne!(a, b);
1430 }
1431
1432 #[test]
1433 fn csrf_verify_matching_returns_true() {
1434 let t = csrf::generate_token();
1435 assert!(csrf::verify_token(&t, &t));
1436 }
1437
1438 #[test]
1439 fn csrf_verify_mismatched_returns_false() {
1440 let t = csrf::generate_token();
1441 let other = csrf::generate_token();
1442 assert!(!csrf::verify_token(&t, &other));
1443 }
1444
1445 #[test]
1446 fn csrf_verify_empty_either_side_returns_false() {
1447 let t = csrf::generate_token();
1448 assert!(!csrf::verify_token("", &t));
1449 assert!(!csrf::verify_token(&t, ""));
1450 assert!(!csrf::verify_token("", ""));
1451 }
1452
1453 #[test]
1454 fn csrf_verify_rejects_different_lengths() {
1455 assert!(!csrf::verify_token("abc", "abcd"));
1458 assert!(!csrf::verify_token("abcd", "abc"));
1459 }
1460
1461 #[test]
1462 fn csrf_verify_rejects_single_byte_difference() {
1463 let a = "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef";
1464 let mut b = String::from(a);
1465 b.pop();
1467 b.push('0');
1468 assert!(!csrf::verify_token(a, &b));
1469 }
1470
1471 #[tokio::test]
1474 async fn session_create_generates_unique_csrf_per_session() {
1475 let db = setup().await;
1476 let u = seeded_user(&db, ROLE_USER).await;
1477 let s1 = session::create(&db, u.id).await.unwrap();
1478 let s2 = session::create(&db, u.id).await.unwrap();
1479 assert_eq!(s1.csrf_token.len(), 64);
1480 assert_ne!(
1481 s1.csrf_token, s2.csrf_token,
1482 "each session must get an independent CSRF token"
1483 );
1484 assert_ne!(
1485 s1.csrf_token, s1.id,
1486 "session id and csrf token must not be the same value"
1487 );
1488 }
1489
1490 #[tokio::test]
1491 async fn session_find_valid_returns_csrf_token() {
1492 let db = setup().await;
1493 let u = seeded_user(&db, ROLE_USER).await;
1494 let s = session::create(&db, u.id).await.unwrap();
1495 let found = session::find_valid(&db, &s.id).await.unwrap().unwrap();
1496 assert_eq!(found.csrf_token, s.csrf_token);
1497 }
1498
1499 #[tokio::test]
1500 async fn resolve_identity_with_session_exposes_csrf() {
1501 let db = setup().await;
1504 let u = seeded_user(&db, ROLE_ADMIN).await;
1505 let s = session::create(&db, u.id).await.unwrap();
1506 let (id, sess) = resolve_identity_with_session(&db, Some(&s.id))
1507 .await
1508 .unwrap();
1509 assert_eq!(id.user_id, u.id);
1510 assert_eq!(sess.csrf_token, s.csrf_token);
1511 }
1512
1513 #[test]
1514 fn rate_limiter_tracks_keys_independently() {
1515 let limiter = LoginRateLimiter::with_params(2, StdDuration::from_secs(60));
1516 limiter.record_failure("alice@example.com");
1517 limiter.record_failure("alice@example.com");
1518 assert!(limiter.check("alice@example.com").is_err());
1519 assert!(limiter.check("bob@example.com").is_ok());
1521 }
1522
1523 #[test]
1526 fn dummy_password_hash_is_stable_across_calls() {
1527 let a = dummy_password_hash();
1530 let b = dummy_password_hash();
1531 assert!(std::ptr::eq(a, b));
1532 }
1533
1534 #[test]
1535 fn dummy_password_hash_is_a_valid_phc_string() {
1536 assert!(PasswordHash::new(dummy_password_hash()).is_ok());
1540 }
1541
1542 #[test]
1543 fn verify_against_dummy_hash_rejects_arbitrary_inputs() {
1544 assert!(!password::verify("", dummy_password_hash()));
1550 assert!(!password::verify("wrong password", dummy_password_hash()));
1551 assert!(!password::verify("admin", dummy_password_hash()));
1552 }
1553
1554 #[tokio::test]
1555 async fn logout_deletes_session_so_later_requests_are_anonymous() {
1556 let db = setup().await;
1559 let u = seeded_user(&db, ROLE_USER).await;
1560 let s = session::create(&db, u.id).await.unwrap();
1561 assert!(resolve_identity(&db, Some(&s.id)).await.is_some());
1562
1563 session::delete(&db, &s.id).await.unwrap();
1564 assert!(
1565 resolve_identity(&db, Some(&s.id)).await.is_none(),
1566 "deleted session must not resolve"
1567 );
1568 }
1569}