solid_pod_rs_idp/
session.rs1use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13use parking_lot::RwLock;
14use rand::RngCore;
15use thiserror::Error;
16
17#[derive(Debug, Error)]
19pub enum SessionError {
20 #[error("unknown session")]
22 Unknown,
23 #[error("session expired")]
25 Expired,
26}
27
28#[derive(Debug, Clone, PartialEq, Eq, Hash)]
30pub struct SessionId(String);
31
32impl SessionId {
33 pub fn generate() -> Self {
35 let mut buf = [0u8; 32];
36 rand::rngs::OsRng.fill_bytes(&mut buf);
37 Self(hex::encode(buf))
38 }
39
40 pub fn as_str(&self) -> &str {
43 &self.0
44 }
45
46 pub fn from_raw(s: impl Into<String>) -> Self {
49 Self(s.into())
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct AuthCodeRecord {
56 pub code: String,
58 pub client_id: String,
60 pub account_id: String,
62 pub redirect_uri: String,
64 pub code_challenge: Option<String>,
66 pub issued_at: Instant,
68 pub requested_scope: Option<String>,
70}
71
72#[derive(Debug, Clone)]
74pub struct SessionRecord {
75 pub account_id: String,
77 pub created_at: Instant,
79 pub last_access: Instant,
81}
82
83impl SessionRecord {
84 fn new(account_id: String) -> Self {
85 Self {
86 account_id,
87 created_at: Instant::now(),
88 last_access: Instant::now(),
89 }
90 }
91}
92
93#[derive(Clone, Default)]
95pub struct SessionStore {
96 inner: Arc<RwLock<Inner>>,
97 session_ttl: Duration,
99 code_ttl: Duration,
102}
103
104#[derive(Default)]
105struct Inner {
106 sessions: HashMap<String, SessionRecord>,
107 codes: HashMap<String, AuthCodeRecord>,
108}
109
110impl SessionStore {
111 pub fn new() -> Self {
113 Self {
114 inner: Arc::new(RwLock::new(Inner::default())),
115 session_ttl: Duration::from_secs(14 * 24 * 3600),
116 code_ttl: Duration::from_secs(10 * 60),
117 }
118 }
119
120 pub fn with_ttls(mut self, session_ttl: Duration, code_ttl: Duration) -> Self {
122 self.session_ttl = session_ttl;
123 self.code_ttl = code_ttl;
124 self
125 }
126
127 pub fn create_session(&self, account_id: impl Into<String>) -> SessionId {
129 let id = SessionId::generate();
130 self.inner
131 .write()
132 .sessions
133 .insert(id.as_str().to_string(), SessionRecord::new(account_id.into()));
134 id
135 }
136
137 pub fn lookup(&self, id: &SessionId) -> Result<SessionRecord, SessionError> {
139 let mut inner = self.inner.write();
140 let entry = inner
141 .sessions
142 .get_mut(id.as_str())
143 .ok_or(SessionError::Unknown)?;
144 if entry.last_access.elapsed() > self.session_ttl {
145 inner.sessions.remove(id.as_str());
146 return Err(SessionError::Expired);
147 }
148 entry.last_access = Instant::now();
149 Ok(entry.clone())
150 }
151
152 pub fn revoke(&self, id: &SessionId) {
154 self.inner.write().sessions.remove(id.as_str());
155 }
156
157 pub fn issue_code(
159 &self,
160 client_id: impl Into<String>,
161 account_id: impl Into<String>,
162 redirect_uri: impl Into<String>,
163 code_challenge: Option<String>,
164 requested_scope: Option<String>,
165 ) -> AuthCodeRecord {
166 let mut buf = [0u8; 32];
167 rand::rngs::OsRng.fill_bytes(&mut buf);
168 let code = hex::encode(buf);
169 let rec = AuthCodeRecord {
170 code: code.clone(),
171 client_id: client_id.into(),
172 account_id: account_id.into(),
173 redirect_uri: redirect_uri.into(),
174 code_challenge,
175 issued_at: Instant::now(),
176 requested_scope,
177 };
178 self.inner.write().codes.insert(code, rec.clone());
179 rec
180 }
181
182 pub fn take_code(&self, code: &str) -> Option<AuthCodeRecord> {
186 let mut inner = self.inner.write();
187 let rec = inner.codes.remove(code)?;
188 if rec.issued_at.elapsed() > self.code_ttl {
189 return None;
190 }
191 Some(rec)
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 #[test]
200 fn session_ids_are_unique() {
201 let a = SessionId::generate();
202 let b = SessionId::generate();
203 assert_ne!(a.as_str(), b.as_str());
204 assert_eq!(a.as_str().len(), 64); }
206
207 #[test]
208 fn session_create_lookup_revoke_roundtrip() {
209 let s = SessionStore::new();
210 let id = s.create_session("acct-1");
211 let rec = s.lookup(&id).unwrap();
212 assert_eq!(rec.account_id, "acct-1");
213 s.revoke(&id);
214 assert!(matches!(s.lookup(&id), Err(SessionError::Unknown)));
215 }
216
217 #[test]
218 fn session_expiry_is_enforced() {
219 let s = SessionStore::new().with_ttls(Duration::from_millis(1), Duration::from_secs(60));
220 let id = s.create_session("acct-2");
221 std::thread::sleep(Duration::from_millis(10));
222 let err = s.lookup(&id).unwrap_err();
223 assert!(matches!(err, SessionError::Expired));
224 }
225
226 #[test]
227 fn auth_code_is_single_use() {
228 let s = SessionStore::new();
229 let rec = s.issue_code("c-1", "acct-3", "https://app/cb", None, None);
230 let a = s.take_code(&rec.code).unwrap();
231 assert_eq!(a.account_id, "acct-3");
232 assert!(s.take_code(&rec.code).is_none());
234 }
235
236 #[test]
237 fn auth_code_expires() {
238 let s = SessionStore::new()
239 .with_ttls(Duration::from_secs(60), Duration::from_millis(1));
240 let rec = s.issue_code("c-1", "acct-4", "https://app/cb", None, None);
241 std::thread::sleep(Duration::from_millis(10));
242 assert!(s.take_code(&rec.code).is_none());
243 }
244}