sa_token_storage_database/
lib.rs1use std::time::Duration;
12
13use async_trait::async_trait;
14use chrono::{DateTime, Utc};
15use sa_token_adapter::storage::{SaStorage, StorageError, StorageResult};
16use sqlx::{Pool, Postgres};
17
18#[derive(Clone)]
20pub struct DatabaseStorage {
21 pool: Pool<Postgres>,
22}
23
24impl DatabaseStorage {
25 pub async fn new(database_url: &str) -> StorageResult<Self> {
27 let pool = Pool::<Postgres>::connect(database_url)
28 .await
29 .map_err(|e| StorageError::ConnectionError(e.to_string()))?;
30
31 let storage = Self { pool };
32 storage.migrate().await?;
33 Ok(storage)
34 }
35
36 pub fn from_pool(pool: Pool<Postgres>) -> Self {
38 Self { pool }
39 }
40
41 pub async fn migrate(&self) -> StorageResult<()> {
43 let ddl = include_str!("../migrations/001_sa_token_storage.sql");
44 for statement in ddl.split(';').map(str::trim).filter(|s| !s.is_empty()) {
45 sqlx::query(statement)
46 .execute(&self.pool)
47 .await
48 .map_err(|e| StorageError::OperationFailed(e.to_string()))?;
49 }
50 Ok(())
51 }
52
53 async fn delete_expired(&self, key: &str) -> StorageResult<()> {
54 sqlx::query("DELETE FROM sa_token_storage WHERE key = $1")
55 .bind(key)
56 .execute(&self.pool)
57 .await
58 .map_err(|e| StorageError::OperationFailed(e.to_string()))?;
59 Ok(())
60 }
61
62 fn is_expired(expire_at: Option<DateTime<Utc>>) -> bool {
63 expire_at.is_some_and(|t| Utc::now() > t)
64 }
65}
66
67pub fn pattern_to_like(pattern: &str) -> String {
69 let mut out = String::with_capacity(pattern.len());
70 for ch in pattern.chars() {
71 match ch {
72 '*' => out.push('%'),
73 '%' | '_' | '\\' => {
74 out.push('\\');
75 out.push(ch);
76 }
77 other => out.push(other),
78 }
79 }
80 out
81}
82
83#[async_trait]
84impl SaStorage for DatabaseStorage {
85 async fn get(&self, key: &str) -> StorageResult<Option<String>> {
86 let row: Option<(String, Option<DateTime<Utc>>)> = sqlx::query_as(
87 "SELECT value, expire_at FROM sa_token_storage WHERE key = $1",
88 )
89 .bind(key)
90 .fetch_optional(&self.pool)
91 .await
92 .map_err(|e| StorageError::OperationFailed(e.to_string()))?;
93
94 match row {
95 Some((_value, expire_at)) if Self::is_expired(expire_at) => {
96 self.delete_expired(key).await?;
97 Ok(None)
98 }
99 Some((value, _)) => Ok(Some(value)),
100 None => Ok(None),
101 }
102 }
103
104 async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> StorageResult<()> {
105 let expire_at: Option<DateTime<Utc>> = ttl.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap());
106
107 sqlx::query(
108 r#"
109 INSERT INTO sa_token_storage (key, value, expire_at, updated_at)
110 VALUES ($1, $2, $3, NOW())
111 ON CONFLICT (key) DO UPDATE
112 SET value = EXCLUDED.value,
113 expire_at = EXCLUDED.expire_at,
114 updated_at = NOW()
115 "#,
116 )
117 .bind(key)
118 .bind(value)
119 .bind(expire_at)
120 .execute(&self.pool)
121 .await
122 .map_err(|e| StorageError::OperationFailed(e.to_string()))?;
123
124 Ok(())
125 }
126
127 async fn delete(&self, key: &str) -> StorageResult<()> {
128 sqlx::query("DELETE FROM sa_token_storage WHERE key = $1")
129 .bind(key)
130 .execute(&self.pool)
131 .await
132 .map_err(|e| StorageError::OperationFailed(e.to_string()))?;
133 Ok(())
134 }
135
136 async fn exists(&self, key: &str) -> StorageResult<bool> {
137 Ok(self.get(key).await?.is_some())
138 }
139
140 async fn expire(&self, key: &str, ttl: Duration) -> StorageResult<()> {
141 let expire_at = Utc::now() + chrono::Duration::from_std(ttl).unwrap();
142 let updated = sqlx::query(
143 "UPDATE sa_token_storage SET expire_at = $1, updated_at = NOW() WHERE key = $2",
144 )
145 .bind(expire_at)
146 .bind(key)
147 .execute(&self.pool)
148 .await
149 .map_err(|e| StorageError::OperationFailed(e.to_string()))?
150 .rows_affected();
151
152 if updated == 0 {
153 return Err(StorageError::KeyNotFound(key.to_string()));
154 }
155 Ok(())
156 }
157
158 async fn ttl(&self, key: &str) -> StorageResult<Option<Duration>> {
159 let row: Option<Option<DateTime<Utc>>> = sqlx::query_scalar(
160 "SELECT expire_at FROM sa_token_storage WHERE key = $1",
161 )
162 .bind(key)
163 .fetch_optional(&self.pool)
164 .await
165 .map_err(|e| StorageError::OperationFailed(e.to_string()))?;
166
167 match row {
168 None => Ok(None),
169 Some(None) => Ok(None),
170 Some(Some(expire_at)) if Self::is_expired(Some(expire_at)) => {
171 self.delete_expired(key).await?;
172 Ok(None)
173 }
174 Some(Some(expire_at)) => {
175 let remaining = expire_at.signed_duration_since(Utc::now());
176 if remaining.num_milliseconds() <= 0 {
177 Ok(Some(Duration::ZERO))
178 } else {
179 Ok(Some(
180 remaining
181 .to_std()
182 .unwrap_or(Duration::ZERO),
183 ))
184 }
185 }
186 }
187 }
188
189 async fn clear(&self) -> StorageResult<()> {
190 sqlx::query("TRUNCATE TABLE sa_token_storage")
191 .execute(&self.pool)
192 .await
193 .map_err(|e| StorageError::OperationFailed(e.to_string()))?;
194 Ok(())
195 }
196
197 async fn keys(&self, pattern: &str) -> StorageResult<Vec<String>> {
198 let like_pattern = pattern_to_like(pattern);
199 let rows: Vec<(String, Option<DateTime<Utc>>)> = sqlx::query_as(
200 "SELECT key, expire_at FROM sa_token_storage WHERE key LIKE $1 ESCAPE '\\'",
201 )
202 .bind(like_pattern)
203 .fetch_all(&self.pool)
204 .await
205 .map_err(|e| StorageError::OperationFailed(e.to_string()))?;
206
207 let mut keys = Vec::new();
208 for (key, expire_at) in rows {
209 if Self::is_expired(expire_at) {
210 self.delete_expired(&key).await?;
211 } else {
212 keys.push(key);
213 }
214 }
215 Ok(keys)
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn like_pattern_escapes_wildcards() {
225 assert_eq!(pattern_to_like("sa:token:*"), "sa:token:%");
226 assert_eq!(pattern_to_like("a%b_c"), "a\\%b\\_c");
227 }
228}
229
230#[cfg(all(test, feature = "postgres"))]
231mod postgres_tests {
232 use super::*;
233
234 fn database_url() -> Option<String> {
235 std::env::var("DATABASE_URL").ok()
236 }
237
238 #[tokio::test]
239 #[ignore = "requires PostgreSQL (set DATABASE_URL)"]
240 async fn database_storage_roundtrip() {
241 let Some(url) = database_url() else {
242 return;
243 };
244 let storage = DatabaseStorage::new(&url).await.expect("connect");
245 storage.set("sa:test:1", "v1", Some(Duration::from_secs(60))).await.unwrap();
246 assert_eq!(storage.get("sa:test:1").await.unwrap(), Some("v1".into()));
247 assert!(storage.exists("sa:test:1").await.unwrap());
248 let ttl = storage.ttl("sa:test:1").await.unwrap();
249 assert!(ttl.is_some());
250 let keys = storage.keys("sa:test:*").await.unwrap();
251 assert!(keys.contains(&"sa:test:1".to_string()));
252 storage.delete("sa:test:1").await.unwrap();
253 assert!(!storage.exists("sa:test:1").await.unwrap());
254 }
255}