Skip to main content

solid_pod_rs/storage/
memory.rs

1//! In-memory storage backend.
2//!
3//! Designed for tests. The state lives in a single
4//! `Arc<RwLock<HashMap<String, (Bytes, ResourceMeta)>>>`. Change
5//! events are broadcast to all registered watchers; a watcher only
6//! receives events for paths that are equal to, or descend from, the
7//! path it was registered with.
8
9use std::collections::HashMap;
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use bytes::Bytes;
14use sha2::{Digest, Sha256};
15use tokio::sync::{broadcast, mpsc, RwLock};
16
17use crate::error::PodError;
18use crate::storage::{ResourceMeta, Storage, StorageEvent};
19
20/// In-memory `Storage` implementation.
21#[derive(Clone)]
22pub struct MemoryBackend {
23    inner: Arc<Inner>,
24}
25
26struct Inner {
27    data: RwLock<HashMap<String, Entry>>,
28    events: broadcast::Sender<StorageEvent>,
29}
30
31#[derive(Clone)]
32struct Entry {
33    body: Bytes,
34    meta: ResourceMeta,
35}
36
37impl Default for MemoryBackend {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl MemoryBackend {
44    /// Create a new empty backend.
45    pub fn new() -> Self {
46        let (events, _) = broadcast::channel(256);
47        Self {
48            inner: Arc::new(Inner {
49                data: RwLock::new(HashMap::new()),
50                events,
51            }),
52        }
53    }
54
55    fn compute_etag(body: &[u8]) -> String {
56        hex::encode(Sha256::digest(body))
57    }
58
59    fn normalize(path: &str) -> String {
60        if path.is_empty() {
61            "/".into()
62        } else if path.starts_with('/') {
63            path.to_string()
64        } else {
65            format!("/{path}")
66        }
67    }
68
69    fn is_under(child: &str, container: &str) -> bool {
70        if container == "/" {
71            return child != "/";
72        }
73        let c = container.trim_end_matches('/');
74        child == c || child.starts_with(&format!("{c}/"))
75    }
76}
77
78#[async_trait]
79impl Storage for MemoryBackend {
80    async fn get(&self, path: &str) -> Result<(Bytes, ResourceMeta), PodError> {
81        let path = Self::normalize(path);
82        let guard = self.inner.data.read().await;
83        guard
84            .get(&path)
85            .map(|e| (e.body.clone(), e.meta.clone()))
86            .ok_or(PodError::NotFound(path))
87    }
88
89    async fn put(
90        &self,
91        path: &str,
92        body: Bytes,
93        content_type: &str,
94    ) -> Result<ResourceMeta, PodError> {
95        let path = Self::normalize(path);
96        let etag = Self::compute_etag(&body);
97        let meta = ResourceMeta {
98            etag,
99            modified: chrono::Utc::now(),
100            size: body.len() as u64,
101            content_type: content_type.to_string(),
102            links: Vec::new(),
103        };
104        let mut guard = self.inner.data.write().await;
105        let existed = guard.contains_key(&path);
106        guard.insert(
107            path.clone(),
108            Entry {
109                body,
110                meta: meta.clone(),
111            },
112        );
113        drop(guard);
114        let event = if existed {
115            StorageEvent::Updated(path)
116        } else {
117            StorageEvent::Created(path)
118        };
119        let _ = self.inner.events.send(event);
120        Ok(meta)
121    }
122
123    async fn delete(&self, path: &str) -> Result<(), PodError> {
124        let path = Self::normalize(path);
125        let mut guard = self.inner.data.write().await;
126        match guard.remove(&path) {
127            Some(_) => {
128                drop(guard);
129                let _ = self.inner.events.send(StorageEvent::Deleted(path));
130                Ok(())
131            }
132            None => Err(PodError::NotFound(path)),
133        }
134    }
135
136    async fn list(&self, container: &str) -> Result<Vec<String>, PodError> {
137        let container = Self::normalize(container);
138        let container = if container.ends_with('/') {
139            container
140        } else {
141            format!("{container}/")
142        };
143        let guard = self.inner.data.read().await;
144        let mut seen = std::collections::BTreeSet::new();
145        for key in guard.keys() {
146            if !key.starts_with(&container) {
147                continue;
148            }
149            let remainder = &key[container.len()..];
150            if remainder.is_empty() {
151                continue;
152            }
153            let name = match remainder.find('/') {
154                Some(pos) => &remainder[..=pos],
155                None => remainder,
156            };
157            seen.insert(name.to_string());
158        }
159        Ok(seen.into_iter().collect())
160    }
161
162    async fn head(&self, path: &str) -> Result<ResourceMeta, PodError> {
163        let path = Self::normalize(path);
164        let guard = self.inner.data.read().await;
165        guard
166            .get(&path)
167            .map(|e| e.meta.clone())
168            .ok_or(PodError::NotFound(path))
169    }
170
171    async fn exists(&self, path: &str) -> Result<bool, PodError> {
172        let path = Self::normalize(path);
173        let guard = self.inner.data.read().await;
174        Ok(guard.contains_key(&path))
175    }
176
177    async fn watch(&self, path: &str) -> Result<mpsc::Receiver<StorageEvent>, PodError> {
178        let filter_path = Self::normalize(path);
179        let mut rx = self.inner.events.subscribe();
180        let (tx, out_rx) = mpsc::channel(64);
181        tokio::spawn(async move {
182            while let Ok(event) = rx.recv().await {
183                let target = match &event {
184                    StorageEvent::Created(p)
185                    | StorageEvent::Updated(p)
186                    | StorageEvent::Deleted(p) => p.clone(),
187                };
188                if MemoryBackend::is_under(&target, &filter_path)
189                    && tx.send(event).await.is_err()
190                {
191                    return;
192                }
193            }
194        });
195        Ok(out_rx)
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[tokio::test]
204    async fn put_get_roundtrip() {
205        let m = MemoryBackend::new();
206        m.put("/foo", Bytes::from_static(b"hello"), "text/plain")
207            .await
208            .unwrap();
209        let (body, meta) = m.get("/foo").await.unwrap();
210        assert_eq!(&body[..], b"hello");
211        assert_eq!(meta.size, 5);
212        assert_eq!(meta.content_type, "text/plain");
213    }
214
215    #[tokio::test]
216    async fn list_direct_children_only() {
217        let m = MemoryBackend::new();
218        m.put("/a/b", Bytes::from_static(b""), "text/plain")
219            .await
220            .unwrap();
221        m.put("/a/c/d", Bytes::from_static(b""), "text/plain")
222            .await
223            .unwrap();
224        let mut items = m.list("/a").await.unwrap();
225        items.sort();
226        assert_eq!(items, vec!["b".to_string(), "c/".to_string()]);
227    }
228
229    #[tokio::test]
230    async fn watch_receives_created_event() {
231        let m = MemoryBackend::new();
232        let mut rx = m.watch("/").await.unwrap();
233        m.put("/x", Bytes::from_static(b"hi"), "text/plain")
234            .await
235            .unwrap();
236        let event = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
237            .await
238            .unwrap()
239            .unwrap();
240        assert_eq!(event, StorageEvent::Created("/x".into()));
241    }
242}