1use chrono::{DateTime, Duration, Utc};
8use r2d2::Pool;
9use r2d2_sqlite::SqliteConnectionManager;
10use rand::RngCore;
11use rusqlite::{Connection, params};
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::collections::HashMap;
15use std::path::Path;
16
17use crate::Result;
18use crate::config::CloudflareKvConfig;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Session {
23 pub id: String,
25 pub data: HashMap<String, Value>,
27 pub created_at: DateTime<Utc>,
29 pub expires_at: DateTime<Utc>,
31 pub last_accessed: DateTime<Utc>,
33}
34
35pub const CSRF_TOKEN_KEY: &str = "_csrf_token";
37
38impl Session {
39 pub fn new(max_age_seconds: i64) -> Self {
41 let now = Utc::now();
42 let mut data = HashMap::new();
43 data.insert(
44 CSRF_TOKEN_KEY.to_string(),
45 Value::String(generate_csrf_token()),
46 );
47 Self {
48 id: generate_session_id(),
49 data,
50 created_at: now,
51 expires_at: now + Duration::seconds(max_age_seconds),
52 last_accessed: now,
53 }
54 }
55
56 pub fn is_expired(&self) -> bool {
58 Utc::now() > self.expires_at
59 }
60
61 pub fn to_context(&self) -> Value {
63 let mut map = serde_json::Map::new();
64 map.insert("id".to_string(), Value::String(self.id.clone()));
65 map.insert(
66 "created_at".to_string(),
67 Value::String(self.created_at.to_rfc3339()),
68 );
69 map.insert(
70 "expires_at".to_string(),
71 Value::String(self.expires_at.to_rfc3339()),
72 );
73
74 for (key, value) in &self.data {
76 map.insert(key.clone(), value.clone());
77 }
78
79 Value::Object(map)
80 }
81}
82
83pub fn generate_session_id() -> String {
86 let mut bytes = [0u8; 64];
87 rand::thread_rng().fill_bytes(&mut bytes);
88 hex::encode(&bytes)
89}
90
91pub fn generate_csrf_token() -> String {
94 let mut bytes = [0u8; 32];
95 rand::thread_rng().fill_bytes(&mut bytes);
96 hex::encode(&bytes)
97}
98
99mod hex {
101 pub fn encode(bytes: &[u8]) -> String {
102 bytes.iter().map(|b| format!("{:02x}", b)).collect()
103 }
104}
105
106pub enum SessionBackend {
112 Sqlite(SqliteSessionStore),
114 CloudflareKv(KvSessionStore),
116}
117
118impl SessionBackend {
119 pub async fn create(&self) -> Result<Session> {
121 match self {
122 Self::Sqlite(s) => s.create().await,
123 Self::CloudflareKv(s) => s.create().await,
124 }
125 }
126
127 pub async fn get(&self, id: &str) -> Result<Option<Session>> {
129 match self {
130 Self::Sqlite(s) => s.get(id).await,
131 Self::CloudflareKv(s) => s.get(id).await,
132 }
133 }
134
135 pub async fn get_or_create(&self, id: Option<&str>) -> Result<Session> {
137 match self {
138 Self::Sqlite(s) => s.get_or_create(id).await,
139 Self::CloudflareKv(s) => s.get_or_create(id).await,
140 }
141 }
142
143 pub async fn update(&self, id: &str, data: HashMap<String, Value>) -> Result<()> {
145 match self {
146 Self::Sqlite(s) => s.update(id, data).await,
147 Self::CloudflareKv(s) => s.update(id, data).await,
148 }
149 }
150
151 pub async fn touch(&self, id: &str) -> Result<()> {
153 match self {
154 Self::Sqlite(s) => s.touch(id).await,
155 Self::CloudflareKv(s) => s.touch(id).await,
156 }
157 }
158
159 pub async fn delete(&self, id: &str) -> Result<()> {
161 match self {
162 Self::Sqlite(s) => s.delete(id).await,
163 Self::CloudflareKv(s) => s.delete(id).await,
164 }
165 }
166
167 pub async fn cleanup_expired(&self) -> Result<u64> {
169 match self {
170 Self::Sqlite(s) => s.cleanup_expired().await,
171 Self::CloudflareKv(_) => Ok(0), }
173 }
174
175 pub async fn list_session_ids(&self) -> Result<Vec<String>> {
177 match self {
178 Self::Sqlite(s) => s.list_session_ids().await,
179 Self::CloudflareKv(_) => Ok(vec![]), }
181 }
182
183 pub async fn count(&self) -> Result<usize> {
185 match self {
186 Self::Sqlite(s) => s.count().await,
187 Self::CloudflareKv(_) => Ok(0), }
189 }
190
191 pub async fn apply_mutation(
196 &self,
197 id: &str,
198 mutation: &AtomicMutation,
199 ) -> Result<HashMap<String, Value>> {
200 match self {
201 Self::Sqlite(s) => s.apply_atomic_mutation(id, mutation).await,
202 Self::CloudflareKv(s) => {
203 if let Some(mut session) = s.get(id).await? {
205 apply_mutation_in_memory(&mut session.data, mutation);
206 s.update(id, session.data.clone()).await?;
207 Ok(session.data)
208 } else {
209 Ok(HashMap::new())
210 }
211 }
212 }
213 }
214}
215
216impl Clone for SessionBackend {
217 fn clone(&self) -> Self {
218 match self {
219 Self::Sqlite(s) => Self::Sqlite(s.clone()),
220 Self::CloudflareKv(s) => Self::CloudflareKv(s.clone()),
221 }
222 }
223}
224
225#[derive(Clone)]
233pub struct SqliteSessionStore {
234 pool: Pool<SqliteConnectionManager>,
235 max_age: i64,
236}
237
238#[derive(Debug)]
240struct SessionCustomizer;
241
242impl r2d2::CustomizeConnection<Connection, rusqlite::Error> for SessionCustomizer {
243 fn on_acquire(&self, conn: &mut Connection) -> std::result::Result<(), rusqlite::Error> {
244 conn.execute_batch("PRAGMA busy_timeout=5000; PRAGMA synchronous=NORMAL;")?;
245 Ok(())
246 }
247}
248
249impl SqliteSessionStore {
250 pub fn new(db_path: impl AsRef<Path>, max_age_seconds: i64) -> Result<Self> {
253 let manager = SqliteConnectionManager::file(db_path);
254 let pool = Pool::builder()
255 .max_size(4)
256 .connection_customizer(Box::new(SessionCustomizer))
257 .build(manager)
258 .map_err(|e| crate::Error::Session(format!("Session pool creation failed: {}", e)))?;
259
260 let conn = pool
262 .get()
263 .map_err(|e| crate::Error::Session(format!("Session pool get failed: {}", e)))?;
264 conn.execute_batch("PRAGMA journal_mode=WAL;")?;
265 conn.execute(
266 "CREATE TABLE IF NOT EXISTS sessions (
267 id TEXT PRIMARY KEY,
268 data TEXT NOT NULL DEFAULT '{}',
269 created_at INTEGER NOT NULL,
270 expires_at INTEGER NOT NULL,
271 last_accessed INTEGER NOT NULL
272 )",
273 [],
274 )?;
275 conn.execute(
276 "CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires_at)",
277 [],
278 )?;
279
280 let now = Utc::now().timestamp();
282 let cleaned = conn.execute("DELETE FROM sessions WHERE expires_at < ?1", params![now])?;
283 if cleaned > 0 {
284 tracing::info!("Cleaned up {} expired sessions", cleaned);
285 }
286
287 Ok(Self {
288 pool,
289 max_age: max_age_seconds,
290 })
291 }
292
293 pub fn in_memory(max_age_seconds: i64) -> Result<Self> {
296 let manager = SqliteConnectionManager::memory();
297 let pool = Pool::builder()
298 .max_size(1)
299 .build(manager)
300 .map_err(|e| crate::Error::Session(format!("Session pool creation failed: {}", e)))?;
301
302 let conn = pool
303 .get()
304 .map_err(|e| crate::Error::Session(format!("Session pool get failed: {}", e)))?;
305 conn.execute(
306 "CREATE TABLE sessions (
307 id TEXT PRIMARY KEY,
308 data TEXT NOT NULL DEFAULT '{}',
309 created_at INTEGER NOT NULL,
310 expires_at INTEGER NOT NULL,
311 last_accessed INTEGER NOT NULL
312 )",
313 [],
314 )?;
315
316 Ok(Self {
317 pool,
318 max_age: max_age_seconds,
319 })
320 }
321
322 pub async fn create(&self) -> Result<Session> {
324 let pool = self.pool.clone();
325 let max_age = self.max_age;
326 tokio::task::spawn_blocking(move || {
327 let session = Session::new(max_age);
328 let conn = pool.get()
329 .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
330 conn.execute(
331 "INSERT INTO sessions (id, data, created_at, expires_at, last_accessed) VALUES (?1, ?2, ?3, ?4, ?5)",
332 params![
333 session.id,
334 serde_json::to_string(&session.data)?,
335 session.created_at.timestamp(),
336 session.expires_at.timestamp(),
337 session.last_accessed.timestamp(),
338 ],
339 )?;
340 Ok(session)
341 }).await.map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
342 }
343
344 pub async fn get(&self, id: &str) -> Result<Option<Session>> {
346 let pool = self.pool.clone();
347 let id = id.to_string();
348 tokio::task::spawn_blocking(move || {
349 let conn = pool
350 .get()
351 .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
352
353 let mut stmt = conn.prepare(
354 "SELECT id, data, created_at, expires_at, last_accessed FROM sessions WHERE id = ?1"
355 )?;
356
357 let session = match stmt.query_row(params![id], |row| {
358 let id: String = row.get(0)?;
359 let data_str: String = row.get(1)?;
360 let created_at: i64 = row.get(2)?;
361 let expires_at: i64 = row.get(3)?;
362 let last_accessed: i64 = row.get(4)?;
363
364 Ok(Session {
365 id,
366 data: serde_json::from_str(&data_str).unwrap_or_default(),
367 created_at: DateTime::from_timestamp(created_at, 0).unwrap_or_else(Utc::now),
368 expires_at: DateTime::from_timestamp(expires_at, 0).unwrap_or_else(Utc::now),
369 last_accessed: DateTime::from_timestamp(last_accessed, 0)
370 .unwrap_or_else(Utc::now),
371 })
372 }) {
373 Ok(s) => Some(s),
374 Err(rusqlite::Error::QueryReturnedNoRows) => None,
375 Err(e) => return Err(e.into()),
376 };
377
378 match session {
380 Some(s) if s.is_expired() => {
381 conn.execute("DELETE FROM sessions WHERE id = ?1", params![s.id])?;
382 Ok(None)
383 }
384 s => Ok(s),
385 }
386 })
387 .await
388 .map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
389 }
390
391 pub async fn get_or_create(&self, id: Option<&str>) -> Result<Session> {
393 if let Some(session_id) = id {
394 if let Some(session) = self.get(session_id).await? {
395 self.touch(&session.id).await?;
397 return Ok(session);
398 }
399 }
400 self.create().await
401 }
402
403 pub async fn update(&self, id: &str, data: HashMap<String, Value>) -> Result<()> {
405 let pool = self.pool.clone();
406 let id = id.to_string();
407 tokio::task::spawn_blocking(move || {
408 let conn = pool
409 .get()
410 .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
411 let now = Utc::now().timestamp();
412 conn.execute(
413 "UPDATE sessions SET data = ?1, last_accessed = ?2 WHERE id = ?3",
414 params![serde_json::to_string(&data)?, now, id],
415 )?;
416 Ok(())
417 })
418 .await
419 .map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
420 }
421
422 pub async fn touch(&self, id: &str) -> Result<()> {
424 let pool = self.pool.clone();
425 let id = id.to_string();
426 tokio::task::spawn_blocking(move || {
427 let conn = pool
428 .get()
429 .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
430 let now = Utc::now().timestamp();
431 conn.execute(
432 "UPDATE sessions SET last_accessed = ?1 WHERE id = ?2",
433 params![now, id],
434 )?;
435 Ok(())
436 })
437 .await
438 .map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
439 }
440
441 pub async fn delete(&self, id: &str) -> Result<()> {
443 let pool = self.pool.clone();
444 let id = id.to_string();
445 tokio::task::spawn_blocking(move || {
446 let conn = pool
447 .get()
448 .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
449 conn.execute("DELETE FROM sessions WHERE id = ?1", params![id])?;
450 Ok(())
451 })
452 .await
453 .map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
454 }
455
456 pub async fn cleanup_expired(&self) -> Result<u64> {
458 let pool = self.pool.clone();
459 tokio::task::spawn_blocking(move || {
460 let conn = pool
461 .get()
462 .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
463 let now = Utc::now().timestamp();
464 let deleted =
465 conn.execute("DELETE FROM sessions WHERE expires_at < ?1", params![now])?;
466 Ok(deleted as u64)
467 })
468 .await
469 .map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
470 }
471
472 pub async fn list_session_ids(&self) -> Result<Vec<String>> {
474 let pool = self.pool.clone();
475 tokio::task::spawn_blocking(move || {
476 let conn = pool
477 .get()
478 .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
479 let now = Utc::now().timestamp();
480 let mut stmt = conn.prepare(
481 "SELECT id FROM sessions WHERE expires_at > ?1 ORDER BY last_accessed DESC",
482 )?;
483 let ids: Vec<String> = stmt
484 .query_map(params![now], |row| row.get(0))?
485 .filter_map(|r| r.ok())
486 .collect();
487 Ok(ids)
488 })
489 .await
490 .map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
491 }
492
493 pub async fn count(&self) -> Result<usize> {
495 let pool = self.pool.clone();
496 tokio::task::spawn_blocking(move || {
497 let conn = pool
498 .get()
499 .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
500 let now = Utc::now().timestamp();
501 let count: i64 = conn.query_row(
502 "SELECT COUNT(*) FROM sessions WHERE expires_at > ?1",
503 params![now],
504 |row| row.get(0),
505 )?;
506 Ok(count as usize)
507 })
508 .await
509 .map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
510 }
511
512 pub async fn apply_atomic_mutation(
515 &self,
516 id: &str,
517 mutation: &AtomicMutation,
518 ) -> Result<HashMap<String, Value>> {
519 let pool = self.pool.clone();
520 let id = id.to_string();
521 let mutation = mutation.clone();
522 tokio::task::spawn_blocking(move || {
523 let conn = pool.get()
524 .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
525 let now = Utc::now().timestamp();
526
527 match &mutation {
528 AtomicMutation::Increment { key, value } => {
529 let path = format!("$.{}", key);
530 conn.execute(
531 "UPDATE sessions SET data = json_set(data, ?1, COALESCE(json_extract(data, ?1), 0) + ?2), last_accessed = ?3 WHERE id = ?4",
532 params![path, value, now, id],
533 )?;
534 }
535 AtomicMutation::Set { key, value } => {
536 let path = format!("$.{}", key);
537 let json_str = serde_json::to_string(value).unwrap_or_default();
538 conn.execute(
539 "UPDATE sessions SET data = json_set(data, ?1, json(?2)), last_accessed = ?3 WHERE id = ?4",
540 params![path, json_str, now, id],
541 )?;
542 }
543 AtomicMutation::Push { key, value } => {
544 let path = format!("$.{}", key);
545 let json_val = serde_json::to_string(value).unwrap_or_default();
546 conn.execute(
547 "UPDATE sessions SET data = json_set(data, ?1, \
548 CASE WHEN json_extract(data, ?1) IS NULL THEN json_array(json(?2)) \
549 ELSE json_insert(json_extract(data, ?1), '$[#]', json(?2)) END \
550 ), last_accessed = ?3 WHERE id = ?4",
551 params![path, json_val, now, id],
552 )?;
553 }
554 AtomicMutation::PushMax { key, max, value } => {
555 let path = format!("$.{}", key);
556 let current: String = conn.query_row(
558 "SELECT COALESCE(json_extract(data, ?1), '[]') FROM sessions WHERE id = ?2",
559 params![path, id],
560 |row| row.get(0),
561 ).unwrap_or_else(|_| "[]".to_string());
562 let mut arr: Vec<Value> = serde_json::from_str(¤t).unwrap_or_default();
563 arr.push(value.clone());
564 while arr.len() > *max {
565 arr.remove(0);
566 }
567 let new_arr = serde_json::to_string(&arr).unwrap_or_else(|_| "[]".to_string());
568 conn.execute(
569 "UPDATE sessions SET data = json_set(data, ?1, json(?2)), last_accessed = ?3 WHERE id = ?4",
570 params![path, new_arr, now, id],
571 )?;
572 }
573 AtomicMutation::Unshift { key, value } => {
574 let path = format!("$.{}", key);
575 let current: String = conn.query_row(
576 "SELECT COALESCE(json_extract(data, ?1), '[]') FROM sessions WHERE id = ?2",
577 params![path, id],
578 |row| row.get(0),
579 ).unwrap_or_else(|_| "[]".to_string());
580 let mut arr: Vec<Value> = serde_json::from_str(¤t).unwrap_or_default();
581 arr.insert(0, value.clone());
582 let new_arr = serde_json::to_string(&arr).unwrap_or_else(|_| "[]".to_string());
583 conn.execute(
584 "UPDATE sessions SET data = json_set(data, ?1, json(?2)), last_accessed = ?3 WHERE id = ?4",
585 params![path, new_arr, now, id],
586 )?;
587 }
588 AtomicMutation::Clear { key } => {
589 let path = format!("$.{}", key);
590 conn.execute(
591 "UPDATE sessions SET data = json_set(data, ?1, json_array()), last_accessed = ?2 WHERE id = ?3",
592 params![path, now, id],
593 )?;
594 }
595 }
596
597 let data_str: String = conn.query_row(
599 "SELECT data FROM sessions WHERE id = ?1",
600 params![id],
601 |row| row.get(0),
602 )?;
603
604 let data: HashMap<String, Value> = serde_json::from_str(&data_str).unwrap_or_default();
605 Ok(data)
606 }).await.map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
607 }
608}
609
610#[derive(Debug, Clone)]
612pub enum AtomicMutation {
613 Increment { key: String, value: i64 },
615 Set { key: String, value: Value },
617 Push { key: String, value: Value },
619 PushMax {
621 key: String,
622 max: usize,
623 value: Value,
624 },
625 Unshift { key: String, value: Value },
627 Clear { key: String },
629}
630
631pub fn apply_mutation_in_memory(data: &mut HashMap<String, Value>, mutation: &AtomicMutation) {
633 match mutation {
634 AtomicMutation::Increment { key, value } => {
635 let current = data.get(key).and_then(|v| v.as_i64()).unwrap_or(0);
636 data.insert(key.clone(), serde_json::json!(current + value));
637 }
638 AtomicMutation::Set { key, value } => {
639 data.insert(key.clone(), value.clone());
640 }
641 AtomicMutation::Push { key, value } => {
642 let arr = data
643 .entry(key.clone())
644 .or_insert_with(|| serde_json::json!([]));
645 if let Some(arr) = arr.as_array_mut() {
646 arr.push(value.clone());
647 }
648 }
649 AtomicMutation::PushMax { key, max, value } => {
650 let arr = data
651 .entry(key.clone())
652 .or_insert_with(|| serde_json::json!([]));
653 if let Some(arr) = arr.as_array_mut() {
654 arr.push(value.clone());
655 while arr.len() > *max {
656 arr.remove(0);
657 }
658 }
659 }
660 AtomicMutation::Unshift { key, value } => {
661 let arr = data
662 .entry(key.clone())
663 .or_insert_with(|| serde_json::json!([]));
664 if let Some(arr) = arr.as_array_mut() {
665 arr.insert(0, value.clone());
666 }
667 }
668 AtomicMutation::Clear { key } => {
669 data.insert(key.clone(), serde_json::json!([]));
670 }
671 }
672}
673
674#[derive(Clone)]
680pub struct KvSessionStore {
681 account_id: String,
682 namespace_id: String,
683 api_token: String,
684 max_age: i64,
685}
686
687impl KvSessionStore {
688 pub fn new(config: &CloudflareKvConfig, max_age_seconds: i64) -> Self {
690 Self {
691 account_id: config.account_id.clone(),
692 namespace_id: config.namespace_id.clone(),
693 api_token: config.api_token.clone(),
694 max_age: max_age_seconds,
695 }
696 }
697
698 fn base_url(&self) -> String {
700 format!(
701 "https://api.cloudflare.com/client/v4/accounts/{}/storage/kv/namespaces/{}",
702 self.account_id, self.namespace_id
703 )
704 }
705
706 fn key(&self, session_id: &str) -> String {
708 format!("session:{}", session_id)
709 }
710
711 fn client() -> &'static reqwest::Client {
713 use std::sync::OnceLock;
714 static CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
715 CLIENT.get_or_init(|| {
716 crate::http_client::build_http_client(Some(std::time::Duration::from_secs(10)))
717 .expect("failed to build Cloudflare KV HTTP client")
718 })
719 }
720
721 pub async fn create(&self) -> Result<Session> {
723 let session = Session::new(self.max_age);
724 self.put_session(&session).await?;
725 Ok(session)
726 }
727
728 pub async fn get(&self, id: &str) -> Result<Option<Session>> {
730 let url = format!("{}/values/{}", self.base_url(), self.key(id));
731
732 let response = Self::client()
733 .get(&url)
734 .bearer_auth(&self.api_token)
735 .send()
736 .await
737 .map_err(|e| crate::Error::Session(format!("KV read failed: {}", e)))?;
738
739 if response.status() == reqwest::StatusCode::NOT_FOUND {
740 return Ok(None);
741 }
742
743 if !response.status().is_success() {
744 return Err(crate::Error::Session(format!(
745 "KV read error: HTTP {}",
746 response.status()
747 )));
748 }
749
750 let body = response
751 .text()
752 .await
753 .map_err(|e| crate::Error::Session(format!("KV read body failed: {}", e)))?;
754
755 match serde_json::from_str::<Session>(&body) {
756 Ok(session) if session.is_expired() => {
757 let _ = self.delete(&session.id).await;
759 Ok(None)
760 }
761 Ok(session) => Ok(Some(session)),
762 Err(e) => {
763 tracing::warn!("KV session deserialize failed: {}", e);
764 Ok(None)
765 }
766 }
767 }
768
769 pub async fn get_or_create(&self, id: Option<&str>) -> Result<Session> {
771 if let Some(session_id) = id {
772 if let Some(session) = self.get(session_id).await? {
773 self.touch(&session.id).await?;
774 return Ok(session);
775 }
776 }
777 self.create().await
778 }
779
780 pub async fn update(&self, id: &str, data: HashMap<String, Value>) -> Result<()> {
782 if let Some(mut session) = self.get(id).await? {
784 session.data = data;
785 session.last_accessed = Utc::now();
786 self.put_session(&session).await?;
787 }
788 Ok(())
789 }
790
791 pub async fn touch(&self, id: &str) -> Result<()> {
793 if let Some(mut session) = self.get(id).await? {
794 session.last_accessed = Utc::now();
795 self.put_session(&session).await?;
796 }
797 Ok(())
798 }
799
800 pub async fn delete(&self, id: &str) -> Result<()> {
802 let url = format!("{}/values/{}", self.base_url(), self.key(id));
803
804 Self::client()
805 .delete(&url)
806 .bearer_auth(&self.api_token)
807 .send()
808 .await
809 .map_err(|e| crate::Error::Session(format!("KV delete failed: {}", e)))?;
810
811 Ok(())
812 }
813
814 async fn put_session(&self, session: &Session) -> Result<()> {
816 let url = format!(
817 "{}/values/{}?expiration_ttl={}",
818 self.base_url(),
819 self.key(&session.id),
820 self.max_age
821 );
822
823 let body = serde_json::to_string(session)
824 .map_err(|e| crate::Error::Session(format!("KV serialize failed: {}", e)))?;
825
826 let response = Self::client()
827 .put(&url)
828 .bearer_auth(&self.api_token)
829 .header("Content-Type", "application/json")
830 .body(body)
831 .send()
832 .await
833 .map_err(|e| crate::Error::Session(format!("KV write failed: {}", e)))?;
834
835 if !response.status().is_success() {
836 let status = response.status();
837 let body = response.text().await.unwrap_or_default();
838 return Err(crate::Error::Session(format!(
839 "KV write error: HTTP {} — {}",
840 status, body
841 )));
842 }
843
844 Ok(())
845 }
846}
847
848pub fn parse_session_cookie(cookie_header: Option<&str>, cookie_name: &str) -> Option<String> {
854 cookie_header.and_then(|header| {
855 header
856 .split(';')
857 .map(|s| s.trim())
858 .find(|s| s.starts_with(&format!("{}=", cookie_name)))
859 .map(|s| s[cookie_name.len() + 1..].to_string())
860 })
861}
862
863pub fn build_session_cookie(
865 session_id: &str,
866 cookie_name: &str,
867 max_age: i64,
868 secure: bool,
869) -> String {
870 let mut cookie = format!(
871 "{}={}; HttpOnly; SameSite=Strict; Path=/; Max-Age={}",
872 cookie_name, session_id, max_age
873 );
874
875 if secure {
876 cookie.push_str("; Secure");
877 }
878
879 cookie
880}
881
882#[cfg(test)]
883mod tests {
884 use super::*;
885
886 #[test]
887 fn test_generate_session_id() {
888 let id = generate_session_id();
889 assert_eq!(id.len(), 128); assert!(id.chars().all(|c| c.is_ascii_hexdigit()));
891 }
892
893 #[tokio::test]
894 async fn test_session_store() {
895 let store = SqliteSessionStore::in_memory(3600).unwrap();
896
897 let session = store.create().await.unwrap();
899 assert_eq!(session.id.len(), 128);
900
901 let retrieved = store.get(&session.id).await.unwrap();
903 assert!(retrieved.is_some());
904 assert_eq!(retrieved.unwrap().id, session.id);
905
906 store.delete(&session.id).await.unwrap();
908 let deleted = store.get(&session.id).await.unwrap();
909 assert!(deleted.is_none());
910 }
911
912 #[test]
913 fn test_parse_session_cookie() {
914 let header = "w_session=abc123; other=value";
915 let result = parse_session_cookie(Some(header), "w_session");
916 assert_eq!(result, Some("abc123".to_string()));
917
918 let result = parse_session_cookie(Some(header), "missing");
919 assert_eq!(result, None);
920 }
921
922 #[test]
923 fn test_kv_key_format() {
924 let config = CloudflareKvConfig {
925 account_id: "acc123".to_string(),
926 namespace_id: "ns456".to_string(),
927 api_token: "token789".to_string(),
928 };
929 let store = KvSessionStore::new(&config, 3600);
930 assert_eq!(store.key("abc123"), "session:abc123");
931 }
932
933 #[test]
934 fn test_kv_base_url() {
935 let config = CloudflareKvConfig {
936 account_id: "acc123".to_string(),
937 namespace_id: "ns456".to_string(),
938 api_token: "token789".to_string(),
939 };
940 let store = KvSessionStore::new(&config, 3600);
941 assert_eq!(
942 store.base_url(),
943 "https://api.cloudflare.com/client/v4/accounts/acc123/storage/kv/namespaces/ns456"
944 );
945 }
946
947 #[test]
948 fn test_session_serialization_roundtrip() {
949 let session = Session::new(3600);
950 let json = serde_json::to_string(&session).unwrap();
951 let deserialized: Session = serde_json::from_str(&json).unwrap();
952 assert_eq!(deserialized.id, session.id);
953 assert_eq!(deserialized.data.len(), 1);
955 assert!(deserialized.data.contains_key(CSRF_TOKEN_KEY));
956 }
957
958 #[tokio::test]
959 async fn test_atomic_increment() {
960 let store = SqliteSessionStore::in_memory(3600).unwrap();
961 let session = store.create().await.unwrap();
962
963 let data = store
965 .apply_atomic_mutation(
966 &session.id,
967 &AtomicMutation::Increment {
968 key: "counter".to_string(),
969 value: 1,
970 },
971 )
972 .await
973 .unwrap();
974 assert_eq!(data.get("counter").and_then(|v| v.as_i64()), Some(1));
975
976 let data = store
978 .apply_atomic_mutation(
979 &session.id,
980 &AtomicMutation::Increment {
981 key: "counter".to_string(),
982 value: 5,
983 },
984 )
985 .await
986 .unwrap();
987 assert_eq!(data.get("counter").and_then(|v| v.as_i64()), Some(6));
988
989 let data = store
991 .apply_atomic_mutation(
992 &session.id,
993 &AtomicMutation::Increment {
994 key: "counter".to_string(),
995 value: -2,
996 },
997 )
998 .await
999 .unwrap();
1000 assert_eq!(data.get("counter").and_then(|v| v.as_i64()), Some(4));
1001 }
1002
1003 #[tokio::test]
1004 async fn test_atomic_set() {
1005 let store = SqliteSessionStore::in_memory(3600).unwrap();
1006 let session = store.create().await.unwrap();
1007
1008 let data = store
1009 .apply_atomic_mutation(
1010 &session.id,
1011 &AtomicMutation::Set {
1012 key: "name".to_string(),
1013 value: serde_json::json!("Alice"),
1014 },
1015 )
1016 .await
1017 .unwrap();
1018 assert_eq!(data.get("name").and_then(|v| v.as_str()), Some("Alice"));
1019
1020 let data = store
1022 .apply_atomic_mutation(
1023 &session.id,
1024 &AtomicMutation::Set {
1025 key: "name".to_string(),
1026 value: serde_json::json!("Bob"),
1027 },
1028 )
1029 .await
1030 .unwrap();
1031 assert_eq!(data.get("name").and_then(|v| v.as_str()), Some("Bob"));
1032 }
1033
1034 #[tokio::test]
1035 async fn test_atomic_push() {
1036 let store = SqliteSessionStore::in_memory(3600).unwrap();
1037 let session = store.create().await.unwrap();
1038
1039 let data = store
1041 .apply_atomic_mutation(
1042 &session.id,
1043 &AtomicMutation::Push {
1044 key: "items".to_string(),
1045 value: serde_json::json!("first"),
1046 },
1047 )
1048 .await
1049 .unwrap();
1050 let items = data.get("items").and_then(|v| v.as_array()).unwrap();
1051 assert_eq!(items.len(), 1);
1052 assert_eq!(items[0].as_str(), Some("first"));
1053
1054 let data = store
1056 .apply_atomic_mutation(
1057 &session.id,
1058 &AtomicMutation::Push {
1059 key: "items".to_string(),
1060 value: serde_json::json!("second"),
1061 },
1062 )
1063 .await
1064 .unwrap();
1065 let items = data.get("items").and_then(|v| v.as_array()).unwrap();
1066 assert_eq!(items.len(), 2);
1067 assert_eq!(items[1].as_str(), Some("second"));
1068 }
1069
1070 #[tokio::test]
1071 async fn test_atomic_push_max() {
1072 let store = SqliteSessionStore::in_memory(3600).unwrap();
1073 let session = store.create().await.unwrap();
1074
1075 for i in 1..=3 {
1077 store
1078 .apply_atomic_mutation(
1079 &session.id,
1080 &AtomicMutation::PushMax {
1081 key: "log".to_string(),
1082 max: 2,
1083 value: serde_json::json!(i),
1084 },
1085 )
1086 .await
1087 .unwrap();
1088 }
1089
1090 let data = store.get(&session.id).await.unwrap().unwrap();
1091 let log = data.data.get("log").and_then(|v| v.as_array()).unwrap();
1092 assert_eq!(log.len(), 2);
1093 assert_eq!(log[0].as_i64(), Some(2));
1095 assert_eq!(log[1].as_i64(), Some(3));
1096 }
1097
1098 #[tokio::test]
1099 async fn test_atomic_unshift() {
1100 let store = SqliteSessionStore::in_memory(3600).unwrap();
1101 let session = store.create().await.unwrap();
1102
1103 store
1104 .apply_atomic_mutation(
1105 &session.id,
1106 &AtomicMutation::Unshift {
1107 key: "stack".to_string(),
1108 value: serde_json::json!("first"),
1109 },
1110 )
1111 .await
1112 .unwrap();
1113 let data = store
1114 .apply_atomic_mutation(
1115 &session.id,
1116 &AtomicMutation::Unshift {
1117 key: "stack".to_string(),
1118 value: serde_json::json!("second"),
1119 },
1120 )
1121 .await
1122 .unwrap();
1123
1124 let stack = data.get("stack").and_then(|v| v.as_array()).unwrap();
1125 assert_eq!(stack.len(), 2);
1126 assert_eq!(stack[0].as_str(), Some("second"));
1127 assert_eq!(stack[1].as_str(), Some("first"));
1128 }
1129
1130 #[tokio::test]
1131 async fn test_atomic_clear() {
1132 let store = SqliteSessionStore::in_memory(3600).unwrap();
1133 let session = store.create().await.unwrap();
1134
1135 store
1137 .apply_atomic_mutation(
1138 &session.id,
1139 &AtomicMutation::Push {
1140 key: "items".to_string(),
1141 value: serde_json::json!("a"),
1142 },
1143 )
1144 .await
1145 .unwrap();
1146 store
1147 .apply_atomic_mutation(
1148 &session.id,
1149 &AtomicMutation::Push {
1150 key: "items".to_string(),
1151 value: serde_json::json!("b"),
1152 },
1153 )
1154 .await
1155 .unwrap();
1156
1157 let data = store
1159 .apply_atomic_mutation(
1160 &session.id,
1161 &AtomicMutation::Clear {
1162 key: "items".to_string(),
1163 },
1164 )
1165 .await
1166 .unwrap();
1167 let items = data.get("items").and_then(|v| v.as_array()).unwrap();
1168 assert_eq!(items.len(), 0);
1169 }
1170
1171 #[test]
1172 fn test_apply_mutation_in_memory() {
1173 let mut data = HashMap::new();
1174
1175 apply_mutation_in_memory(
1176 &mut data,
1177 &AtomicMutation::Increment {
1178 key: "x".to_string(),
1179 value: 3,
1180 },
1181 );
1182 assert_eq!(data.get("x").and_then(|v| v.as_i64()), Some(3));
1183
1184 apply_mutation_in_memory(
1185 &mut data,
1186 &AtomicMutation::Set {
1187 key: "name".to_string(),
1188 value: serde_json::json!("test"),
1189 },
1190 );
1191 assert_eq!(data.get("name").and_then(|v| v.as_str()), Some("test"));
1192
1193 apply_mutation_in_memory(
1194 &mut data,
1195 &AtomicMutation::Push {
1196 key: "list".to_string(),
1197 value: serde_json::json!(1),
1198 },
1199 );
1200 apply_mutation_in_memory(
1201 &mut data,
1202 &AtomicMutation::Push {
1203 key: "list".to_string(),
1204 value: serde_json::json!(2),
1205 },
1206 );
1207 let list = data.get("list").and_then(|v| v.as_array()).unwrap();
1208 assert_eq!(list.len(), 2);
1209 }
1210}