rustbasic_core/
session_manager.rs1use axum_session::{DatabasePool, DatabaseError};
2use async_trait::async_trait;
3use sqlx::AnyPool;
4use dashmap::DashMap;
5use once_cell::sync::Lazy;
6
7pub static IP_TRACKER: Lazy<DashMap<String, String>> = Lazy::new(DashMap::new);
10
11#[derive(Clone, Debug)]
12pub struct RustBasicSessionStore {
13 pub pool: AnyPool,
14}
15
16impl RustBasicSessionStore {
17 pub fn new(pool: AnyPool) -> Self {
18 Self { pool }
19 }
20}
21
22#[async_trait]
23impl DatabasePool for RustBasicSessionStore {
24 async fn initiate(&self, _table_name: &str) -> Result<(), DatabaseError> {
25 Ok(())
26 }
27
28 async fn delete_one_by_id(&self, id: &str, table_name: &str) -> Result<(), DatabaseError> {
29 let query = format!("DELETE FROM {} WHERE id = $1", table_name);
30 sqlx::query(&query)
31 .bind(id)
32 .execute(&self.pool)
33 .await
34 .map_err(|e| DatabaseError::GenericDeleteError(e.to_string()))?;
35
36 IP_TRACKER.remove(id);
38
39 Ok(())
40 }
41
42 async fn load(&self, id: &str, table_name: &str) -> Result<Option<String>, DatabaseError> {
43 let query = format!("SELECT payload FROM {} WHERE id = $1 AND last_activity > $2", table_name);
44 let now = chrono::Utc::now().timestamp();
45
46 let row: Option<(String,)> = sqlx::query_as(&query)
47 .bind(id)
48 .bind(now)
49 .fetch_optional(&self.pool)
50 .await
51 .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?;
52
53 Ok(row.map(|r| r.0))
54 }
55
56 async fn store(&self, id: &str, session: &str, expires: i64, table_name: &str) -> Result<(), DatabaseError> {
57 let ip = IP_TRACKER.get(id).map(|i| i.clone()).unwrap_or_else(|| "unknown".to_string());
59
60 let delete_query = format!("DELETE FROM {} WHERE id = $1", table_name);
61 sqlx::query(&delete_query).bind(id).execute(&self.pool).await.ok();
62
63 let insert_query = format!(
64 "INSERT INTO {} (id, payload, last_activity, ip_address) VALUES ($1, $2, $3, $4)",
65 table_name
66 );
67
68 sqlx::query(&insert_query)
69 .bind(id)
70 .bind(session)
71 .bind(expires)
72 .bind(ip)
73 .execute(&self.pool)
74 .await
75 .map_err(|e| DatabaseError::GenericInsertError(e.to_string()))?;
76 Ok(())
77 }
78
79 async fn delete_by_expiry(&self, table_name: &str) -> Result<Vec<String>, DatabaseError> {
80 let now = chrono::Utc::now().timestamp();
81 let select_query = format!("SELECT id FROM {} WHERE last_activity < $1", table_name);
82 let ids: Vec<String> = sqlx::query_as::<_, (String,)>(&select_query)
83 .bind(now)
84 .fetch_all(&self.pool)
85 .await
86 .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?
87 .into_iter()
88 .map(|r| r.0)
89 .collect();
90
91 for id in &ids {
93 IP_TRACKER.remove(id);
94 }
95
96 let delete_query = format!("DELETE FROM {} WHERE last_activity < $1", table_name);
97 sqlx::query(&delete_query)
98 .bind(now)
99 .execute(&self.pool)
100 .await
101 .map_err(|e| DatabaseError::GenericDeleteError(e.to_string()))?;
102
103 Ok(ids)
104 }
105
106 async fn count(&self, table_name: &str) -> Result<i64, DatabaseError> {
107 let query = format!("SELECT COUNT(*) FROM {}", table_name);
108 let count: (i64,) = sqlx::query_as(&query)
109 .fetch_one(&self.pool)
110 .await
111 .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?;
112 Ok(count.0)
113 }
114
115 async fn exists(&self, id: &str, table_name: &str) -> Result<bool, DatabaseError> {
116 let query = format!("SELECT id FROM {} WHERE id = $1", table_name);
117 let row: Option<(String,)> = sqlx::query_as(&query)
118 .bind(id)
119 .fetch_optional(&self.pool)
120 .await
121 .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?;
122 Ok(row.is_some())
123 }
124
125 async fn delete_all(&self, table_name: &str) -> Result<(), DatabaseError> {
126 let query = format!("DELETE FROM {}", table_name);
127 sqlx::query(&query)
128 .execute(&self.pool)
129 .await
130 .map_err(|e| DatabaseError::GenericDeleteError(e.to_string()))?;
131
132 IP_TRACKER.clear();
133
134 Ok(())
135 }
136
137 async fn get_ids(&self, table_name: &str) -> Result<Vec<String>, DatabaseError> {
138 let query = format!("SELECT id FROM {}", table_name);
139 let ids: Vec<String> = sqlx::query_as::<_, (String,)>(&query)
140 .fetch_all(&self.pool)
141 .await
142 .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?
143 .into_iter()
144 .map(|r| r.0)
145 .collect();
146 Ok(ids)
147 }
148
149 fn auto_handles_expiry(&self) -> bool {
150 false
151 }
152}