Skip to main content

rustbasic_core/
session_manager.rs

1use axum_session::{DatabasePool, DatabaseError};
2use async_trait::async_trait;
3use sqlx::AnyPool;
4use dashmap::DashMap;
5use once_cell::sync::Lazy;
6
7/// Map global untuk melacak IP per Session ID secara sementara.
8/// Digunakan karena trait DatabasePool tidak memberikan akses ke request/IP secara langsung.
9pub static IP_TRACKER: Lazy<DashMap<String, String>> = Lazy::new(DashMap::new);
10
11#[derive(Clone, Debug)]
12pub struct RustBasicSessionStore {
13    pub pool: AnyPool,
14}
15
16impl RustBasicSessionStore {
17    pub fn new(pool: AnyPool) -> Self {
18        Self { pool }
19    }
20}
21
22#[async_trait]
23impl DatabasePool for RustBasicSessionStore {
24    async fn initiate(&self, _table_name: &str) -> Result<(), DatabaseError> {
25        Ok(())
26    }
27
28    async fn delete_one_by_id(&self, id: &str, table_name: &str) -> Result<(), DatabaseError> {
29        let query = format!("DELETE FROM {} WHERE id = $1", table_name);
30        sqlx::query(&query)
31            .bind(id)
32            .execute(&self.pool)
33            .await
34            .map_err(|e| DatabaseError::GenericDeleteError(e.to_string()))?;
35        
36        // Bersihkan tracker
37        IP_TRACKER.remove(id);
38        
39        Ok(())
40    }
41
42    async fn load(&self, id: &str, table_name: &str) -> Result<Option<String>, DatabaseError> {
43        let query = format!("SELECT payload FROM {} WHERE id = $1 AND last_activity > $2", table_name);
44        let now = chrono::Utc::now().timestamp();
45        
46        let row: Option<(String,)> = sqlx::query_as(&query)
47            .bind(id)
48            .bind(now)
49            .fetch_optional(&self.pool)
50            .await
51            .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?;
52
53        Ok(row.map(|r| r.0))
54    }
55
56    async fn store(&self, id: &str, session: &str, expires: i64, table_name: &str) -> Result<(), DatabaseError> {
57        // Ambil IP dari tracker (jika ada)
58        let ip = IP_TRACKER.get(id).map(|i| i.clone()).unwrap_or_else(|| "unknown".to_string());
59
60        let delete_query = format!("DELETE FROM {} WHERE id = $1", table_name);
61        sqlx::query(&delete_query).bind(id).execute(&self.pool).await.ok();
62
63        let insert_query = format!(
64            "INSERT INTO {} (id, payload, last_activity, ip_address) VALUES ($1, $2, $3, $4)",
65            table_name
66        );
67
68        sqlx::query(&insert_query)
69            .bind(id)
70            .bind(session)
71            .bind(expires)
72            .bind(ip)
73            .execute(&self.pool)
74            .await
75            .map_err(|e| DatabaseError::GenericInsertError(e.to_string()))?;
76        Ok(())
77    }
78
79    async fn delete_by_expiry(&self, table_name: &str) -> Result<Vec<String>, DatabaseError> {
80        let now = chrono::Utc::now().timestamp();
81        let select_query = format!("SELECT id FROM {} WHERE last_activity < $1", table_name);
82        let ids: Vec<String> = sqlx::query_as::<_, (String,)>(&select_query)
83            .bind(now)
84            .fetch_all(&self.pool)
85            .await
86            .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?
87            .into_iter()
88            .map(|r| r.0)
89            .collect();
90
91        // Bersihkan tracker untuk ID yang expired
92        for id in &ids {
93            IP_TRACKER.remove(id);
94        }
95
96        let delete_query = format!("DELETE FROM {} WHERE last_activity < $1", table_name);
97        sqlx::query(&delete_query)
98            .bind(now)
99            .execute(&self.pool)
100            .await
101            .map_err(|e| DatabaseError::GenericDeleteError(e.to_string()))?;
102
103        Ok(ids)
104    }
105
106    async fn count(&self, table_name: &str) -> Result<i64, DatabaseError> {
107        let query = format!("SELECT COUNT(*) FROM {}", table_name);
108        let count: (i64,) = sqlx::query_as(&query)
109            .fetch_one(&self.pool)
110            .await
111            .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?;
112        Ok(count.0)
113    }
114
115    async fn exists(&self, id: &str, table_name: &str) -> Result<bool, DatabaseError> {
116        let query = format!("SELECT id FROM {} WHERE id = $1", table_name);
117        let row: Option<(String,)> = sqlx::query_as(&query)
118            .bind(id)
119            .fetch_optional(&self.pool)
120            .await
121            .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?;
122        Ok(row.is_some())
123    }
124
125    async fn delete_all(&self, table_name: &str) -> Result<(), DatabaseError> {
126        let query = format!("DELETE FROM {}", table_name);
127        sqlx::query(&query)
128            .execute(&self.pool)
129            .await
130            .map_err(|e| DatabaseError::GenericDeleteError(e.to_string()))?;
131        
132        IP_TRACKER.clear();
133        
134        Ok(())
135    }
136
137    async fn get_ids(&self, table_name: &str) -> Result<Vec<String>, DatabaseError> {
138        let query = format!("SELECT id FROM {}", table_name);
139        let ids: Vec<String> = sqlx::query_as::<_, (String,)>(&query)
140            .fetch_all(&self.pool)
141            .await
142            .map_err(|e| DatabaseError::GenericSelectError(e.to_string()))?
143            .into_iter()
144            .map(|r| r.0)
145            .collect();
146        Ok(ids)
147    }
148
149    fn auto_handles_expiry(&self) -> bool {
150        false
151    }
152}