stac_server/backend/
memory.rs1use crate::{Backend, DEFAULT_LIMIT, Error, Result};
2use futures_core::Stream;
3use serde_json::Map;
4use stac::api::{
5 CollectionsClient, ItemCollection, ItemsClient, Search, StreamItemsClient, TransactionClient,
6 stream_pages,
7};
8use stac::{Collection, Item};
9use std::{
10 collections::{BTreeMap, HashMap},
11 sync::{Arc, RwLock},
12};
13
14#[derive(Clone, Debug)]
18pub struct MemoryBackend {
19 collections: Arc<RwLock<BTreeMap<String, Collection>>>,
20 items: Arc<RwLock<HashMap<String, Vec<Item>>>>,
21}
22
23impl MemoryBackend {
24 pub fn new() -> MemoryBackend {
33 MemoryBackend {
34 collections: Arc::new(RwLock::new(BTreeMap::new())),
35 items: Arc::new(RwLock::new(HashMap::new())),
36 }
37 }
38}
39
40impl ItemsClient for MemoryBackend {
41 type Error = Error;
42
43 async fn search(&self, mut search: Search) -> Result<ItemCollection> {
44 let items = self.items.read().unwrap();
45 if search.collections.is_empty() {
46 search.collections = items.keys().cloned().collect();
47 }
48 let mut item_references = Vec::new();
49 for collection in &search.collections {
50 if let Some(items) = items.get(collection) {
51 item_references.extend(
52 items
53 .iter()
54 .filter(|item| search.matches(item).unwrap_or_default()),
55 );
56 }
57 }
58 let limit = search.limit.unwrap_or(DEFAULT_LIMIT).try_into()?;
59 let skip = search
60 .additional_fields
61 .get("skip")
62 .and_then(|skip| {
63 skip.as_u64()
64 .or_else(|| skip.as_str().and_then(|skip| skip.parse::<u64>().ok()))
65 })
66 .unwrap_or_default()
67 .try_into()?;
68 let len = item_references.len();
69 let items = item_references
70 .into_iter()
71 .skip(skip)
72 .take(limit)
73 .map(|item| stac::api::Item::try_from(item.clone()).map_err(Error::from))
74 .collect::<Result<Vec<_>>>()?;
75 let mut item_collection = ItemCollection::new(items)?;
76 if len > item_collection.items.len() + skip {
77 let mut next = Map::new();
78 let _ = next.insert("skip".to_string(), (skip + limit).into());
79 item_collection.next = Some(next);
80 }
81 if skip > 0 {
82 let mut prev = Map::new();
83 let skip = skip.saturating_sub(limit);
84 let _ = prev.insert("skip".to_string(), skip.into());
85 item_collection.prev = Some(prev);
86 }
87 Ok(item_collection)
88 }
89
90 async fn item(&self, collection_id: &str, item_id: &str) -> Result<Option<Item>> {
91 let items = self.items.read().unwrap();
92 Ok(items
93 .get(collection_id)
94 .and_then(|items| items.iter().find(|item| item.id == item_id).cloned()))
95 }
96}
97
98impl CollectionsClient for MemoryBackend {
99 type Error = Error;
100
101 async fn collections(&self) -> Result<Vec<Collection>> {
102 let collections = self.collections.read().unwrap();
103 Ok(collections.values().cloned().collect())
104 }
105
106 async fn collection(&self, id: &str) -> Result<Option<Collection>> {
107 let collections = self.collections.read().unwrap();
108 Ok(collections.get(id).cloned())
109 }
110}
111
112impl TransactionClient for MemoryBackend {
113 type Error = Error;
114
115 async fn add_collection(&mut self, collection: Collection) -> Result<()> {
116 let mut collections = self.collections.write().unwrap();
117 let _ = collections.insert(collection.id.clone(), collection);
118 Ok(())
119 }
120
121 async fn add_item(&mut self, item: Item) -> Result<()> {
122 if let Some(collection_id) = item.collection.clone() {
123 if CollectionsClient::collection(self, &collection_id)
124 .await?
125 .is_none()
126 {
127 Err(Error::MemoryBackend(format!(
128 "no collection with id='{collection_id}'",
129 )))
130 } else {
131 let mut items = self.items.write().unwrap();
132 items.entry(collection_id).or_default().push(item);
133 Ok(())
134 }
135 } else {
136 Err(Error::MemoryBackend(format!(
137 "collection not set on item: {}",
138 item.id
139 )))
140 }
141 }
142}
143
144impl StreamItemsClient for MemoryBackend {
145 type Error = Error;
146
147 async fn search_stream(
148 &self,
149 search: Search,
150 ) -> Result<impl Stream<Item = std::result::Result<stac::api::Item, Error>> + Send> {
151 let page = ItemsClient::search(self, search.clone()).await?;
152 Ok(stream_pages(self.clone(), search, page))
153 }
154}
155
156impl Backend for MemoryBackend {
157 fn has_item_search(&self) -> bool {
158 true
159 }
160
161 fn has_filter(&self) -> bool {
162 false
163 }
164}
165
166impl Default for MemoryBackend {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use stac::api::{Items, StreamCollectionsClient};
176
177 async fn populated_backend() -> MemoryBackend {
178 let mut backend = MemoryBackend::new();
179 backend
180 .add_collection(Collection::new("collection-id", "a description"))
181 .await
182 .unwrap();
183 backend
184 .add_item(Item::new("item-a").collection("collection-id"))
185 .await
186 .unwrap();
187 backend
188 .add_item(Item::new("item-b").collection("collection-id"))
189 .await
190 .unwrap();
191 backend
192 .add_item(Item::new("item-c").collection("collection-id"))
193 .await
194 .unwrap();
195 backend
196 }
197
198 #[tokio::test]
199 async fn stream_items_across_pages_with_real_backend() {
200 let backend = populated_backend().await;
201 let search = Search::default().limit(1u64);
202 let items = backend.collect_items(search).await.unwrap();
203 assert_eq!(items.len(), 3);
204 }
205
206 #[tokio::test]
207 async fn item_count_uses_streaming_path() {
208 let backend = populated_backend().await;
209 let search = Search::default().limit(1u64);
210 let count = backend.item_count(search).await.unwrap();
211 assert_eq!(count, 3);
212 }
213
214 #[tokio::test]
215 async fn search_honors_numeric_skip_token() {
216 let backend = populated_backend().await;
217 let mut search = Search::default().limit(1u64);
218 let _ = search
219 .additional_fields
220 .insert("skip".to_string(), 1.into());
221
222 let page = backend.search(search).await.unwrap();
223
224 assert_eq!(page.items.len(), 1);
225 assert_eq!(
226 page.items[0].get("id").and_then(|value| value.as_str()),
227 Some("item-b")
228 );
229 }
230
231 #[tokio::test]
232 async fn collections_stream_with_real_backend() {
233 let backend = populated_backend().await;
234 let collections = backend.collect_collections().await.unwrap();
235 assert_eq!(collections.len(), 1);
236 assert_eq!(collections[0].id, "collection-id");
237 let items = backend
238 .items("collection-id", Items::default())
239 .await
240 .unwrap();
241 assert_eq!(items.items.len(), 3);
242 }
243}