rustbasic_core/
session_manager.rs1use axum_session::{DatabasePool, DatabaseError};
2use async_trait::async_trait;
3use sqlx::AnyPool;
4use dashmap::DashMap;
5use once_cell::sync::Lazy;
6
7pub static IP_TRACKER: Lazy<DashMap<String, String>> = Lazy::new(DashMap::new);
10
11#[derive(Clone, Debug)]
12pub struct RustBasicSessionStore {
13 pub pool: AnyPool,
14}
15
16impl RustBasicSessionStore {
17 pub fn new(pool: AnyPool) -> Self {
18 Self { pool }
19 }
20
21 async fn get_placeholder_query(&self, sql: &str) -> String {
22 let is_mysql = if let Ok(conn) = self.pool.acquire().await {
23 conn.backend_name() == "MySQL"
24 } else {
25 false
26 };
27
28 if is_mysql {
29 let re = regex::Regex::new(r"\$\d+").unwrap();
30 re.replace_all(sql, "?").into_owned()
31 } else {
32 sql.to_string()
33 }
34 }
35}
36
37#[async_trait]
38impl DatabasePool for RustBasicSessionStore {
39 async fn initiate(&self, _table_name: &str) -> Result<(), DatabaseError> {
40 Ok(())
41 }
42
43 async fn delete_one_by_id(&self, id: &str, table_name: &str) -> Result<(), DatabaseError> {
44 let raw_query = format!("DELETE FROM {} WHERE id = $1", table_name);
45 let query = self.get_placeholder_query(&raw_query).await;
46 sqlx::query(&query)
47 .bind(id)
48 .execute(&self.pool)
49 .await
50 .map_err(|e| DatabaseError::GenericDeleteError(e.to_string()))?;
51
52 IP_TRACKER.remove(id);
54
55 Ok(())
56 }
57
58 async fn load(&self, id: &str, table_name: &str) -> Result<Option<String>, DatabaseError> {
59 let raw_query = format!("SELECT payload FROM {} WHERE id = $1 AND last_activity > $2", table_name);
60 let query = self.get_placeholder_query(&raw_query).await;
61 let now = chrono::Utc::now().timestamp();
62
63 let row: Option<(String,)> = sqlx::query_as(&query)
64 .bind(id)
65 .bind(now)
66 .fetch_optional(&self.pool)
67 .await
68 .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?;
69
70 Ok(row.map(|r| r.0))
71 }
72
73 async fn store(&self, id: &str, session: &str, expires: i64, table_name: &str) -> Result<(), DatabaseError> {
74 let ip = IP_TRACKER.get(id).map(|i| i.clone()).unwrap_or_else(|| "unknown".to_string());
76
77 let raw_delete_query = format!("DELETE FROM {} WHERE id = $1", table_name);
78 let delete_query = self.get_placeholder_query(&raw_delete_query).await;
79 sqlx::query(&delete_query).bind(id).execute(&self.pool).await.ok();
80
81 let raw_insert_query = format!(
82 "INSERT INTO {} (id, payload, last_activity, ip_address) VALUES ($1, $2, $3, $4)",
83 table_name
84 );
85 let insert_query = self.get_placeholder_query(&raw_insert_query).await;
86
87 sqlx::query(&insert_query)
88 .bind(id)
89 .bind(session)
90 .bind(expires)
91 .bind(ip)
92 .execute(&self.pool)
93 .await
94 .map_err(|e| DatabaseError::GenericInsertError(e.to_string()))?;
95 Ok(())
96 }
97
98 async fn delete_by_expiry(&self, table_name: &str) -> Result<Vec<String>, DatabaseError> {
99 let now = chrono::Utc::now().timestamp();
100 let raw_select_query = format!("SELECT id FROM {} WHERE last_activity < $1", table_name);
101 let select_query = self.get_placeholder_query(&raw_select_query).await;
102 let ids: Vec<String> = sqlx::query_as::<_, (String,)>(&select_query)
103 .bind(now)
104 .fetch_all(&self.pool)
105 .await
106 .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?
107 .into_iter()
108 .map(|r| r.0)
109 .collect();
110
111 for id in &ids {
113 IP_TRACKER.remove(id);
114 }
115
116 let raw_delete_query = format!("DELETE FROM {} WHERE last_activity < $1", table_name);
117 let delete_query = self.get_placeholder_query(&raw_delete_query).await;
118 sqlx::query(&delete_query)
119 .bind(now)
120 .execute(&self.pool)
121 .await
122 .map_err(|e| DatabaseError::GenericDeleteError(e.to_string()))?;
123
124 Ok(ids)
125 }
126
127 async fn count(&self, table_name: &str) -> Result<i64, DatabaseError> {
128 let query = format!("SELECT COUNT(*) FROM {}", table_name);
129 let count: (i64,) = sqlx::query_as(&query)
130 .fetch_one(&self.pool)
131 .await
132 .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?;
133 Ok(count.0)
134 }
135
136 async fn exists(&self, id: &str, table_name: &str) -> Result<bool, DatabaseError> {
137 let raw_query = format!("SELECT id FROM {} WHERE id = $1", table_name);
138 let query = self.get_placeholder_query(&raw_query).await;
139 let row: Option<(String,)> = sqlx::query_as(&query)
140 .bind(id)
141 .fetch_optional(&self.pool)
142 .await
143 .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?;
144 Ok(row.is_some())
145 }
146
147 async fn delete_all(&self, table_name: &str) -> Result<(), DatabaseError> {
148 let query = format!("DELETE FROM {}", table_name);
149 sqlx::query(&query)
150 .execute(&self.pool)
151 .await
152 .map_err(|e| DatabaseError::GenericDeleteError(e.to_string()))?;
153
154 IP_TRACKER.clear();
155
156 Ok(())
157 }
158
159 async fn get_ids(&self, table_name: &str) -> Result<Vec<String>, DatabaseError> {
160 let query = format!("SELECT id FROM {}", table_name);
161 let ids: Vec<String> = sqlx::query_as::<_, (String,)>(&query)
162 .fetch_all(&self.pool)
163 .await
164 .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?
165 .into_iter()
166 .map(|r| r.0)
167 .collect();
168 Ok(ids)
169 }
170
171 fn auto_handles_expiry(&self) -> bool {
172 false
173 }
174}