rustbasic_core/
session_manager.rs1use sqlx::AnyPool;
2
3fn replace_postgres_placeholders(sql: &str) -> String {
6 let mut result = String::with_capacity(sql.len());
7 let mut chars = sql.chars().peekable();
8
9 while let Some(ch) = chars.next() {
10 if ch == '$' {
11 let has_digit = chars.peek().map_or(false, |c| c.is_ascii_digit());
13 if has_digit {
14 while chars.peek().map_or(false, |c| c.is_ascii_digit()) {
15 chars.next();
16 }
17 result.push('?');
18 } else {
19 result.push('$');
20 }
21 } else {
22 result.push(ch);
23 }
24 }
25
26 result
27}
28
29#[derive(Clone, Debug)]
30pub struct RustBasicSessionStore {
31 pub pool: AnyPool,
32}
33
34impl RustBasicSessionStore {
35 pub fn new(pool: AnyPool) -> Self {
36 Self { pool }
37 }
38
39 async fn get_placeholder_query(&self, sql: &str) -> String {
40 let is_mysql = if let Ok(conn) = self.pool.acquire().await {
41 conn.backend_name() == "MySQL"
42 } else {
43 false
44 };
45
46 if is_mysql {
47 replace_postgres_placeholders(sql)
48 } else {
49 sql.to_string()
50 }
51 }
52
53 pub async fn load(&self, id: &str) -> Option<String> {
54 let raw_query = "SELECT payload FROM sessions WHERE id = $1 AND last_activity > $2";
55 let query = self.get_placeholder_query(raw_query).await;
56 let now = chrono::Utc::now().timestamp();
57
58 use sqlx::Row;
59 let row_opt = sqlx::query(&query)
60 .bind(id)
61 .bind(now)
62 .fetch_optional(&self.pool)
63 .await
64 .ok()
65 .flatten();
66
67 if let Some(row) = row_opt {
68 if let Ok(s) = row.try_get::<String, _>(0) {
69 return Some(s);
70 }
71 if let Ok(bytes) = row.try_get::<Vec<u8>, _>(0) {
72 if let Ok(s) = String::from_utf8(bytes) {
73 return Some(s);
74 }
75 }
76 }
77 None
78 }
79
80 pub async fn store(&self, id: &str, session_json: &str, ip: &str) {
81 let raw_delete_query = "DELETE FROM sessions WHERE id = $1";
82 let delete_query = self.get_placeholder_query(raw_delete_query).await;
83 let _ = sqlx::query(&delete_query).bind(id).execute(&self.pool).await;
84
85 let raw_insert_query = "INSERT INTO sessions (id, payload, last_activity, ip_address) VALUES ($1, $2, $3, $4)";
86 let insert_query = self.get_placeholder_query(raw_insert_query).await;
87 let expires = chrono::Utc::now().timestamp() + 14 * 24 * 60 * 60; let _ = sqlx::query(&insert_query)
90 .bind(id)
91 .bind(session_json)
92 .bind(expires)
93 .bind(ip)
94 .execute(&self.pool)
95 .await;
96 }
97}