rustbasic_core/
session_manager.rs1use sqlx::AnyPool;
2
3fn replace_postgres_placeholders(sql: &str) -> String {
6 let mut result = String::with_capacity(sql.len());
7 let mut chars = sql.chars().peekable();
8
9 while let Some(ch) = chars.next() {
10 if ch == '$' {
11 let has_digit = chars.peek().map_or(false, |c| c.is_ascii_digit());
13 if has_digit {
14 while chars.peek().map_or(false, |c| c.is_ascii_digit()) {
15 chars.next();
16 }
17 result.push('?');
18 } else {
19 result.push('$');
20 }
21 } else {
22 result.push(ch);
23 }
24 }
25
26 result
27}
28
29#[derive(Clone, Debug)]
30pub struct RustBasicSessionStore {
31 pub pool: AnyPool,
32}
33
34impl RustBasicSessionStore {
35 pub fn new(pool: AnyPool) -> Self {
36 Self { pool }
37 }
38
39 async fn get_placeholder_query(&self, sql: &str) -> String {
40 let is_mysql = if let Ok(conn) = self.pool.acquire().await {
41 conn.backend_name() == "MySQL"
42 } else {
43 false
44 };
45
46 if is_mysql {
47 replace_postgres_placeholders(sql)
48 } else {
49 sql.to_string()
50 }
51 }
52
53 pub async fn load(&self, id: &str) -> Option<String> {
54 let raw_query = "SELECT payload FROM sessions WHERE id = $1 AND last_activity > $2";
55 let query = self.get_placeholder_query(raw_query).await;
56 let now = chrono::Utc::now().timestamp();
57
58 let row: Option<(String,)> = sqlx::query_as(&query)
59 .bind(id)
60 .bind(now)
61 .fetch_optional(&self.pool)
62 .await
63 .ok()
64 .flatten();
65
66 row.map(|r| r.0)
67 }
68
69 pub async fn store(&self, id: &str, session_json: &str, ip: &str) {
70 let raw_delete_query = "DELETE FROM sessions WHERE id = $1";
71 let delete_query = self.get_placeholder_query(raw_delete_query).await;
72 let _ = sqlx::query(&delete_query).bind(id).execute(&self.pool).await;
73
74 let raw_insert_query = "INSERT INTO sessions (id, payload, last_activity, ip_address) VALUES ($1, $2, $3, $4)";
75 let insert_query = self.get_placeholder_query(raw_insert_query).await;
76 let expires = chrono::Utc::now().timestamp() + 14 * 24 * 60 * 60; let _ = sqlx::query(&insert_query)
79 .bind(id)
80 .bind(session_json)
81 .bind(expires)
82 .bind(ip)
83 .execute(&self.pool)
84 .await;
85 }
86}