Skip to main content

sa_token_storage_database/
lib.rs

1// Author: 金书记
2//
3//! # sa-token-storage-database
4//!
5//! 基于 sqlx 的关系型数据库存储实现(默认 PostgreSQL,可选 MySQL)。
6//!
7//! ## DDL
8//!
9//! 见 [`migrations/001_sa_token_storage.sql`](../../migrations/001_sa_token_storage.sql)。
10
11use std::time::Duration;
12
13use async_trait::async_trait;
14use chrono::{DateTime, Utc};
15use sa_token_adapter::storage::{SaStorage, StorageError, StorageResult};
16use sqlx::{Pool, Postgres};
17
18/// PostgreSQL 存储实现
19#[derive(Clone)]
20pub struct DatabaseStorage {
21    pool: Pool<Postgres>,
22}
23
24impl DatabaseStorage {
25    /// 连接数据库并确保表结构存在
26    pub async fn new(database_url: &str) -> StorageResult<Self> {
27        let pool = Pool::<Postgres>::connect(database_url)
28            .await
29            .map_err(|e| StorageError::ConnectionError(e.to_string()))?;
30
31        let storage = Self { pool };
32        storage.migrate().await?;
33        Ok(storage)
34    }
35
36    /// 使用已有连接池
37    pub fn from_pool(pool: Pool<Postgres>) -> Self {
38        Self { pool }
39    }
40
41    /// 执行内嵌 DDL(幂等)
42    pub async fn migrate(&self) -> StorageResult<()> {
43        let ddl = include_str!("../migrations/001_sa_token_storage.sql");
44        for statement in ddl.split(';').map(str::trim).filter(|s| !s.is_empty()) {
45            sqlx::query(statement)
46                .execute(&self.pool)
47                .await
48                .map_err(|e| StorageError::OperationFailed(e.to_string()))?;
49        }
50        Ok(())
51    }
52
53    async fn delete_expired(&self, key: &str) -> StorageResult<()> {
54        sqlx::query("DELETE FROM sa_token_storage WHERE key = $1")
55            .bind(key)
56            .execute(&self.pool)
57            .await
58            .map_err(|e| StorageError::OperationFailed(e.to_string()))?;
59        Ok(())
60    }
61
62    fn is_expired(expire_at: Option<DateTime<Utc>>) -> bool {
63        expire_at.is_some_and(|t| Utc::now() > t)
64    }
65}
66
67/// 将 `*` 通配符转为 SQL LIKE 模式,并转义 `%` / `_`
68pub fn pattern_to_like(pattern: &str) -> String {
69    let mut out = String::with_capacity(pattern.len());
70    for ch in pattern.chars() {
71        match ch {
72            '*' => out.push('%'),
73            '%' | '_' | '\\' => {
74                out.push('\\');
75                out.push(ch);
76            }
77            other => out.push(other),
78        }
79    }
80    out
81}
82
83#[async_trait]
84impl SaStorage for DatabaseStorage {
85    async fn get(&self, key: &str) -> StorageResult<Option<String>> {
86        let row: Option<(String, Option<DateTime<Utc>>)> = sqlx::query_as(
87            "SELECT value, expire_at FROM sa_token_storage WHERE key = $1",
88        )
89        .bind(key)
90        .fetch_optional(&self.pool)
91        .await
92        .map_err(|e| StorageError::OperationFailed(e.to_string()))?;
93
94        match row {
95            Some((_value, expire_at)) if Self::is_expired(expire_at) => {
96                self.delete_expired(key).await?;
97                Ok(None)
98            }
99            Some((value, _)) => Ok(Some(value)),
100            None => Ok(None),
101        }
102    }
103
104    async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> StorageResult<()> {
105        let expire_at: Option<DateTime<Utc>> = ttl.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap());
106
107        sqlx::query(
108            r#"
109            INSERT INTO sa_token_storage (key, value, expire_at, updated_at)
110            VALUES ($1, $2, $3, NOW())
111            ON CONFLICT (key) DO UPDATE
112            SET value = EXCLUDED.value,
113                expire_at = EXCLUDED.expire_at,
114                updated_at = NOW()
115            "#,
116        )
117        .bind(key)
118        .bind(value)
119        .bind(expire_at)
120        .execute(&self.pool)
121        .await
122        .map_err(|e| StorageError::OperationFailed(e.to_string()))?;
123
124        Ok(())
125    }
126
127    async fn delete(&self, key: &str) -> StorageResult<()> {
128        sqlx::query("DELETE FROM sa_token_storage WHERE key = $1")
129            .bind(key)
130            .execute(&self.pool)
131            .await
132            .map_err(|e| StorageError::OperationFailed(e.to_string()))?;
133        Ok(())
134    }
135
136    async fn exists(&self, key: &str) -> StorageResult<bool> {
137        Ok(self.get(key).await?.is_some())
138    }
139
140    async fn expire(&self, key: &str, ttl: Duration) -> StorageResult<()> {
141        let expire_at = Utc::now() + chrono::Duration::from_std(ttl).unwrap();
142        let updated = sqlx::query(
143            "UPDATE sa_token_storage SET expire_at = $1, updated_at = NOW() WHERE key = $2",
144        )
145        .bind(expire_at)
146        .bind(key)
147        .execute(&self.pool)
148        .await
149        .map_err(|e| StorageError::OperationFailed(e.to_string()))?
150        .rows_affected();
151
152        if updated == 0 {
153            return Err(StorageError::KeyNotFound(key.to_string()));
154        }
155        Ok(())
156    }
157
158    async fn ttl(&self, key: &str) -> StorageResult<Option<Duration>> {
159        let row: Option<Option<DateTime<Utc>>> = sqlx::query_scalar(
160            "SELECT expire_at FROM sa_token_storage WHERE key = $1",
161        )
162        .bind(key)
163        .fetch_optional(&self.pool)
164        .await
165        .map_err(|e| StorageError::OperationFailed(e.to_string()))?;
166
167        match row {
168            None => Ok(None),
169            Some(None) => Ok(None),
170            Some(Some(expire_at)) if Self::is_expired(Some(expire_at)) => {
171                self.delete_expired(key).await?;
172                Ok(None)
173            }
174            Some(Some(expire_at)) => {
175                let remaining = expire_at.signed_duration_since(Utc::now());
176                if remaining.num_milliseconds() <= 0 {
177                    Ok(Some(Duration::ZERO))
178                } else {
179                    Ok(Some(
180                        remaining
181                            .to_std()
182                            .unwrap_or(Duration::ZERO),
183                    ))
184                }
185            }
186        }
187    }
188
189    async fn clear(&self) -> StorageResult<()> {
190        sqlx::query("TRUNCATE TABLE sa_token_storage")
191            .execute(&self.pool)
192            .await
193            .map_err(|e| StorageError::OperationFailed(e.to_string()))?;
194        Ok(())
195    }
196
197    async fn keys(&self, pattern: &str) -> StorageResult<Vec<String>> {
198        let like_pattern = pattern_to_like(pattern);
199        let rows: Vec<(String, Option<DateTime<Utc>>)> = sqlx::query_as(
200            "SELECT key, expire_at FROM sa_token_storage WHERE key LIKE $1 ESCAPE '\\'",
201        )
202        .bind(like_pattern)
203        .fetch_all(&self.pool)
204        .await
205        .map_err(|e| StorageError::OperationFailed(e.to_string()))?;
206
207        let mut keys = Vec::new();
208        for (key, expire_at) in rows {
209            if Self::is_expired(expire_at) {
210                self.delete_expired(&key).await?;
211            } else {
212                keys.push(key);
213            }
214        }
215        Ok(keys)
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn like_pattern_escapes_wildcards() {
225        assert_eq!(pattern_to_like("sa:token:*"), "sa:token:%");
226        assert_eq!(pattern_to_like("a%b_c"), "a\\%b\\_c");
227    }
228}
229
230#[cfg(all(test, feature = "postgres"))]
231mod postgres_tests {
232    use super::*;
233
234    fn database_url() -> Option<String> {
235        std::env::var("DATABASE_URL").ok()
236    }
237
238    #[tokio::test]
239    #[ignore = "requires PostgreSQL (set DATABASE_URL)"]
240    async fn database_storage_roundtrip() {
241        let Some(url) = database_url() else {
242            return;
243        };
244        let storage = DatabaseStorage::new(&url).await.expect("connect");
245        storage.set("sa:test:1", "v1", Some(Duration::from_secs(60))).await.unwrap();
246        assert_eq!(storage.get("sa:test:1").await.unwrap(), Some("v1".into()));
247        assert!(storage.exists("sa:test:1").await.unwrap());
248        let ttl = storage.ttl("sa:test:1").await.unwrap();
249        assert!(ttl.is_some());
250        let keys = storage.keys("sa:test:*").await.unwrap();
251        assert!(keys.contains(&"sa:test:1".to_string()));
252        storage.delete("sa:test:1").await.unwrap();
253        assert!(!storage.exists("sa:test:1").await.unwrap());
254    }
255}