Skip to main content

rustbasic_core/
session_manager.rs

1use axum_session::{DatabasePool, DatabaseError};
2use async_trait::async_trait;
3use sqlx::AnyPool;
4use dashmap::DashMap;
5use once_cell::sync::Lazy;
6
7/// Map global untuk melacak IP per Session ID secara sementara.
8/// Digunakan karena trait DatabasePool tidak memberikan akses ke request/IP secara langsung.
9pub static IP_TRACKER: Lazy<DashMap<String, String>> = Lazy::new(DashMap::new);
10
11#[derive(Clone, Debug)]
12pub struct RustBasicSessionStore {
13    pub pool: AnyPool,
14}
15
16impl RustBasicSessionStore {
17    pub fn new(pool: AnyPool) -> Self {
18        Self { pool }
19    }
20
21    async fn get_placeholder_query(&self, sql: &str) -> String {
22        let is_mysql = if let Ok(conn) = self.pool.acquire().await {
23            conn.backend_name() == "MySQL"
24        } else {
25            false
26        };
27
28        if is_mysql {
29            let re = regex::Regex::new(r"\$\d+").unwrap();
30            re.replace_all(sql, "?").into_owned()
31        } else {
32            sql.to_string()
33        }
34    }
35}
36
37#[async_trait]
38impl DatabasePool for RustBasicSessionStore {
39    async fn initiate(&self, _table_name: &str) -> Result<(), DatabaseError> {
40        Ok(())
41    }
42
43    async fn delete_one_by_id(&self, id: &str, table_name: &str) -> Result<(), DatabaseError> {
44        let raw_query = format!("DELETE FROM {} WHERE id = $1", table_name);
45        let query = self.get_placeholder_query(&raw_query).await;
46        sqlx::query(&query)
47            .bind(id)
48            .execute(&self.pool)
49            .await
50            .map_err(|e| DatabaseError::GenericDeleteError(e.to_string()))?;
51        
52        // Bersihkan tracker
53        IP_TRACKER.remove(id);
54        
55        Ok(())
56    }
57
58    async fn load(&self, id: &str, table_name: &str) -> Result<Option<String>, DatabaseError> {
59        let raw_query = format!("SELECT payload FROM {} WHERE id = $1 AND last_activity > $2", table_name);
60        let query = self.get_placeholder_query(&raw_query).await;
61        let now = chrono::Utc::now().timestamp();
62        
63        let row: Option<(String,)> = sqlx::query_as(&query)
64            .bind(id)
65            .bind(now)
66            .fetch_optional(&self.pool)
67            .await
68            .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?;
69
70        Ok(row.map(|r| r.0))
71    }
72
73    async fn store(&self, id: &str, session: &str, expires: i64, table_name: &str) -> Result<(), DatabaseError> {
74        // Ambil IP dari tracker (jika ada)
75        let ip = IP_TRACKER.get(id).map(|i| i.clone()).unwrap_or_else(|| "unknown".to_string());
76
77        let raw_delete_query = format!("DELETE FROM {} WHERE id = $1", table_name);
78        let delete_query = self.get_placeholder_query(&raw_delete_query).await;
79        sqlx::query(&delete_query).bind(id).execute(&self.pool).await.ok();
80
81        let raw_insert_query = format!(
82            "INSERT INTO {} (id, payload, last_activity, ip_address) VALUES ($1, $2, $3, $4)",
83            table_name
84        );
85        let insert_query = self.get_placeholder_query(&raw_insert_query).await;
86
87        sqlx::query(&insert_query)
88            .bind(id)
89            .bind(session)
90            .bind(expires)
91            .bind(ip)
92            .execute(&self.pool)
93            .await
94            .map_err(|e| DatabaseError::GenericInsertError(e.to_string()))?;
95        Ok(())
96    }
97
98    async fn delete_by_expiry(&self, table_name: &str) -> Result<Vec<String>, DatabaseError> {
99        let now = chrono::Utc::now().timestamp();
100        let raw_select_query = format!("SELECT id FROM {} WHERE last_activity < $1", table_name);
101        let select_query = self.get_placeholder_query(&raw_select_query).await;
102        let ids: Vec<String> = sqlx::query_as::<_, (String,)>(&select_query)
103            .bind(now)
104            .fetch_all(&self.pool)
105            .await
106            .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?
107            .into_iter()
108            .map(|r| r.0)
109            .collect();
110
111        // Bersihkan tracker untuk ID yang expired
112        for id in &ids {
113            IP_TRACKER.remove(id);
114        }
115
116        let raw_delete_query = format!("DELETE FROM {} WHERE last_activity < $1", table_name);
117        let delete_query = self.get_placeholder_query(&raw_delete_query).await;
118        sqlx::query(&delete_query)
119            .bind(now)
120            .execute(&self.pool)
121            .await
122            .map_err(|e| DatabaseError::GenericDeleteError(e.to_string()))?;
123
124        Ok(ids)
125    }
126
127    async fn count(&self, table_name: &str) -> Result<i64, DatabaseError> {
128        let query = format!("SELECT COUNT(*) FROM {}", table_name);
129        let count: (i64,) = sqlx::query_as(&query)
130            .fetch_one(&self.pool)
131            .await
132            .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?;
133        Ok(count.0)
134    }
135
136    async fn exists(&self, id: &str, table_name: &str) -> Result<bool, DatabaseError> {
137        let raw_query = format!("SELECT id FROM {} WHERE id = $1", table_name);
138        let query = self.get_placeholder_query(&raw_query).await;
139        let row: Option<(String,)> = sqlx::query_as(&query)
140            .bind(id)
141            .fetch_optional(&self.pool)
142            .await
143            .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?;
144        Ok(row.is_some())
145    }
146
147    async fn delete_all(&self, table_name: &str) -> Result<(), DatabaseError> {
148        let query = format!("DELETE FROM {}", table_name);
149        sqlx::query(&query)
150            .execute(&self.pool)
151            .await
152            .map_err(|e| DatabaseError::GenericDeleteError(e.to_string()))?;
153        
154        IP_TRACKER.clear();
155        
156        Ok(())
157    }
158
159    async fn get_ids(&self, table_name: &str) -> Result<Vec<String>, DatabaseError> {
160        let query = format!("SELECT id FROM {}", table_name);
161        let ids: Vec<String> = sqlx::query_as::<_, (String,)>(&query)
162            .fetch_all(&self.pool)
163            .await
164            .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?
165            .into_iter()
166            .map(|r| r.0)
167            .collect();
168        Ok(ids)
169    }
170
171    fn auto_handles_expiry(&self) -> bool {
172        false
173    }
174}