solid_pod_rs/storage/
memory.rs1use std::collections::HashMap;
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use bytes::Bytes;
14use sha2::{Digest, Sha256};
15use tokio::sync::{broadcast, mpsc, RwLock};
16
17use crate::error::PodError;
18use crate::storage::{ResourceMeta, Storage, StorageEvent};
19
20#[derive(Clone)]
22pub struct MemoryBackend {
23 inner: Arc<Inner>,
24}
25
26struct Inner {
27 data: RwLock<HashMap<String, Entry>>,
28 events: broadcast::Sender<StorageEvent>,
29}
30
31#[derive(Clone)]
32struct Entry {
33 body: Bytes,
34 meta: ResourceMeta,
35}
36
37impl Default for MemoryBackend {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl MemoryBackend {
44 pub fn new() -> Self {
46 let (events, _) = broadcast::channel(256);
47 Self {
48 inner: Arc::new(Inner {
49 data: RwLock::new(HashMap::new()),
50 events,
51 }),
52 }
53 }
54
55 fn compute_etag(body: &[u8]) -> String {
56 hex::encode(Sha256::digest(body))
57 }
58
59 fn normalize(path: &str) -> String {
60 if path.is_empty() {
61 "/".into()
62 } else if path.starts_with('/') {
63 path.to_string()
64 } else {
65 format!("/{path}")
66 }
67 }
68
69 fn is_under(child: &str, container: &str) -> bool {
70 if container == "/" {
71 return child != "/";
72 }
73 let c = container.trim_end_matches('/');
74 child == c || child.starts_with(&format!("{c}/"))
75 }
76}
77
78#[async_trait]
79impl Storage for MemoryBackend {
80 async fn get(&self, path: &str) -> Result<(Bytes, ResourceMeta), PodError> {
81 let path = Self::normalize(path);
82 let guard = self.inner.data.read().await;
83 guard
84 .get(&path)
85 .map(|e| (e.body.clone(), e.meta.clone()))
86 .ok_or(PodError::NotFound(path))
87 }
88
89 async fn put(
90 &self,
91 path: &str,
92 body: Bytes,
93 content_type: &str,
94 ) -> Result<ResourceMeta, PodError> {
95 let path = Self::normalize(path);
96 let etag = Self::compute_etag(&body);
97 let meta = ResourceMeta {
98 etag,
99 modified: chrono::Utc::now(),
100 size: body.len() as u64,
101 content_type: content_type.to_string(),
102 links: Vec::new(),
103 };
104 let mut guard = self.inner.data.write().await;
105 let existed = guard.contains_key(&path);
106 guard.insert(
107 path.clone(),
108 Entry {
109 body,
110 meta: meta.clone(),
111 },
112 );
113 drop(guard);
114 let event = if existed {
115 StorageEvent::Updated(path)
116 } else {
117 StorageEvent::Created(path)
118 };
119 let _ = self.inner.events.send(event);
120 Ok(meta)
121 }
122
123 async fn delete(&self, path: &str) -> Result<(), PodError> {
124 let path = Self::normalize(path);
125 let mut guard = self.inner.data.write().await;
126 match guard.remove(&path) {
127 Some(_) => {
128 drop(guard);
129 let _ = self.inner.events.send(StorageEvent::Deleted(path));
130 Ok(())
131 }
132 None => Err(PodError::NotFound(path)),
133 }
134 }
135
136 async fn list(&self, container: &str) -> Result<Vec<String>, PodError> {
137 let container = Self::normalize(container);
138 let container = if container.ends_with('/') {
139 container
140 } else {
141 format!("{container}/")
142 };
143 let guard = self.inner.data.read().await;
144 let mut seen = std::collections::BTreeSet::new();
145 for key in guard.keys() {
146 if !key.starts_with(&container) {
147 continue;
148 }
149 let remainder = &key[container.len()..];
150 if remainder.is_empty() {
151 continue;
152 }
153 let name = match remainder.find('/') {
154 Some(pos) => &remainder[..=pos],
155 None => remainder,
156 };
157 seen.insert(name.to_string());
158 }
159 Ok(seen.into_iter().collect())
160 }
161
162 async fn head(&self, path: &str) -> Result<ResourceMeta, PodError> {
163 let path = Self::normalize(path);
164 let guard = self.inner.data.read().await;
165 guard
166 .get(&path)
167 .map(|e| e.meta.clone())
168 .ok_or(PodError::NotFound(path))
169 }
170
171 async fn exists(&self, path: &str) -> Result<bool, PodError> {
172 let path = Self::normalize(path);
173 let guard = self.inner.data.read().await;
174 Ok(guard.contains_key(&path))
175 }
176
177 async fn create_container(&self, path: &str) -> Result<ResourceMeta, PodError> {
178 let container = Self::normalize(path);
179 let container = if container.ends_with('/') {
180 container
181 } else {
182 format!("{container}/")
183 };
184 let meta = ResourceMeta::new("container", 0, "application/ld+json");
185 let mut guard = self.inner.data.write().await;
186 guard.insert(
187 container.clone(),
188 Entry {
189 body: Bytes::new(),
190 meta: meta.clone(),
191 },
192 );
193 drop(guard);
194 let _ = self.inner.events.send(StorageEvent::Created(container));
195 Ok(meta)
196 }
197
198 async fn watch(&self, path: &str) -> Result<mpsc::Receiver<StorageEvent>, PodError> {
199 let filter_path = Self::normalize(path);
200 let mut rx = self.inner.events.subscribe();
201 let (tx, out_rx) = mpsc::channel(64);
202 tokio::spawn(async move {
203 while let Ok(event) = rx.recv().await {
204 let target = match &event {
205 StorageEvent::Created(p)
206 | StorageEvent::Updated(p)
207 | StorageEvent::Deleted(p) => p.clone(),
208 };
209 if MemoryBackend::is_under(&target, &filter_path)
210 && tx.send(event).await.is_err()
211 {
212 return;
213 }
214 }
215 });
216 Ok(out_rx)
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[tokio::test]
225 async fn put_get_roundtrip() {
226 let m = MemoryBackend::new();
227 m.put("/foo", Bytes::from_static(b"hello"), "text/plain")
228 .await
229 .unwrap();
230 let (body, meta) = m.get("/foo").await.unwrap();
231 assert_eq!(&body[..], b"hello");
232 assert_eq!(meta.size, 5);
233 assert_eq!(meta.content_type, "text/plain");
234 }
235
236 #[tokio::test]
237 async fn list_direct_children_only() {
238 let m = MemoryBackend::new();
239 m.put("/a/b", Bytes::from_static(b""), "text/plain")
240 .await
241 .unwrap();
242 m.put("/a/c/d", Bytes::from_static(b""), "text/plain")
243 .await
244 .unwrap();
245 let mut items = m.list("/a").await.unwrap();
246 items.sort();
247 assert_eq!(items, vec!["b".to_string(), "c/".to_string()]);
248 }
249
250 #[tokio::test]
251 async fn watch_receives_created_event() {
252 let m = MemoryBackend::new();
253 let mut rx = m.watch("/").await.unwrap();
254 m.put("/x", Bytes::from_static(b"hi"), "text/plain")
255 .await
256 .unwrap();
257 let event = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
258 .await
259 .unwrap()
260 .unwrap();
261 assert_eq!(event, StorageEvent::Created("/x".into()));
262 }
263}