solid_pod_rs_idp/
invites.rs1use std::collections::HashMap;
16use std::time::Duration;
17
18use async_trait::async_trait;
19use chrono::{DateTime, Utc};
20use parking_lot::RwLock;
21use rand::rngs::OsRng;
22use rand::RngCore;
23use thiserror::Error;
24
25#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct Invite {
28 pub token: String,
30 pub max_uses: Option<u32>,
33 pub expires_at: Option<DateTime<Utc>>,
35}
36
37#[derive(Debug, Error)]
39pub enum InviteStoreError {
40 #[error("backend: {0}")]
42 Backend(String),
43}
44
45#[async_trait]
50pub trait InviteStore: Send + Sync + 'static {
51 async fn insert(&self, invite: Invite) -> Result<(), InviteStoreError>;
56
57 async fn get(&self, token: &str) -> Result<Option<Invite>, InviteStoreError>;
59}
60
61#[derive(Default)]
63pub struct InMemoryInviteStore {
64 inner: RwLock<HashMap<String, Invite>>,
65}
66
67impl InMemoryInviteStore {
68 pub fn new() -> Self {
70 Self::default()
71 }
72
73 pub fn snapshot(&self) -> Vec<Invite> {
77 self.inner.read().values().cloned().collect()
78 }
79}
80
81#[async_trait]
82impl InviteStore for InMemoryInviteStore {
83 async fn insert(&self, invite: Invite) -> Result<(), InviteStoreError> {
84 self.inner
85 .write()
86 .entry(invite.token.clone())
87 .or_insert(invite);
88 Ok(())
89 }
90
91 async fn get(&self, token: &str) -> Result<Option<Invite>, InviteStoreError> {
92 Ok(self.inner.read().get(token).cloned())
93 }
94}
95
96pub fn mint_token() -> String {
99 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
100 let mut buf = [0u8; 32];
101 OsRng.fill_bytes(&mut buf);
102 URL_SAFE_NO_PAD.encode(buf)
103}
104
105pub fn parse_duration(input: &str) -> Result<Duration, String> {
111 let trimmed = input.trim();
112 if trimmed.is_empty() {
113 return Err("empty duration".to_string());
114 }
115 if let Ok(n) = trimmed.parse::<u64>() {
117 return Ok(Duration::from_secs(n));
118 }
119 let (num_part, unit) = trimmed.split_at(
120 trimmed
121 .find(|c: char| !c.is_ascii_digit())
122 .ok_or_else(|| format!("no unit suffix in {trimmed:?}"))?,
123 );
124 let n: u64 = num_part
125 .parse()
126 .map_err(|e| format!("invalid number {num_part:?}: {e}"))?;
127 let secs = match unit {
128 "s" => n,
129 "m" => n.saturating_mul(60),
130 "h" => n.saturating_mul(3_600),
131 "d" => n.saturating_mul(86_400),
132 "w" => n.saturating_mul(604_800),
133 other => return Err(format!("unknown duration unit {other:?}")),
134 };
135 Ok(Duration::from_secs(secs))
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[tokio::test]
143 async fn inmemory_store_round_trips() {
144 let s = InMemoryInviteStore::new();
145 let inv = Invite {
146 token: "tok-1".into(),
147 max_uses: Some(3),
148 expires_at: None,
149 };
150 s.insert(inv.clone()).await.unwrap();
151 let got = s.get("tok-1").await.unwrap().unwrap();
152 assert_eq!(got, inv);
153 assert!(s.get("missing").await.unwrap().is_none());
154 }
155
156 #[test]
157 fn mint_token_is_base64url_and_uniqueish() {
158 let a = mint_token();
159 let b = mint_token();
160 assert_ne!(a, b);
161 assert!(a.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
162 assert_eq!(a.len(), 43);
164 }
165
166 #[test]
167 fn parse_duration_accepts_common_units() {
168 assert_eq!(parse_duration("30s").unwrap(), Duration::from_secs(30));
169 assert_eq!(parse_duration("5m").unwrap(), Duration::from_secs(300));
170 assert_eq!(parse_duration("2h").unwrap(), Duration::from_secs(7_200));
171 assert_eq!(parse_duration("7d").unwrap(), Duration::from_secs(604_800));
172 assert_eq!(parse_duration("1w").unwrap(), Duration::from_secs(604_800));
173 assert_eq!(parse_duration("60").unwrap(), Duration::from_secs(60));
174 }
175
176 #[test]
177 fn parse_duration_rejects_bad_input() {
178 assert!(parse_duration("").is_err());
179 assert!(parse_duration("1y").is_err());
180 assert!(parse_duration("abc").is_err());
181 }
182}