Skip to main content

stac_server/backend/
memory.rs

1use crate::{Backend, DEFAULT_LIMIT, Error, Result};
2use futures_core::Stream;
3use serde_json::Map;
4use stac::api::{
5    CollectionsClient, ItemCollection, ItemsClient, Search, StreamItemsClient, TransactionClient,
6    stream_pages,
7};
8use stac::{Collection, Item};
9use std::{
10    collections::{BTreeMap, HashMap},
11    sync::{Arc, RwLock},
12};
13
14/// A naive backend that stores collections and items in memory.
15///
16/// This backend is meant to be used for testing and toy servers, not for production.
17#[derive(Clone, Debug)]
18pub struct MemoryBackend {
19    collections: Arc<RwLock<BTreeMap<String, Collection>>>,
20    items: Arc<RwLock<HashMap<String, Vec<Item>>>>,
21}
22
23impl MemoryBackend {
24    /// Creates a new memory backend.
25    ///
26    /// # Examples
27    ///
28    /// ```
29    /// use stac_server::MemoryBackend;
30    /// let backend = MemoryBackend::new();
31    /// ```
32    pub fn new() -> MemoryBackend {
33        MemoryBackend {
34            collections: Arc::new(RwLock::new(BTreeMap::new())),
35            items: Arc::new(RwLock::new(HashMap::new())),
36        }
37    }
38}
39
40impl ItemsClient for MemoryBackend {
41    type Error = Error;
42
43    async fn search(&self, mut search: Search) -> Result<ItemCollection> {
44        let items = self.items.read().unwrap();
45        if search.collections.is_empty() {
46            search.collections = items.keys().cloned().collect();
47        }
48        let mut item_references = Vec::new();
49        for collection in &search.collections {
50            if let Some(items) = items.get(collection) {
51                item_references.extend(
52                    items
53                        .iter()
54                        .filter(|item| search.matches(item).unwrap_or_default()),
55                );
56            }
57        }
58        let limit = search.limit.unwrap_or(DEFAULT_LIMIT).try_into()?;
59        let skip = search
60            .additional_fields
61            .get("skip")
62            .and_then(|skip| {
63                skip.as_u64()
64                    .or_else(|| skip.as_str().and_then(|skip| skip.parse::<u64>().ok()))
65            })
66            .unwrap_or_default()
67            .try_into()?;
68        let len = item_references.len();
69        let items = item_references
70            .into_iter()
71            .skip(skip)
72            .take(limit)
73            .map(|item| stac::api::Item::try_from(item.clone()).map_err(Error::from))
74            .collect::<Result<Vec<_>>>()?;
75        let mut item_collection = ItemCollection::new(items)?;
76        if len > item_collection.items.len() + skip {
77            let mut next = Map::new();
78            let _ = next.insert("skip".to_string(), (skip + limit).into());
79            item_collection.next = Some(next);
80        }
81        if skip > 0 {
82            let mut prev = Map::new();
83            let skip = skip.saturating_sub(limit);
84            let _ = prev.insert("skip".to_string(), skip.into());
85            item_collection.prev = Some(prev);
86        }
87        Ok(item_collection)
88    }
89
90    async fn item(&self, collection_id: &str, item_id: &str) -> Result<Option<Item>> {
91        let items = self.items.read().unwrap();
92        Ok(items
93            .get(collection_id)
94            .and_then(|items| items.iter().find(|item| item.id == item_id).cloned()))
95    }
96}
97
98impl CollectionsClient for MemoryBackend {
99    type Error = Error;
100
101    async fn collections(&self) -> Result<Vec<Collection>> {
102        let collections = self.collections.read().unwrap();
103        Ok(collections.values().cloned().collect())
104    }
105
106    async fn collection(&self, id: &str) -> Result<Option<Collection>> {
107        let collections = self.collections.read().unwrap();
108        Ok(collections.get(id).cloned())
109    }
110}
111
112impl TransactionClient for MemoryBackend {
113    type Error = Error;
114
115    async fn add_collection(&mut self, collection: Collection) -> Result<()> {
116        let mut collections = self.collections.write().unwrap();
117        let _ = collections.insert(collection.id.clone(), collection);
118        Ok(())
119    }
120
121    async fn add_item(&mut self, item: Item) -> Result<()> {
122        if let Some(collection_id) = item.collection.clone() {
123            if CollectionsClient::collection(self, &collection_id)
124                .await?
125                .is_none()
126            {
127                Err(Error::MemoryBackend(format!(
128                    "no collection with id='{collection_id}'",
129                )))
130            } else {
131                let mut items = self.items.write().unwrap();
132                items.entry(collection_id).or_default().push(item);
133                Ok(())
134            }
135        } else {
136            Err(Error::MemoryBackend(format!(
137                "collection not set on item: {}",
138                item.id
139            )))
140        }
141    }
142}
143
144impl StreamItemsClient for MemoryBackend {
145    type Error = Error;
146
147    async fn search_stream(
148        &self,
149        search: Search,
150    ) -> Result<impl Stream<Item = std::result::Result<stac::api::Item, Error>> + Send> {
151        let page = ItemsClient::search(self, search.clone()).await?;
152        Ok(stream_pages(self.clone(), search, page))
153    }
154}
155
156impl Backend for MemoryBackend {
157    fn has_item_search(&self) -> bool {
158        true
159    }
160
161    fn has_filter(&self) -> bool {
162        false
163    }
164}
165
166impl Default for MemoryBackend {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use stac::api::{Items, StreamCollectionsClient};
176
177    async fn populated_backend() -> MemoryBackend {
178        let mut backend = MemoryBackend::new();
179        backend
180            .add_collection(Collection::new("collection-id", "a description"))
181            .await
182            .unwrap();
183        backend
184            .add_item(Item::new("item-a").collection("collection-id"))
185            .await
186            .unwrap();
187        backend
188            .add_item(Item::new("item-b").collection("collection-id"))
189            .await
190            .unwrap();
191        backend
192            .add_item(Item::new("item-c").collection("collection-id"))
193            .await
194            .unwrap();
195        backend
196    }
197
198    #[tokio::test]
199    async fn stream_items_across_pages_with_real_backend() {
200        let backend = populated_backend().await;
201        let search = Search::default().limit(1u64);
202        let items = backend.collect_items(search).await.unwrap();
203        assert_eq!(items.len(), 3);
204    }
205
206    #[tokio::test]
207    async fn item_count_uses_streaming_path() {
208        let backend = populated_backend().await;
209        let search = Search::default().limit(1u64);
210        let count = backend.item_count(search).await.unwrap();
211        assert_eq!(count, 3);
212    }
213
214    #[tokio::test]
215    async fn search_honors_numeric_skip_token() {
216        let backend = populated_backend().await;
217        let mut search = Search::default().limit(1u64);
218        let _ = search
219            .additional_fields
220            .insert("skip".to_string(), 1.into());
221
222        let page = backend.search(search).await.unwrap();
223
224        assert_eq!(page.items.len(), 1);
225        assert_eq!(
226            page.items[0].get("id").and_then(|value| value.as_str()),
227            Some("item-b")
228        );
229    }
230
231    #[tokio::test]
232    async fn collections_stream_with_real_backend() {
233        let backend = populated_backend().await;
234        let collections = backend.collect_collections().await.unwrap();
235        assert_eq!(collections.len(), 1);
236        assert_eq!(collections[0].id, "collection-id");
237        let items = backend
238            .items("collection-id", Items::default())
239            .await
240            .unwrap();
241        assert_eq!(items.items.len(), 3);
242    }
243}