tako_rs_plugins/middleware/
session.rs1use std::future::Future;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::sync::atomic::AtomicBool;
26use std::sync::atomic::AtomicU64;
27use std::sync::atomic::Ordering;
28use std::time::Duration;
29use std::time::Instant;
30
31use http::HeaderValue;
32use parking_lot::Mutex;
33use scc::HashMap as SccHashMap;
34use serde::Serialize;
35use serde::de::DeserializeOwned;
36use tako_rs_core::middleware::IntoMiddleware;
37use tako_rs_core::middleware::Next;
38use tako_rs_core::types::Request;
39use tako_rs_core::types::Response;
40
41#[derive(Clone, Copy)]
43pub struct SessionTtl {
44 pub idle_secs: u64,
46 pub absolute_secs: Option<u64>,
49}
50
51impl Default for SessionTtl {
52 fn default() -> Self {
53 Self {
54 idle_secs: 3_600,
55 absolute_secs: Some(86_400),
56 }
57 }
58}
59
60#[derive(Clone, Copy, Debug)]
62pub enum SameSite {
63 Strict,
64 Lax,
65 None,
66}
67
68impl SameSite {
69 fn as_str(self) -> &'static str {
70 match self {
71 SameSite::Strict => "Strict",
72 SameSite::Lax => "Lax",
73 SameSite::None => "None",
74 }
75 }
76}
77
78#[derive(Clone)]
79struct SessionEntry {
80 data: serde_json::Map<String, serde_json::Value>,
81 created_at: Instant,
82 last_seen_at: Instant,
83}
84
85#[derive(Clone)]
87struct Store(Arc<SccHashMap<String, SessionEntry>>);
88
89impl Store {
90 fn new() -> Self {
91 Self(Arc::new(SccHashMap::new()))
92 }
93
94 fn get(&self, id: &str) -> Option<SessionEntry> {
95 self.0.get_sync(id).map(|e| e.clone())
96 }
97
98 fn upsert(&self, id: String, entry: SessionEntry) {
99 let _ = self.0.upsert_sync(id, entry);
100 }
101
102 fn remove(&self, id: &str) {
103 let _ = self.0.remove_sync(id);
104 }
105
106 fn revoke_all(&self) {
107 self.0.clear_sync();
108 }
109
110 fn revoke_predicate(&self, mut keep: impl FnMut(&str, &SessionEntry) -> bool) {
111 self.0.retain_sync(|k, v| keep(k, v));
112 }
113
114 fn retain_expired(&self, ttl: SessionTtl) {
115 let now = Instant::now();
116 let idle = Duration::from_secs(ttl.idle_secs);
117 let absolute = ttl.absolute_secs.map(Duration::from_secs);
118 self.0.retain_sync(|_, v| {
119 if now.duration_since(v.last_seen_at) > idle {
120 return false;
121 }
122 if let Some(abs) = absolute
123 && now.duration_since(v.created_at) > abs
124 {
125 return false;
126 }
127 true
128 });
129 }
130}
131
132#[derive(Clone)]
134pub struct SessionStoreHandle {
135 store: Store,
136}
137
138impl SessionStoreHandle {
139 pub fn revoke_all(&self) {
141 self.store.revoke_all();
142 }
143
144 pub fn revoke_where<F>(&self, mut pred: F)
146 where
147 F: FnMut(&str, &serde_json::Map<String, serde_json::Value>) -> bool,
148 {
149 self.store.revoke_predicate(|k, v| !pred(k, &v.data));
150 }
151}
152
153#[derive(Clone)]
155pub struct Session {
156 data: Arc<Mutex<serde_json::Map<String, serde_json::Value>>>,
157 dirty: Arc<AtomicBool>,
158 rotation_counter: Arc<AtomicU64>,
159 destroyed: Arc<AtomicBool>,
160}
161
162impl Session {
163 fn new(data: serde_json::Map<String, serde_json::Value>) -> Self {
164 Self {
165 data: Arc::new(Mutex::new(data)),
166 dirty: Arc::new(AtomicBool::new(false)),
167 rotation_counter: Arc::new(AtomicU64::new(0)),
168 destroyed: Arc::new(AtomicBool::new(false)),
169 }
170 }
171
172 pub fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
174 self
175 .data
176 .lock()
177 .get(key)
178 .and_then(|v| serde_json::from_value(v.clone()).ok())
179 }
180
181 pub fn set<T: Serialize>(&self, key: &str, value: T) {
183 if let Ok(v) = serde_json::to_value(value) {
184 self.data.lock().insert(key.to_string(), v);
185 self.dirty.store(true, Ordering::Relaxed);
186 }
187 }
188
189 pub fn remove(&self, key: &str) {
191 if self.data.lock().remove(key).is_some() {
192 self.dirty.store(true, Ordering::Relaxed);
193 }
194 }
195
196 pub fn clear(&self) {
201 let mut guard = self.data.lock();
202 if !guard.is_empty() {
203 guard.clear();
204 self.dirty.store(true, Ordering::Relaxed);
205 }
206 }
207
208 pub fn destroy(&self) {
213 self.data.lock().clear();
214 self.destroyed.store(true, Ordering::Release);
215 self.dirty.store(true, Ordering::Relaxed);
216 }
217
218 fn is_destroyed(&self) -> bool {
219 self.destroyed.load(Ordering::Acquire)
220 }
221
222 pub fn rotate(&self) {
226 self.rotation_counter.fetch_add(1, Ordering::AcqRel);
227 self.dirty.store(true, Ordering::Relaxed);
228 }
229
230 fn is_dirty(&self) -> bool {
231 self.dirty.load(Ordering::Relaxed)
232 }
233
234 pub fn rotation_requested(&self) -> bool {
239 self.rotation_counter.load(Ordering::Acquire) > 0
240 }
241
242 fn snapshot(&self) -> serde_json::Map<String, serde_json::Value> {
243 self.data.lock().clone()
244 }
245}
246
247pub struct SessionMiddleware {
249 cookie_name: String,
250 ttl: SessionTtl,
251 path: String,
252 domain: Option<String>,
253 secure: bool,
254 http_only: bool,
255 same_site: SameSite,
256 store: Store,
257}
258
259impl Default for SessionMiddleware {
260 fn default() -> Self {
261 Self::new()
262 }
263}
264
265impl SessionMiddleware {
266 pub fn new() -> Self {
268 Self {
269 cookie_name: "tako_session".to_string(),
270 ttl: SessionTtl::default(),
271 path: "/".to_string(),
272 domain: None,
273 secure: false,
274 http_only: true,
275 same_site: SameSite::Lax,
276 store: Store::new(),
277 }
278 }
279
280 pub fn cookie_name(mut self, name: &str) -> Self {
282 self.cookie_name = name.to_string();
283 self
284 }
285
286 pub fn ttl_secs(mut self, secs: u64) -> Self {
289 self.ttl.idle_secs = secs;
290 self
291 }
292
293 pub fn ttl(mut self, ttl: SessionTtl) -> Self {
295 self.ttl = ttl;
296 self
297 }
298
299 pub fn path(mut self, path: &str) -> Self {
301 self.path = path.to_string();
302 self
303 }
304
305 pub fn domain(mut self, domain: &str) -> Self {
307 self.domain = Some(domain.to_string());
308 self
309 }
310
311 pub fn secure(mut self, secure: bool) -> Self {
313 self.secure = secure;
314 self
315 }
316
317 pub fn http_only(mut self, on: bool) -> Self {
319 self.http_only = on;
320 self
321 }
322
323 pub fn same_site(mut self, ss: SameSite) -> Self {
326 self.same_site = ss;
327 self
328 }
329
330 pub fn handle(&self) -> SessionStoreHandle {
332 SessionStoreHandle {
333 store: self.store.clone(),
334 }
335 }
336}
337
338fn generate_session_id() -> String {
339 uuid::Uuid::new_v4().simple().to_string()
344}
345
346fn extract_cookie_value<'a>(req: &'a Request, cookie_name: &str) -> Option<&'a str> {
347 req
348 .headers()
349 .get(http::header::COOKIE)
350 .and_then(|v| v.to_str().ok())
351 .and_then(|cookies| {
352 cookies.split(';').find_map(|pair| {
353 let pair = pair.trim();
354 let (name, value) = pair.split_once('=')?;
355 if name.trim() == cookie_name {
356 Some(value.trim())
357 } else {
358 None
359 }
360 })
361 })
362}
363
364#[allow(clippy::too_many_arguments)]
365fn build_cookie(
366 cookie_name: &str,
367 sid: &str,
368 path: &str,
369 domain: Option<&str>,
370 ttl_secs: u64,
371 secure: bool,
372 http_only: bool,
373 same_site: SameSite,
374) -> String {
375 let mut s = format!("{cookie_name}={sid}; Path={path}");
376 if let Some(d) = domain {
377 s.push_str("; Domain=");
378 s.push_str(d);
379 }
380 s.push_str(&format!("; Max-Age={ttl_secs}"));
381 if http_only {
382 s.push_str("; HttpOnly");
383 }
384 if secure {
385 s.push_str("; Secure");
386 }
387 s.push_str("; SameSite=");
388 s.push_str(same_site.as_str());
389 s
390}
391
392fn build_expired_cookie(
393 cookie_name: &str,
394 path: &str,
395 domain: Option<&str>,
396 secure: bool,
397 http_only: bool,
398 same_site: SameSite,
399) -> String {
400 let mut s = format!("{cookie_name}=; Path={path}");
403 if let Some(d) = domain {
404 s.push_str("; Domain=");
405 s.push_str(d);
406 }
407 s.push_str("; Max-Age=0; Expires=Thu, 01 Jan 1970 00:00:00 GMT");
408 if http_only {
409 s.push_str("; HttpOnly");
410 }
411 if secure {
412 s.push_str("; Secure");
413 }
414 s.push_str("; SameSite=");
415 s.push_str(same_site.as_str());
416 s
417}
418
419impl IntoMiddleware for SessionMiddleware {
420 fn into_middleware(
421 self,
422 ) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
423 + Clone
424 + Send
425 + Sync
426 + 'static {
427 let store = self.store.clone();
428 let cookie_name = Arc::new(self.cookie_name);
429 let ttl = self.ttl;
430 let path = Arc::new(self.path);
431 let domain = self.domain.map(Arc::new);
432 let secure = self.secure;
433 let http_only = self.http_only;
434 let same_site = self.same_site;
435
436 {
439 let store = store.clone();
440 let interval = Duration::from_secs(ttl.idle_secs.clamp(60, 3_600));
441 #[cfg(not(feature = "compio"))]
442 tokio::spawn(async move {
443 let mut tick = tokio::time::interval(interval);
444 loop {
445 tick.tick().await;
446 store.retain_expired(ttl);
447 }
448 });
449 #[cfg(feature = "compio")]
450 compio::runtime::spawn(async move {
451 loop {
452 compio::time::sleep(interval).await;
453 store.retain_expired(ttl);
454 }
455 })
456 .detach();
457 }
458
459 move |mut req: Request, next: Next| {
460 let store = store.clone();
461 let cookie_name = cookie_name.clone();
462 let path = path.clone();
463 let domain = domain.clone();
464
465 Box::pin(async move {
466 let now = Instant::now();
467 let idle = Duration::from_secs(ttl.idle_secs);
468 let absolute = ttl.absolute_secs.map(Duration::from_secs);
469
470 let inbound_id = extract_cookie_value(&req, &cookie_name).map(str::to_string);
471 let (sid, data, created_at, was_existing) = match inbound_id {
472 Some(ref id) => match store.get(id) {
473 Some(entry)
474 if now.duration_since(entry.last_seen_at) <= idle
475 && absolute.is_none_or(|abs| now.duration_since(entry.created_at) <= abs) =>
476 {
477 (id.clone(), entry.data, entry.created_at, true)
478 }
479 _ => {
480 if let Some(id) = inbound_id.as_ref() {
481 store.remove(id);
482 }
483 (generate_session_id(), serde_json::Map::new(), now, false)
484 }
485 },
486 None => (generate_session_id(), serde_json::Map::new(), now, false),
487 };
488
489 let session = Session::new(data);
490 req.extensions_mut().insert(session.clone());
491
492 let resp_outcome = next.run(req).await;
493 let mut resp = resp_outcome;
494
495 let dirty = session.is_dirty();
496 let rotated = session.rotation_requested();
497 let destroyed = session.is_destroyed();
498
499 if destroyed {
503 if was_existing {
504 store.remove(&sid);
505 }
506 let expired = build_expired_cookie(
507 &cookie_name,
508 &path,
509 domain.as_deref().map(String::as_str),
510 secure,
511 http_only,
512 same_site,
513 );
514 if let Ok(v) = HeaderValue::from_str(&expired) {
515 resp.headers_mut().append(http::header::SET_COOKIE, v);
516 }
517 let _ = dirty;
518 return resp;
519 }
520
521 let effective_sid = if rotated {
523 if was_existing {
524 store.remove(&sid);
525 }
526 generate_session_id()
527 } else {
528 sid
529 };
530
531 let updated_entry = SessionEntry {
535 data: session.snapshot(),
536 created_at,
537 last_seen_at: now,
538 };
539 store.upsert(effective_sid.clone(), updated_entry);
540
541 let max_age = match absolute {
545 Some(abs) => {
546 let elapsed = now.duration_since(created_at);
547 let absolute_remaining = abs.saturating_sub(elapsed);
548 absolute_remaining.as_secs().min(idle.as_secs())
549 }
550 None => idle.as_secs(),
551 };
552
553 let cookie_value = build_cookie(
554 &cookie_name,
555 &effective_sid,
556 &path,
557 domain.as_deref().map(String::as_str),
558 max_age,
559 secure,
560 http_only,
561 same_site,
562 );
563 if let Ok(v) = HeaderValue::from_str(&cookie_value) {
564 resp.headers_mut().append(http::header::SET_COOKIE, v);
565 }
566
567 let _ = dirty;
568
569 resp
570 })
571 }
572 }
573}