Skip to main content

solid_pod_rs/storage/
memory.rs

1//! In-memory storage backend.
2//!
3//! Designed for tests. The state lives in a single
4//! `Arc<RwLock<HashMap<String, (Bytes, ResourceMeta)>>>`. Change
5//! events are broadcast to all registered watchers; a watcher only
6//! receives events for paths that are equal to, or descend from, the
7//! path it was registered with.
8
9use std::collections::HashMap;
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use bytes::Bytes;
14use sha2::{Digest, Sha256};
15use tokio::sync::{broadcast, mpsc, RwLock};
16
17use crate::error::PodError;
18use crate::storage::{ResourceMeta, Storage, StorageEvent};
19
20/// In-memory `Storage` implementation.
21#[derive(Clone)]
22pub struct MemoryBackend {
23    inner: Arc<Inner>,
24}
25
26struct Inner {
27    data: RwLock<HashMap<String, Entry>>,
28    events: broadcast::Sender<StorageEvent>,
29}
30
31#[derive(Clone)]
32struct Entry {
33    body: Bytes,
34    meta: ResourceMeta,
35}
36
37impl Default for MemoryBackend {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl MemoryBackend {
44    /// Create a new empty backend.
45    pub fn new() -> Self {
46        let (events, _) = broadcast::channel(256);
47        Self {
48            inner: Arc::new(Inner {
49                data: RwLock::new(HashMap::new()),
50                events,
51            }),
52        }
53    }
54
55    fn compute_etag(body: &[u8]) -> String {
56        hex::encode(Sha256::digest(body))
57    }
58
59    fn normalize(path: &str) -> String {
60        if path.is_empty() {
61            "/".into()
62        } else if path.starts_with('/') {
63            path.to_string()
64        } else {
65            format!("/{path}")
66        }
67    }
68
69    fn is_under(child: &str, container: &str) -> bool {
70        if container == "/" {
71            return child != "/";
72        }
73        let c = container.trim_end_matches('/');
74        child == c || child.starts_with(&format!("{c}/"))
75    }
76}
77
78#[async_trait]
79impl Storage for MemoryBackend {
80    async fn get(&self, path: &str) -> Result<(Bytes, ResourceMeta), PodError> {
81        let path = Self::normalize(path);
82        let guard = self.inner.data.read().await;
83        guard
84            .get(&path)
85            .map(|e| (e.body.clone(), e.meta.clone()))
86            .ok_or(PodError::NotFound(path))
87    }
88
89    async fn put(
90        &self,
91        path: &str,
92        body: Bytes,
93        content_type: &str,
94    ) -> Result<ResourceMeta, PodError> {
95        let path = Self::normalize(path);
96        let etag = Self::compute_etag(&body);
97        let meta = ResourceMeta {
98            etag,
99            modified: chrono::Utc::now(),
100            size: body.len() as u64,
101            content_type: content_type.to_string(),
102            links: Vec::new(),
103        };
104        let mut guard = self.inner.data.write().await;
105        let existed = guard.contains_key(&path);
106        guard.insert(
107            path.clone(),
108            Entry {
109                body,
110                meta: meta.clone(),
111            },
112        );
113        drop(guard);
114        let event = if existed {
115            StorageEvent::Updated(path)
116        } else {
117            StorageEvent::Created(path)
118        };
119        let _ = self.inner.events.send(event);
120        Ok(meta)
121    }
122
123    async fn delete(&self, path: &str) -> Result<(), PodError> {
124        let path = Self::normalize(path);
125        let mut guard = self.inner.data.write().await;
126        match guard.remove(&path) {
127            Some(_) => {
128                drop(guard);
129                let _ = self.inner.events.send(StorageEvent::Deleted(path));
130                Ok(())
131            }
132            None => Err(PodError::NotFound(path)),
133        }
134    }
135
136    async fn list(&self, container: &str) -> Result<Vec<String>, PodError> {
137        let container = Self::normalize(container);
138        let container = if container.ends_with('/') {
139            container
140        } else {
141            format!("{container}/")
142        };
143        let guard = self.inner.data.read().await;
144        let mut seen = std::collections::BTreeSet::new();
145        for key in guard.keys() {
146            if !key.starts_with(&container) {
147                continue;
148            }
149            let remainder = &key[container.len()..];
150            if remainder.is_empty() {
151                continue;
152            }
153            let name = match remainder.find('/') {
154                Some(pos) => &remainder[..=pos],
155                None => remainder,
156            };
157            seen.insert(name.to_string());
158        }
159        Ok(seen.into_iter().collect())
160    }
161
162    async fn head(&self, path: &str) -> Result<ResourceMeta, PodError> {
163        let path = Self::normalize(path);
164        let guard = self.inner.data.read().await;
165        guard
166            .get(&path)
167            .map(|e| e.meta.clone())
168            .ok_or(PodError::NotFound(path))
169    }
170
171    async fn exists(&self, path: &str) -> Result<bool, PodError> {
172        let path = Self::normalize(path);
173        let guard = self.inner.data.read().await;
174        Ok(guard.contains_key(&path))
175    }
176
177    async fn create_container(&self, path: &str) -> Result<ResourceMeta, PodError> {
178        let container = Self::normalize(path);
179        let container = if container.ends_with('/') {
180            container
181        } else {
182            format!("{container}/")
183        };
184        let meta = ResourceMeta::new("container", 0, "application/ld+json");
185        let mut guard = self.inner.data.write().await;
186        guard.insert(
187            container.clone(),
188            Entry {
189                body: Bytes::new(),
190                meta: meta.clone(),
191            },
192        );
193        drop(guard);
194        let _ = self.inner.events.send(StorageEvent::Created(container));
195        Ok(meta)
196    }
197
198    async fn watch(&self, path: &str) -> Result<mpsc::Receiver<StorageEvent>, PodError> {
199        let filter_path = Self::normalize(path);
200        let mut rx = self.inner.events.subscribe();
201        let (tx, out_rx) = mpsc::channel(64);
202        tokio::spawn(async move {
203            while let Ok(event) = rx.recv().await {
204                let target = match &event {
205                    StorageEvent::Created(p)
206                    | StorageEvent::Updated(p)
207                    | StorageEvent::Deleted(p) => p.clone(),
208                };
209                if MemoryBackend::is_under(&target, &filter_path)
210                    && tx.send(event).await.is_err()
211                {
212                    return;
213                }
214            }
215        });
216        Ok(out_rx)
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[tokio::test]
225    async fn put_get_roundtrip() {
226        let m = MemoryBackend::new();
227        m.put("/foo", Bytes::from_static(b"hello"), "text/plain")
228            .await
229            .unwrap();
230        let (body, meta) = m.get("/foo").await.unwrap();
231        assert_eq!(&body[..], b"hello");
232        assert_eq!(meta.size, 5);
233        assert_eq!(meta.content_type, "text/plain");
234    }
235
236    #[tokio::test]
237    async fn list_direct_children_only() {
238        let m = MemoryBackend::new();
239        m.put("/a/b", Bytes::from_static(b""), "text/plain")
240            .await
241            .unwrap();
242        m.put("/a/c/d", Bytes::from_static(b""), "text/plain")
243            .await
244            .unwrap();
245        let mut items = m.list("/a").await.unwrap();
246        items.sort();
247        assert_eq!(items, vec!["b".to_string(), "c/".to_string()]);
248    }
249
250    #[tokio::test]
251    async fn watch_receives_created_event() {
252        let m = MemoryBackend::new();
253        let mut rx = m.watch("/").await.unwrap();
254        m.put("/x", Bytes::from_static(b"hi"), "text/plain")
255            .await
256            .unwrap();
257        let event = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
258            .await
259            .unwrap()
260            .unwrap();
261        assert_eq!(event, StorageEvent::Created("/x".into()));
262    }
263}