rustvello_sqlite/
client_data_store.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use rustvello_core::client_data_store::ClientDataStore;
8use rustvello_core::error::{RustvelloError, RustvelloResult};
9
10use crate::db::{blocking, lock_err, sql_err, Database};
11
12pub struct SqliteClientDataStore {
17 db: Arc<Database>,
18}
19
20impl SqliteClientDataStore {
21 pub fn new(db: Arc<Database>) -> Self {
22 Self { db }
23 }
24}
25
26#[async_trait]
27impl ClientDataStore for SqliteClientDataStore {
28 async fn store(&self, key: &str, value: &str) -> RustvelloResult<()> {
29 let db = Arc::clone(&self.db);
30 let key = key.to_owned();
31 let value = value.to_owned();
32 blocking(move || {
33 let conn = db.conn.lock().map_err(lock_err)?;
34 conn.execute(
35 "INSERT OR REPLACE INTO client_data (data_key, data_value) VALUES (?1, ?2)",
36 [&key, &value],
37 )
38 .map_err(sql_err)?;
39 Ok(())
40 })
41 .await
42 }
43
44 async fn retrieve(&self, key: &str) -> RustvelloResult<String> {
45 let db = Arc::clone(&self.db);
46 let key = key.to_owned();
47 blocking(move || {
48 let conn = db.conn.lock().map_err(lock_err)?;
49 conn.query_row(
50 "SELECT data_value FROM client_data WHERE data_key = ?1",
51 [&key],
52 |row| row.get(0),
53 )
54 .map_err(|e| match e {
55 rusqlite::Error::QueryReturnedNoRows => {
56 RustvelloError::state_backend(format!("key not found: {key}"))
57 }
58 other => sql_err(other),
59 })
60 })
61 .await
62 }
63
64 async fn purge(&self) -> RustvelloResult<()> {
65 let db = Arc::clone(&self.db);
66 blocking(move || {
67 let conn = db.conn.lock().map_err(lock_err)?;
68 conn.execute("DELETE FROM client_data", [])
69 .map_err(sql_err)?;
70 Ok(())
71 })
72 .await
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79
80 fn make_store() -> SqliteClientDataStore {
81 let db = Arc::new(Database::in_memory().unwrap());
82 SqliteClientDataStore::new(db)
83 }
84
85 #[tokio::test]
86 async fn store_and_retrieve() {
87 let store = make_store();
88 store.store("k1", "v1").await.unwrap();
89 assert_eq!(store.retrieve("k1").await.unwrap(), "v1");
90 }
91
92 #[tokio::test]
93 async fn retrieve_missing_key_errors() {
94 let store = make_store();
95 let err = store.retrieve("nonexistent").await;
96 assert!(err.is_err());
97 }
98
99 #[tokio::test]
100 async fn purge_removes_all() {
101 let store = make_store();
102 store.store("k1", "v1").await.unwrap();
103 store.store("k2", "v2").await.unwrap();
104 store.purge().await.unwrap();
105 assert!(store.retrieve("k1").await.is_err());
106 assert!(store.retrieve("k2").await.is_err());
107 }
108
109 #[tokio::test]
110 async fn upsert_semantics() {
111 let store = make_store();
112 store.store("k1", "original").await.unwrap();
113 store.store("k1", "updated").await.unwrap();
114 assert_eq!(store.retrieve("k1").await.unwrap(), "updated");
115 }
116}