Skip to main content

what_core/sessions/
mod.rs

1//! Session management for What framework
2//!
3//! Provides session storage with pluggable backends:
4//! - SQLite (default, for local development and single-server deployments)
5//! - Cloudflare Workers KV (for globally distributed, durable sessions)
6
7use chrono::{DateTime, Duration, Utc};
8use r2d2::Pool;
9use r2d2_sqlite::SqliteConnectionManager;
10use rand::RngCore;
11use rusqlite::{Connection, params};
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::collections::HashMap;
15use std::path::Path;
16
17use crate::Result;
18use crate::config::CloudflareKvConfig;
19
20/// Session data structure
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Session {
23    /// Unique session ID (128 hex chars)
24    pub id: String,
25    /// Session data as JSON
26    pub data: HashMap<String, Value>,
27    /// When the session was created
28    pub created_at: DateTime<Utc>,
29    /// When the session expires
30    pub expires_at: DateTime<Utc>,
31    /// Last access time
32    pub last_accessed: DateTime<Utc>,
33}
34
35/// Reserved session key for the CSRF token
36pub const CSRF_TOKEN_KEY: &str = "_csrf_token";
37
38impl Session {
39    /// Create a new session with generated ID and a CSRF token
40    pub fn new(max_age_seconds: i64) -> Self {
41        let now = Utc::now();
42        let mut data = HashMap::new();
43        data.insert(
44            CSRF_TOKEN_KEY.to_string(),
45            Value::String(generate_csrf_token()),
46        );
47        Self {
48            id: generate_session_id(),
49            data,
50            created_at: now,
51            expires_at: now + Duration::seconds(max_age_seconds),
52            last_accessed: now,
53        }
54    }
55
56    /// Check if session is expired
57    pub fn is_expired(&self) -> bool {
58        Utc::now() > self.expires_at
59    }
60
61    /// Convert session to JSON Value for template context
62    pub fn to_context(&self) -> Value {
63        let mut map = serde_json::Map::new();
64        map.insert("id".to_string(), Value::String(self.id.clone()));
65        map.insert(
66            "created_at".to_string(),
67            Value::String(self.created_at.to_rfc3339()),
68        );
69        map.insert(
70            "expires_at".to_string(),
71            Value::String(self.expires_at.to_rfc3339()),
72        );
73
74        // Merge session data into context
75        for (key, value) in &self.data {
76            map.insert(key.clone(), value.clone());
77        }
78
79        Value::Object(map)
80    }
81}
82
83/// Generate a cryptographically secure session ID
84/// Returns 128 hex characters (64 bytes of random data)
85pub fn generate_session_id() -> String {
86    let mut bytes = [0u8; 64];
87    rand::thread_rng().fill_bytes(&mut bytes);
88    hex::encode(&bytes)
89}
90
91/// Generate a cryptographically secure CSRF token
92/// Returns 64 hex characters (32 bytes of random data)
93pub fn generate_csrf_token() -> String {
94    let mut bytes = [0u8; 32];
95    rand::thread_rng().fill_bytes(&mut bytes);
96    hex::encode(&bytes)
97}
98
99// Simple hex encoding (to avoid adding another dependency)
100mod hex {
101    pub fn encode(bytes: &[u8]) -> String {
102        bytes.iter().map(|b| format!("{:02x}", b)).collect()
103    }
104}
105
106// ---------------------------------------------------------------------------
107// SessionBackend — enum dispatch for pluggable storage
108// ---------------------------------------------------------------------------
109
110/// Pluggable session storage backend
111pub enum SessionBackend {
112    /// Local SQLite storage (default)
113    Sqlite(SqliteSessionStore),
114    /// Cloudflare Workers KV via REST API
115    CloudflareKv(KvSessionStore),
116}
117
118impl SessionBackend {
119    /// Create a new session
120    pub async fn create(&self) -> Result<Session> {
121        match self {
122            Self::Sqlite(s) => s.create().await,
123            Self::CloudflareKv(s) => s.create().await,
124        }
125    }
126
127    /// Get a session by ID
128    pub async fn get(&self, id: &str) -> Result<Option<Session>> {
129        match self {
130            Self::Sqlite(s) => s.get(id).await,
131            Self::CloudflareKv(s) => s.get(id).await,
132        }
133    }
134
135    /// Get or create a session
136    pub async fn get_or_create(&self, id: Option<&str>) -> Result<Session> {
137        match self {
138            Self::Sqlite(s) => s.get_or_create(id).await,
139            Self::CloudflareKv(s) => s.get_or_create(id).await,
140        }
141    }
142
143    /// Update session data
144    pub async fn update(&self, id: &str, data: HashMap<String, Value>) -> Result<()> {
145        match self {
146            Self::Sqlite(s) => s.update(id, data).await,
147            Self::CloudflareKv(s) => s.update(id, data).await,
148        }
149    }
150
151    /// Update last accessed time
152    pub async fn touch(&self, id: &str) -> Result<()> {
153        match self {
154            Self::Sqlite(s) => s.touch(id).await,
155            Self::CloudflareKv(s) => s.touch(id).await,
156        }
157    }
158
159    /// Delete a session
160    pub async fn delete(&self, id: &str) -> Result<()> {
161        match self {
162            Self::Sqlite(s) => s.delete(id).await,
163            Self::CloudflareKv(s) => s.delete(id).await,
164        }
165    }
166
167    /// Clean up expired sessions (no-op for KV which uses TTL)
168    pub async fn cleanup_expired(&self) -> Result<u64> {
169        match self {
170            Self::Sqlite(s) => s.cleanup_expired().await,
171            Self::CloudflareKv(_) => Ok(0), // KV handles expiry via TTL
172        }
173    }
174
175    /// List active session IDs (SQLite only, returns empty for KV)
176    pub async fn list_session_ids(&self) -> Result<Vec<String>> {
177        match self {
178            Self::Sqlite(s) => s.list_session_ids().await,
179            Self::CloudflareKv(_) => Ok(vec![]), // KV doesn't support listing efficiently
180        }
181    }
182
183    /// Count active sessions (SQLite only, returns 0 for KV)
184    pub async fn count(&self) -> Result<usize> {
185        match self {
186            Self::Sqlite(s) => s.count().await,
187            Self::CloudflareKv(_) => Ok(0), // KV doesn't support counting efficiently
188        }
189    }
190
191    /// Apply an atomic mutation directly at the storage level.
192    /// For SQLite, this uses SQL json_set/json_extract for atomicity.
193    /// For KV, falls back to read-modify-write (KV doesn't support atomic ops).
194    /// Returns the updated session data after the mutation.
195    pub async fn apply_mutation(
196        &self,
197        id: &str,
198        mutation: &AtomicMutation,
199    ) -> Result<HashMap<String, Value>> {
200        match self {
201            Self::Sqlite(s) => s.apply_atomic_mutation(id, mutation).await,
202            Self::CloudflareKv(s) => {
203                // KV fallback: read-modify-write (best effort, no true atomicity)
204                if let Some(mut session) = s.get(id).await? {
205                    apply_mutation_in_memory(&mut session.data, mutation);
206                    s.update(id, session.data.clone()).await?;
207                    Ok(session.data)
208                } else {
209                    Ok(HashMap::new())
210                }
211            }
212        }
213    }
214}
215
216impl Clone for SessionBackend {
217    fn clone(&self) -> Self {
218        match self {
219            Self::Sqlite(s) => Self::Sqlite(s.clone()),
220            Self::CloudflareKv(s) => Self::CloudflareKv(s.clone()),
221        }
222    }
223}
224
225// ---------------------------------------------------------------------------
226// SQLite backend
227// ---------------------------------------------------------------------------
228
229/// Session store backed by SQLite with connection pooling.
230/// Uses r2d2 for concurrent reads and spawn_blocking to avoid
231/// blocking the async runtime.
232#[derive(Clone)]
233pub struct SqliteSessionStore {
234    pool: Pool<SqliteConnectionManager>,
235    max_age: i64,
236}
237
238/// Connection customizer that sets WAL mode and busy timeout on each new connection.
239#[derive(Debug)]
240struct SessionCustomizer;
241
242impl r2d2::CustomizeConnection<Connection, rusqlite::Error> for SessionCustomizer {
243    fn on_acquire(&self, conn: &mut Connection) -> std::result::Result<(), rusqlite::Error> {
244        conn.execute_batch("PRAGMA busy_timeout=5000; PRAGMA synchronous=NORMAL;")?;
245        Ok(())
246    }
247}
248
249impl SqliteSessionStore {
250    /// Create a new session store with SQLite database.
251    /// Uses a connection pool (max 4 connections) with WAL mode.
252    pub fn new(db_path: impl AsRef<Path>, max_age_seconds: i64) -> Result<Self> {
253        let manager = SqliteConnectionManager::file(db_path);
254        let pool = Pool::builder()
255            .max_size(4)
256            .connection_customizer(Box::new(SessionCustomizer))
257            .build(manager)
258            .map_err(|e| crate::Error::Session(format!("Session pool creation failed: {}", e)))?;
259
260        // Set WAL mode once (requires exclusive lock, so do it before pool fills)
261        let conn = pool
262            .get()
263            .map_err(|e| crate::Error::Session(format!("Session pool get failed: {}", e)))?;
264        conn.execute_batch("PRAGMA journal_mode=WAL;")?;
265        conn.execute(
266            "CREATE TABLE IF NOT EXISTS sessions (
267                id TEXT PRIMARY KEY,
268                data TEXT NOT NULL DEFAULT '{}',
269                created_at INTEGER NOT NULL,
270                expires_at INTEGER NOT NULL,
271                last_accessed INTEGER NOT NULL
272            )",
273            [],
274        )?;
275        conn.execute(
276            "CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires_at)",
277            [],
278        )?;
279
280        // Clean up expired sessions on startup
281        let now = Utc::now().timestamp();
282        let cleaned = conn.execute("DELETE FROM sessions WHERE expires_at < ?1", params![now])?;
283        if cleaned > 0 {
284            tracing::info!("Cleaned up {} expired sessions", cleaned);
285        }
286
287        Ok(Self {
288            pool,
289            max_age: max_age_seconds,
290        })
291    }
292
293    /// Create a new in-memory session store (for testing).
294    /// Uses max_size=1 because in-memory DBs are per-connection.
295    pub fn in_memory(max_age_seconds: i64) -> Result<Self> {
296        let manager = SqliteConnectionManager::memory();
297        let pool = Pool::builder()
298            .max_size(1)
299            .build(manager)
300            .map_err(|e| crate::Error::Session(format!("Session pool creation failed: {}", e)))?;
301
302        let conn = pool
303            .get()
304            .map_err(|e| crate::Error::Session(format!("Session pool get failed: {}", e)))?;
305        conn.execute(
306            "CREATE TABLE sessions (
307                id TEXT PRIMARY KEY,
308                data TEXT NOT NULL DEFAULT '{}',
309                created_at INTEGER NOT NULL,
310                expires_at INTEGER NOT NULL,
311                last_accessed INTEGER NOT NULL
312            )",
313            [],
314        )?;
315
316        Ok(Self {
317            pool,
318            max_age: max_age_seconds,
319        })
320    }
321
322    /// Create a new session
323    pub async fn create(&self) -> Result<Session> {
324        let pool = self.pool.clone();
325        let max_age = self.max_age;
326        tokio::task::spawn_blocking(move || {
327            let session = Session::new(max_age);
328            let conn = pool.get()
329                .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
330            conn.execute(
331                "INSERT INTO sessions (id, data, created_at, expires_at, last_accessed) VALUES (?1, ?2, ?3, ?4, ?5)",
332                params![
333                    session.id,
334                    serde_json::to_string(&session.data)?,
335                    session.created_at.timestamp(),
336                    session.expires_at.timestamp(),
337                    session.last_accessed.timestamp(),
338                ],
339            )?;
340            Ok(session)
341        }).await.map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
342    }
343
344    /// Get a session by ID
345    pub async fn get(&self, id: &str) -> Result<Option<Session>> {
346        let pool = self.pool.clone();
347        let id = id.to_string();
348        tokio::task::spawn_blocking(move || {
349            let conn = pool
350                .get()
351                .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
352
353            let mut stmt = conn.prepare(
354                "SELECT id, data, created_at, expires_at, last_accessed FROM sessions WHERE id = ?1"
355            )?;
356
357            let session = match stmt.query_row(params![id], |row| {
358                let id: String = row.get(0)?;
359                let data_str: String = row.get(1)?;
360                let created_at: i64 = row.get(2)?;
361                let expires_at: i64 = row.get(3)?;
362                let last_accessed: i64 = row.get(4)?;
363
364                Ok(Session {
365                    id,
366                    data: serde_json::from_str(&data_str).unwrap_or_default(),
367                    created_at: DateTime::from_timestamp(created_at, 0).unwrap_or_else(Utc::now),
368                    expires_at: DateTime::from_timestamp(expires_at, 0).unwrap_or_else(Utc::now),
369                    last_accessed: DateTime::from_timestamp(last_accessed, 0)
370                        .unwrap_or_else(Utc::now),
371                })
372            }) {
373                Ok(s) => Some(s),
374                Err(rusqlite::Error::QueryReturnedNoRows) => None,
375                Err(e) => return Err(e.into()),
376            };
377
378            // Check if session is expired — delete inline with the same connection
379            match session {
380                Some(s) if s.is_expired() => {
381                    conn.execute("DELETE FROM sessions WHERE id = ?1", params![s.id])?;
382                    Ok(None)
383                }
384                s => Ok(s),
385            }
386        })
387        .await
388        .map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
389    }
390
391    /// Get or create a session
392    pub async fn get_or_create(&self, id: Option<&str>) -> Result<Session> {
393        if let Some(session_id) = id {
394            if let Some(session) = self.get(session_id).await? {
395                // Update last accessed time
396                self.touch(&session.id).await?;
397                return Ok(session);
398            }
399        }
400        self.create().await
401    }
402
403    /// Update session data
404    pub async fn update(&self, id: &str, data: HashMap<String, Value>) -> Result<()> {
405        let pool = self.pool.clone();
406        let id = id.to_string();
407        tokio::task::spawn_blocking(move || {
408            let conn = pool
409                .get()
410                .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
411            let now = Utc::now().timestamp();
412            conn.execute(
413                "UPDATE sessions SET data = ?1, last_accessed = ?2 WHERE id = ?3",
414                params![serde_json::to_string(&data)?, now, id],
415            )?;
416            Ok(())
417        })
418        .await
419        .map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
420    }
421
422    /// Update last accessed time
423    pub async fn touch(&self, id: &str) -> Result<()> {
424        let pool = self.pool.clone();
425        let id = id.to_string();
426        tokio::task::spawn_blocking(move || {
427            let conn = pool
428                .get()
429                .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
430            let now = Utc::now().timestamp();
431            conn.execute(
432                "UPDATE sessions SET last_accessed = ?1 WHERE id = ?2",
433                params![now, id],
434            )?;
435            Ok(())
436        })
437        .await
438        .map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
439    }
440
441    /// Delete a session
442    pub async fn delete(&self, id: &str) -> Result<()> {
443        let pool = self.pool.clone();
444        let id = id.to_string();
445        tokio::task::spawn_blocking(move || {
446            let conn = pool
447                .get()
448                .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
449            conn.execute("DELETE FROM sessions WHERE id = ?1", params![id])?;
450            Ok(())
451        })
452        .await
453        .map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
454    }
455
456    /// Clean up expired sessions
457    pub async fn cleanup_expired(&self) -> Result<u64> {
458        let pool = self.pool.clone();
459        tokio::task::spawn_blocking(move || {
460            let conn = pool
461                .get()
462                .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
463            let now = Utc::now().timestamp();
464            let deleted =
465                conn.execute("DELETE FROM sessions WHERE expires_at < ?1", params![now])?;
466            Ok(deleted as u64)
467        })
468        .await
469        .map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
470    }
471
472    /// List all active (non-expired) session IDs
473    pub async fn list_session_ids(&self) -> Result<Vec<String>> {
474        let pool = self.pool.clone();
475        tokio::task::spawn_blocking(move || {
476            let conn = pool
477                .get()
478                .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
479            let now = Utc::now().timestamp();
480            let mut stmt = conn.prepare(
481                "SELECT id FROM sessions WHERE expires_at > ?1 ORDER BY last_accessed DESC",
482            )?;
483            let ids: Vec<String> = stmt
484                .query_map(params![now], |row| row.get(0))?
485                .filter_map(|r| r.ok())
486                .collect();
487            Ok(ids)
488        })
489        .await
490        .map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
491    }
492
493    /// Get count of active sessions
494    pub async fn count(&self) -> Result<usize> {
495        let pool = self.pool.clone();
496        tokio::task::spawn_blocking(move || {
497            let conn = pool
498                .get()
499                .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
500            let now = Utc::now().timestamp();
501            let count: i64 = conn.query_row(
502                "SELECT COUNT(*) FROM sessions WHERE expires_at > ?1",
503                params![now],
504                |row| row.get(0),
505            )?;
506            Ok(count as usize)
507        })
508        .await
509        .map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
510    }
511
512    /// Apply a single atomic mutation directly in SQL using json_set/json_extract.
513    /// Returns the updated session data after the mutation.
514    pub async fn apply_atomic_mutation(
515        &self,
516        id: &str,
517        mutation: &AtomicMutation,
518    ) -> Result<HashMap<String, Value>> {
519        let pool = self.pool.clone();
520        let id = id.to_string();
521        let mutation = mutation.clone();
522        tokio::task::spawn_blocking(move || {
523            let conn = pool.get()
524                .map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
525            let now = Utc::now().timestamp();
526
527            match &mutation {
528                AtomicMutation::Increment { key, value } => {
529                    let path = format!("$.{}", key);
530                    conn.execute(
531                        "UPDATE sessions SET data = json_set(data, ?1, COALESCE(json_extract(data, ?1), 0) + ?2), last_accessed = ?3 WHERE id = ?4",
532                        params![path, value, now, id],
533                    )?;
534                }
535                AtomicMutation::Set { key, value } => {
536                    let path = format!("$.{}", key);
537                    let json_str = serde_json::to_string(value).unwrap_or_default();
538                    conn.execute(
539                        "UPDATE sessions SET data = json_set(data, ?1, json(?2)), last_accessed = ?3 WHERE id = ?4",
540                        params![path, json_str, now, id],
541                    )?;
542                }
543                AtomicMutation::Push { key, value } => {
544                    let path = format!("$.{}", key);
545                    let json_val = serde_json::to_string(value).unwrap_or_default();
546                    conn.execute(
547                        "UPDATE sessions SET data = json_set(data, ?1, \
548                         CASE WHEN json_extract(data, ?1) IS NULL THEN json_array(json(?2)) \
549                         ELSE json_insert(json_extract(data, ?1), '$[#]', json(?2)) END \
550                         ), last_accessed = ?3 WHERE id = ?4",
551                        params![path, json_val, now, id],
552                    )?;
553                }
554                AtomicMutation::PushMax { key, max, value } => {
555                    let path = format!("$.{}", key);
556                    // Read current array, push, trim oldest, write back — all within one connection
557                    let current: String = conn.query_row(
558                        "SELECT COALESCE(json_extract(data, ?1), '[]') FROM sessions WHERE id = ?2",
559                        params![path, id],
560                        |row| row.get(0),
561                    ).unwrap_or_else(|_| "[]".to_string());
562                    let mut arr: Vec<Value> = serde_json::from_str(&current).unwrap_or_default();
563                    arr.push(value.clone());
564                    while arr.len() > *max {
565                        arr.remove(0);
566                    }
567                    let new_arr = serde_json::to_string(&arr).unwrap_or_else(|_| "[]".to_string());
568                    conn.execute(
569                        "UPDATE sessions SET data = json_set(data, ?1, json(?2)), last_accessed = ?3 WHERE id = ?4",
570                        params![path, new_arr, now, id],
571                    )?;
572                }
573                AtomicMutation::Unshift { key, value } => {
574                    let path = format!("$.{}", key);
575                    let current: String = conn.query_row(
576                        "SELECT COALESCE(json_extract(data, ?1), '[]') FROM sessions WHERE id = ?2",
577                        params![path, id],
578                        |row| row.get(0),
579                    ).unwrap_or_else(|_| "[]".to_string());
580                    let mut arr: Vec<Value> = serde_json::from_str(&current).unwrap_or_default();
581                    arr.insert(0, value.clone());
582                    let new_arr = serde_json::to_string(&arr).unwrap_or_else(|_| "[]".to_string());
583                    conn.execute(
584                        "UPDATE sessions SET data = json_set(data, ?1, json(?2)), last_accessed = ?3 WHERE id = ?4",
585                        params![path, new_arr, now, id],
586                    )?;
587                }
588                AtomicMutation::Clear { key } => {
589                    let path = format!("$.{}", key);
590                    conn.execute(
591                        "UPDATE sessions SET data = json_set(data, ?1, json_array()), last_accessed = ?2 WHERE id = ?3",
592                        params![path, now, id],
593                    )?;
594                }
595            }
596
597            // Re-read the session data to return updated state
598            let data_str: String = conn.query_row(
599                "SELECT data FROM sessions WHERE id = ?1",
600                params![id],
601                |row| row.get(0),
602            )?;
603
604            let data: HashMap<String, Value> = serde_json::from_str(&data_str).unwrap_or_default();
605            Ok(data)
606        }).await.map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
607    }
608}
609
610/// Atomic session mutation operation — used by the server to apply mutations at the SQL level
611#[derive(Debug, Clone)]
612pub enum AtomicMutation {
613    /// Increment a numeric value: key += value (value can be negative for decrement)
614    Increment { key: String, value: i64 },
615    /// Set a value: key = value
616    Set { key: String, value: Value },
617    /// Push to end of array
618    Push { key: String, value: Value },
619    /// Push to end of array with max size (drops oldest)
620    PushMax {
621        key: String,
622        max: usize,
623        value: Value,
624    },
625    /// Unshift (prepend) to array
626    Unshift { key: String, value: Value },
627    /// Clear array to empty
628    Clear { key: String },
629}
630
631/// Apply a mutation to in-memory session data (used for KV fallback and testing).
632pub fn apply_mutation_in_memory(data: &mut HashMap<String, Value>, mutation: &AtomicMutation) {
633    match mutation {
634        AtomicMutation::Increment { key, value } => {
635            let current = data.get(key).and_then(|v| v.as_i64()).unwrap_or(0);
636            data.insert(key.clone(), serde_json::json!(current + value));
637        }
638        AtomicMutation::Set { key, value } => {
639            data.insert(key.clone(), value.clone());
640        }
641        AtomicMutation::Push { key, value } => {
642            let arr = data
643                .entry(key.clone())
644                .or_insert_with(|| serde_json::json!([]));
645            if let Some(arr) = arr.as_array_mut() {
646                arr.push(value.clone());
647            }
648        }
649        AtomicMutation::PushMax { key, max, value } => {
650            let arr = data
651                .entry(key.clone())
652                .or_insert_with(|| serde_json::json!([]));
653            if let Some(arr) = arr.as_array_mut() {
654                arr.push(value.clone());
655                while arr.len() > *max {
656                    arr.remove(0);
657                }
658            }
659        }
660        AtomicMutation::Unshift { key, value } => {
661            let arr = data
662                .entry(key.clone())
663                .or_insert_with(|| serde_json::json!([]));
664            if let Some(arr) = arr.as_array_mut() {
665                arr.insert(0, value.clone());
666            }
667        }
668        AtomicMutation::Clear { key } => {
669            data.insert(key.clone(), serde_json::json!([]));
670        }
671    }
672}
673
674// ---------------------------------------------------------------------------
675// Cloudflare Workers KV backend
676// ---------------------------------------------------------------------------
677
678/// Session store backed by Cloudflare Workers KV via REST API
679#[derive(Clone)]
680pub struct KvSessionStore {
681    account_id: String,
682    namespace_id: String,
683    api_token: String,
684    max_age: i64,
685}
686
687impl KvSessionStore {
688    /// Create a new KV session store
689    pub fn new(config: &CloudflareKvConfig, max_age_seconds: i64) -> Self {
690        Self {
691            account_id: config.account_id.clone(),
692            namespace_id: config.namespace_id.clone(),
693            api_token: config.api_token.clone(),
694            max_age: max_age_seconds,
695        }
696    }
697
698    /// Base URL for KV REST API
699    fn base_url(&self) -> String {
700        format!(
701            "https://api.cloudflare.com/client/v4/accounts/{}/storage/kv/namespaces/{}",
702            self.account_id, self.namespace_id
703        )
704    }
705
706    /// KV key for a session ID
707    fn key(&self, session_id: &str) -> String {
708        format!("session:{}", session_id)
709    }
710
711    /// Get the shared HTTP client
712    fn client() -> &'static reqwest::Client {
713        use std::sync::OnceLock;
714        static CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
715        CLIENT.get_or_init(|| {
716            crate::http_client::build_http_client(Some(std::time::Duration::from_secs(10)))
717                .expect("failed to build Cloudflare KV HTTP client")
718        })
719    }
720
721    /// Create a new session
722    pub async fn create(&self) -> Result<Session> {
723        let session = Session::new(self.max_age);
724        self.put_session(&session).await?;
725        Ok(session)
726    }
727
728    /// Get a session by ID
729    pub async fn get(&self, id: &str) -> Result<Option<Session>> {
730        let url = format!("{}/values/{}", self.base_url(), self.key(id));
731
732        let response = Self::client()
733            .get(&url)
734            .bearer_auth(&self.api_token)
735            .send()
736            .await
737            .map_err(|e| crate::Error::Session(format!("KV read failed: {}", e)))?;
738
739        if response.status() == reqwest::StatusCode::NOT_FOUND {
740            return Ok(None);
741        }
742
743        if !response.status().is_success() {
744            return Err(crate::Error::Session(format!(
745                "KV read error: HTTP {}",
746                response.status()
747            )));
748        }
749
750        let body = response
751            .text()
752            .await
753            .map_err(|e| crate::Error::Session(format!("KV read body failed: {}", e)))?;
754
755        match serde_json::from_str::<Session>(&body) {
756            Ok(session) if session.is_expired() => {
757                // Shouldn't normally happen (KV TTL should handle this), but be safe
758                let _ = self.delete(&session.id).await;
759                Ok(None)
760            }
761            Ok(session) => Ok(Some(session)),
762            Err(e) => {
763                tracing::warn!("KV session deserialize failed: {}", e);
764                Ok(None)
765            }
766        }
767    }
768
769    /// Get or create a session
770    pub async fn get_or_create(&self, id: Option<&str>) -> Result<Session> {
771        if let Some(session_id) = id {
772            if let Some(session) = self.get(session_id).await? {
773                self.touch(&session.id).await?;
774                return Ok(session);
775            }
776        }
777        self.create().await
778    }
779
780    /// Update session data
781    pub async fn update(&self, id: &str, data: HashMap<String, Value>) -> Result<()> {
782        // Read current session, update data, write back
783        if let Some(mut session) = self.get(id).await? {
784            session.data = data;
785            session.last_accessed = Utc::now();
786            self.put_session(&session).await?;
787        }
788        Ok(())
789    }
790
791    /// Update last accessed time (read + re-PUT to refresh TTL)
792    pub async fn touch(&self, id: &str) -> Result<()> {
793        if let Some(mut session) = self.get(id).await? {
794            session.last_accessed = Utc::now();
795            self.put_session(&session).await?;
796        }
797        Ok(())
798    }
799
800    /// Delete a session
801    pub async fn delete(&self, id: &str) -> Result<()> {
802        let url = format!("{}/values/{}", self.base_url(), self.key(id));
803
804        Self::client()
805            .delete(&url)
806            .bearer_auth(&self.api_token)
807            .send()
808            .await
809            .map_err(|e| crate::Error::Session(format!("KV delete failed: {}", e)))?;
810
811        Ok(())
812    }
813
814    /// Write a session to KV with TTL
815    async fn put_session(&self, session: &Session) -> Result<()> {
816        let url = format!(
817            "{}/values/{}?expiration_ttl={}",
818            self.base_url(),
819            self.key(&session.id),
820            self.max_age
821        );
822
823        let body = serde_json::to_string(session)
824            .map_err(|e| crate::Error::Session(format!("KV serialize failed: {}", e)))?;
825
826        let response = Self::client()
827            .put(&url)
828            .bearer_auth(&self.api_token)
829            .header("Content-Type", "application/json")
830            .body(body)
831            .send()
832            .await
833            .map_err(|e| crate::Error::Session(format!("KV write failed: {}", e)))?;
834
835        if !response.status().is_success() {
836            let status = response.status();
837            let body = response.text().await.unwrap_or_default();
838            return Err(crate::Error::Session(format!(
839                "KV write error: HTTP {} — {}",
840                status, body
841            )));
842        }
843
844        Ok(())
845    }
846}
847
848// ---------------------------------------------------------------------------
849// Cookie utilities
850// ---------------------------------------------------------------------------
851
852/// Parse session ID from cookie header
853pub fn parse_session_cookie(cookie_header: Option<&str>, cookie_name: &str) -> Option<String> {
854    cookie_header.and_then(|header| {
855        header
856            .split(';')
857            .map(|s| s.trim())
858            .find(|s| s.starts_with(&format!("{}=", cookie_name)))
859            .map(|s| s[cookie_name.len() + 1..].to_string())
860    })
861}
862
863/// Build Set-Cookie header value
864pub fn build_session_cookie(
865    session_id: &str,
866    cookie_name: &str,
867    max_age: i64,
868    secure: bool,
869) -> String {
870    let mut cookie = format!(
871        "{}={}; HttpOnly; SameSite=Strict; Path=/; Max-Age={}",
872        cookie_name, session_id, max_age
873    );
874
875    if secure {
876        cookie.push_str("; Secure");
877    }
878
879    cookie
880}
881
882#[cfg(test)]
883mod tests {
884    use super::*;
885
886    #[test]
887    fn test_generate_session_id() {
888        let id = generate_session_id();
889        assert_eq!(id.len(), 128); // 64 bytes = 128 hex chars
890        assert!(id.chars().all(|c| c.is_ascii_hexdigit()));
891    }
892
893    #[tokio::test]
894    async fn test_session_store() {
895        let store = SqliteSessionStore::in_memory(3600).unwrap();
896
897        // Create session
898        let session = store.create().await.unwrap();
899        assert_eq!(session.id.len(), 128);
900
901        // Get session
902        let retrieved = store.get(&session.id).await.unwrap();
903        assert!(retrieved.is_some());
904        assert_eq!(retrieved.unwrap().id, session.id);
905
906        // Delete session
907        store.delete(&session.id).await.unwrap();
908        let deleted = store.get(&session.id).await.unwrap();
909        assert!(deleted.is_none());
910    }
911
912    #[test]
913    fn test_parse_session_cookie() {
914        let header = "w_session=abc123; other=value";
915        let result = parse_session_cookie(Some(header), "w_session");
916        assert_eq!(result, Some("abc123".to_string()));
917
918        let result = parse_session_cookie(Some(header), "missing");
919        assert_eq!(result, None);
920    }
921
922    #[test]
923    fn test_kv_key_format() {
924        let config = CloudflareKvConfig {
925            account_id: "acc123".to_string(),
926            namespace_id: "ns456".to_string(),
927            api_token: "token789".to_string(),
928        };
929        let store = KvSessionStore::new(&config, 3600);
930        assert_eq!(store.key("abc123"), "session:abc123");
931    }
932
933    #[test]
934    fn test_kv_base_url() {
935        let config = CloudflareKvConfig {
936            account_id: "acc123".to_string(),
937            namespace_id: "ns456".to_string(),
938            api_token: "token789".to_string(),
939        };
940        let store = KvSessionStore::new(&config, 3600);
941        assert_eq!(
942            store.base_url(),
943            "https://api.cloudflare.com/client/v4/accounts/acc123/storage/kv/namespaces/ns456"
944        );
945    }
946
947    #[test]
948    fn test_session_serialization_roundtrip() {
949        let session = Session::new(3600);
950        let json = serde_json::to_string(&session).unwrap();
951        let deserialized: Session = serde_json::from_str(&json).unwrap();
952        assert_eq!(deserialized.id, session.id);
953        // Session now includes _csrf_token by default
954        assert_eq!(deserialized.data.len(), 1);
955        assert!(deserialized.data.contains_key(CSRF_TOKEN_KEY));
956    }
957
958    #[tokio::test]
959    async fn test_atomic_increment() {
960        let store = SqliteSessionStore::in_memory(3600).unwrap();
961        let session = store.create().await.unwrap();
962
963        // Increment from zero
964        let data = store
965            .apply_atomic_mutation(
966                &session.id,
967                &AtomicMutation::Increment {
968                    key: "counter".to_string(),
969                    value: 1,
970                },
971            )
972            .await
973            .unwrap();
974        assert_eq!(data.get("counter").and_then(|v| v.as_i64()), Some(1));
975
976        // Increment again
977        let data = store
978            .apply_atomic_mutation(
979                &session.id,
980                &AtomicMutation::Increment {
981                    key: "counter".to_string(),
982                    value: 5,
983                },
984            )
985            .await
986            .unwrap();
987        assert_eq!(data.get("counter").and_then(|v| v.as_i64()), Some(6));
988
989        // Decrement (negative increment)
990        let data = store
991            .apply_atomic_mutation(
992                &session.id,
993                &AtomicMutation::Increment {
994                    key: "counter".to_string(),
995                    value: -2,
996                },
997            )
998            .await
999            .unwrap();
1000        assert_eq!(data.get("counter").and_then(|v| v.as_i64()), Some(4));
1001    }
1002
1003    #[tokio::test]
1004    async fn test_atomic_set() {
1005        let store = SqliteSessionStore::in_memory(3600).unwrap();
1006        let session = store.create().await.unwrap();
1007
1008        let data = store
1009            .apply_atomic_mutation(
1010                &session.id,
1011                &AtomicMutation::Set {
1012                    key: "name".to_string(),
1013                    value: serde_json::json!("Alice"),
1014                },
1015            )
1016            .await
1017            .unwrap();
1018        assert_eq!(data.get("name").and_then(|v| v.as_str()), Some("Alice"));
1019
1020        // Overwrite
1021        let data = store
1022            .apply_atomic_mutation(
1023                &session.id,
1024                &AtomicMutation::Set {
1025                    key: "name".to_string(),
1026                    value: serde_json::json!("Bob"),
1027                },
1028            )
1029            .await
1030            .unwrap();
1031        assert_eq!(data.get("name").and_then(|v| v.as_str()), Some("Bob"));
1032    }
1033
1034    #[tokio::test]
1035    async fn test_atomic_push() {
1036        let store = SqliteSessionStore::in_memory(3600).unwrap();
1037        let session = store.create().await.unwrap();
1038
1039        // Push to non-existent array creates it
1040        let data = store
1041            .apply_atomic_mutation(
1042                &session.id,
1043                &AtomicMutation::Push {
1044                    key: "items".to_string(),
1045                    value: serde_json::json!("first"),
1046                },
1047            )
1048            .await
1049            .unwrap();
1050        let items = data.get("items").and_then(|v| v.as_array()).unwrap();
1051        assert_eq!(items.len(), 1);
1052        assert_eq!(items[0].as_str(), Some("first"));
1053
1054        // Push another
1055        let data = store
1056            .apply_atomic_mutation(
1057                &session.id,
1058                &AtomicMutation::Push {
1059                    key: "items".to_string(),
1060                    value: serde_json::json!("second"),
1061                },
1062            )
1063            .await
1064            .unwrap();
1065        let items = data.get("items").and_then(|v| v.as_array()).unwrap();
1066        assert_eq!(items.len(), 2);
1067        assert_eq!(items[1].as_str(), Some("second"));
1068    }
1069
1070    #[tokio::test]
1071    async fn test_atomic_push_max() {
1072        let store = SqliteSessionStore::in_memory(3600).unwrap();
1073        let session = store.create().await.unwrap();
1074
1075        // Push 3 items with max 2
1076        for i in 1..=3 {
1077            store
1078                .apply_atomic_mutation(
1079                    &session.id,
1080                    &AtomicMutation::PushMax {
1081                        key: "log".to_string(),
1082                        max: 2,
1083                        value: serde_json::json!(i),
1084                    },
1085                )
1086                .await
1087                .unwrap();
1088        }
1089
1090        let data = store.get(&session.id).await.unwrap().unwrap();
1091        let log = data.data.get("log").and_then(|v| v.as_array()).unwrap();
1092        assert_eq!(log.len(), 2);
1093        // Should have items 2 and 3 (item 1 was dropped)
1094        assert_eq!(log[0].as_i64(), Some(2));
1095        assert_eq!(log[1].as_i64(), Some(3));
1096    }
1097
1098    #[tokio::test]
1099    async fn test_atomic_unshift() {
1100        let store = SqliteSessionStore::in_memory(3600).unwrap();
1101        let session = store.create().await.unwrap();
1102
1103        store
1104            .apply_atomic_mutation(
1105                &session.id,
1106                &AtomicMutation::Unshift {
1107                    key: "stack".to_string(),
1108                    value: serde_json::json!("first"),
1109                },
1110            )
1111            .await
1112            .unwrap();
1113        let data = store
1114            .apply_atomic_mutation(
1115                &session.id,
1116                &AtomicMutation::Unshift {
1117                    key: "stack".to_string(),
1118                    value: serde_json::json!("second"),
1119                },
1120            )
1121            .await
1122            .unwrap();
1123
1124        let stack = data.get("stack").and_then(|v| v.as_array()).unwrap();
1125        assert_eq!(stack.len(), 2);
1126        assert_eq!(stack[0].as_str(), Some("second"));
1127        assert_eq!(stack[1].as_str(), Some("first"));
1128    }
1129
1130    #[tokio::test]
1131    async fn test_atomic_clear() {
1132        let store = SqliteSessionStore::in_memory(3600).unwrap();
1133        let session = store.create().await.unwrap();
1134
1135        // Add some items first
1136        store
1137            .apply_atomic_mutation(
1138                &session.id,
1139                &AtomicMutation::Push {
1140                    key: "items".to_string(),
1141                    value: serde_json::json!("a"),
1142                },
1143            )
1144            .await
1145            .unwrap();
1146        store
1147            .apply_atomic_mutation(
1148                &session.id,
1149                &AtomicMutation::Push {
1150                    key: "items".to_string(),
1151                    value: serde_json::json!("b"),
1152                },
1153            )
1154            .await
1155            .unwrap();
1156
1157        // Clear
1158        let data = store
1159            .apply_atomic_mutation(
1160                &session.id,
1161                &AtomicMutation::Clear {
1162                    key: "items".to_string(),
1163                },
1164            )
1165            .await
1166            .unwrap();
1167        let items = data.get("items").and_then(|v| v.as_array()).unwrap();
1168        assert_eq!(items.len(), 0);
1169    }
1170
1171    #[test]
1172    fn test_apply_mutation_in_memory() {
1173        let mut data = HashMap::new();
1174
1175        apply_mutation_in_memory(
1176            &mut data,
1177            &AtomicMutation::Increment {
1178                key: "x".to_string(),
1179                value: 3,
1180            },
1181        );
1182        assert_eq!(data.get("x").and_then(|v| v.as_i64()), Some(3));
1183
1184        apply_mutation_in_memory(
1185            &mut data,
1186            &AtomicMutation::Set {
1187                key: "name".to_string(),
1188                value: serde_json::json!("test"),
1189            },
1190        );
1191        assert_eq!(data.get("name").and_then(|v| v.as_str()), Some("test"));
1192
1193        apply_mutation_in_memory(
1194            &mut data,
1195            &AtomicMutation::Push {
1196                key: "list".to_string(),
1197                value: serde_json::json!(1),
1198            },
1199        );
1200        apply_mutation_in_memory(
1201            &mut data,
1202            &AtomicMutation::Push {
1203                key: "list".to_string(),
1204                value: serde_json::json!(2),
1205            },
1206        );
1207        let list = data.get("list").and_then(|v| v.as_array()).unwrap();
1208        assert_eq!(list.len(), 2);
1209    }
1210}