1use async_trait::async_trait;
4use serde::{de::DeserializeOwned, Serialize};
5use std::fmt::Debug;
6
7#[derive(Debug, thiserror::Error)]
9pub enum StorageError {
10 #[error("Not found: {0}")]
11 NotFound(String),
12
13 #[error("Already exists: {0}")]
14 AlreadyExists(String),
15
16 #[error("Serialization error: {0}")]
17 Serialization(String),
18
19 #[error("Connection error: {0}")]
20 Connection(String),
21
22 #[error("Query error: {0}")]
23 Query(String),
24
25 #[error("Internal error: {0}")]
26 Internal(String),
27}
28
29#[async_trait]
31pub trait StorageBackend: Send + Sync + Debug {
32 fn name(&self) -> &str;
34
35 fn as_any(&self) -> &dyn std::any::Any;
37
38 async fn is_healthy(&self) -> bool;
40
41 async fn set_value(&self, key: &str, value: serde_json::Value) -> Result<(), StorageError>;
43
44 async fn get_value(&self, key: &str) -> Result<Option<serde_json::Value>, StorageError>;
46
47 async fn delete(&self, key: &str) -> Result<bool, StorageError>;
49
50 async fn exists(&self, key: &str) -> Result<bool, StorageError>;
52
53 async fn list_keys(&self, prefix: &str) -> Result<Vec<String>, StorageError>;
55}
56
57#[async_trait]
59pub trait StorageExt {
60 async fn set<T: Serialize + Send + Sync>(
61 &self,
62 key: &str,
63 value: &T,
64 ) -> Result<(), StorageError>;
65 async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, StorageError>;
66}
67
68#[async_trait]
69impl<S: StorageBackend + ?Sized> StorageExt for S {
70 async fn set<T: Serialize + Send + Sync>(
71 &self,
72 key: &str,
73 value: &T,
74 ) -> Result<(), StorageError> {
75 let json =
76 serde_json::to_value(value).map_err(|e| StorageError::Serialization(e.to_string()))?;
77 self.set_value(key, json).await
78 }
79
80 async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, StorageError> {
81 match self.get_value(key).await? {
82 Some(json) => {
83 let value = serde_json::from_value(json)
84 .map_err(|e| StorageError::Serialization(e.to_string()))?;
85 Ok(Some(value))
86 }
87 None => Ok(None),
88 }
89 }
90}
91
92#[derive(Debug, Default)]
94pub struct MemoryBackend {
95 data: tokio::sync::RwLock<std::collections::HashMap<String, serde_json::Value>>,
96}
97
98impl MemoryBackend {
99 pub fn new() -> Self {
100 Self::default()
101 }
102}
103
104#[async_trait]
105impl StorageBackend for MemoryBackend {
106 fn name(&self) -> &str {
107 "memory"
108 }
109
110 fn as_any(&self) -> &dyn std::any::Any {
111 self
112 }
113
114 async fn is_healthy(&self) -> bool {
115 true
116 }
117
118 async fn set_value(&self, key: &str, value: serde_json::Value) -> Result<(), StorageError> {
119 self.data.write().await.insert(key.to_string(), value);
120 Ok(())
121 }
122
123 async fn get_value(&self, key: &str) -> Result<Option<serde_json::Value>, StorageError> {
124 let data = self.data.read().await;
125 Ok(data.get(key).cloned())
126 }
127
128 async fn delete(&self, key: &str) -> Result<bool, StorageError> {
129 Ok(self.data.write().await.remove(key).is_some())
130 }
131
132 async fn exists(&self, key: &str) -> Result<bool, StorageError> {
133 Ok(self.data.read().await.contains_key(key))
134 }
135
136 async fn list_keys(&self, prefix: &str) -> Result<Vec<String>, StorageError> {
137 let data = self.data.read().await;
138 let keys: Vec<String> = data
139 .keys()
140 .filter(|k| k.starts_with(prefix))
141 .cloned()
142 .collect();
143 Ok(keys)
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use serde::{Deserialize, Serialize};
151 #[derive(Debug, Serialize, Deserialize, PartialEq)]
155 struct TestData {
156 name: String,
157 value: i32,
158 }
159
160 #[tokio::test]
161 async fn test_memory_backend() {
162 let backend = MemoryBackend::new();
163
164 let data = TestData {
165 name: "test".to_string(),
166 value: 42,
167 };
168
169 backend.set("test:1", &data).await.unwrap();
171
172 let retrieved: Option<TestData> = backend.get("test:1").await.unwrap();
174 assert_eq!(retrieved, Some(data));
175
176 assert!(backend.exists("test:1").await.unwrap());
178 assert!(!backend.exists("test:2").await.unwrap());
179
180 let keys = backend.list_keys("test:").await.unwrap();
182 assert_eq!(keys, vec!["test:1"]);
183
184 assert!(backend.delete("test:1").await.unwrap());
186 assert!(!backend.exists("test:1").await.unwrap());
187 }
188}