1use crate::{
2 adapter::{AdapterKind, DatabaseConfig},
3 error::{DataError, DataResult},
4 query::{FilterOperator, Query, SortDirection},
5};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::{cmp::Ordering, collections::BTreeMap};
9
10pub type Row = BTreeMap<String, Value>;
11
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13pub struct StoredRow {
14 pub id: u64,
15 pub data: Row,
16}
17
18pub trait AdapterDriver: Send + Sync {
19 fn kind(&self) -> AdapterKind;
20}
21
22#[derive(Debug, Default, Clone, Copy)]
23pub struct PostgresAdapter;
24
25impl AdapterDriver for PostgresAdapter {
26 fn kind(&self) -> AdapterKind {
27 AdapterKind::Postgres
28 }
29}
30
31#[derive(Debug, Default, Clone, Copy)]
32pub struct MySqlAdapter;
33
34impl AdapterDriver for MySqlAdapter {
35 fn kind(&self) -> AdapterKind {
36 AdapterKind::MySql
37 }
38}
39
40#[derive(Debug, Default, Clone, Copy)]
41pub struct SqliteAdapter;
42
43impl AdapterDriver for SqliteAdapter {
44 fn kind(&self) -> AdapterKind {
45 AdapterKind::Sqlite
46 }
47}
48
49pub fn adapter_for(config: &DatabaseConfig) -> DataResult<Box<dyn AdapterDriver>> {
50 match config.adapter {
51 AdapterKind::Postgres => Ok(Box::new(PostgresAdapter)),
52 AdapterKind::MySql => Ok(Box::new(MySqlAdapter)),
53 AdapterKind::Sqlite => Ok(Box::new(SqliteAdapter)),
54 AdapterKind::None => Err(DataError::Adapter(
55 "database adapter is `none`; select postgres/mysql/sqlite in shelly.data.toml"
56 .to_string(),
57 )),
58 }
59}
60
61pub trait Repo {
62 fn adapter_kind(&self) -> AdapterKind;
63 fn insert(&mut self, table: &str, data: Row) -> DataResult<StoredRow>;
64 fn update(&mut self, table: &str, id: u64, data: Row) -> DataResult<StoredRow>;
65 fn delete(&mut self, table: &str, id: u64) -> DataResult<()>;
66 fn find(&self, table: &str, id: u64) -> DataResult<Option<StoredRow>>;
67 fn list(&self, table: &str, query: &Query) -> DataResult<Vec<StoredRow>>;
68}
69
70pub struct MemoryRepo {
71 driver: Box<dyn AdapterDriver>,
72 tables: BTreeMap<String, Vec<StoredRow>>,
73 next_id: u64,
74}
75
76impl MemoryRepo {
77 pub fn new(driver: Box<dyn AdapterDriver>) -> Self {
78 Self {
79 driver,
80 tables: BTreeMap::new(),
81 next_id: 1,
82 }
83 }
84}
85
86impl Repo for MemoryRepo {
87 fn adapter_kind(&self) -> AdapterKind {
88 self.driver.kind()
89 }
90
91 fn insert(&mut self, table: &str, data: Row) -> DataResult<StoredRow> {
92 let entry = self.tables.entry(table.to_string()).or_default();
93 let row = StoredRow {
94 id: self.next_id,
95 data,
96 };
97 self.next_id += 1;
98 entry.push(row.clone());
99 Ok(row)
100 }
101
102 fn update(&mut self, table: &str, id: u64, data: Row) -> DataResult<StoredRow> {
103 let rows = self.tables.entry(table.to_string()).or_default();
104 let Some(existing) = rows.iter_mut().find(|row| row.id == id) else {
105 return Err(DataError::Query(format!(
106 "row id {id} not found in table `{table}`"
107 )));
108 };
109 existing.data = data;
110 Ok(existing.clone())
111 }
112
113 fn delete(&mut self, table: &str, id: u64) -> DataResult<()> {
114 let rows = self.tables.entry(table.to_string()).or_default();
115 let initial_len = rows.len();
116 rows.retain(|row| row.id != id);
117 if rows.len() == initial_len {
118 return Err(DataError::Query(format!(
119 "row id {id} not found in table `{table}`"
120 )));
121 }
122 Ok(())
123 }
124
125 fn find(&self, table: &str, id: u64) -> DataResult<Option<StoredRow>> {
126 Ok(self
127 .tables
128 .get(table)
129 .and_then(|rows| rows.iter().find(|row| row.id == id))
130 .cloned())
131 }
132
133 fn list(&self, table: &str, query: &Query) -> DataResult<Vec<StoredRow>> {
134 let mut rows = self.tables.get(table).cloned().unwrap_or_default();
135
136 if !query.filters.is_empty() {
137 rows.retain(|row| {
138 query
139 .filters
140 .iter()
141 .all(|filter| matches_filter(row, filter))
142 });
143 }
144
145 for sort in query.sorts.iter().rev() {
146 rows.sort_by(|left, right| compare_for_sort(left, right, sort.field.as_str()));
147 if sort.direction == SortDirection::Desc {
148 rows.reverse();
149 }
150 }
151
152 if let Some(pagination) = query.pagination {
153 let offset = (pagination.page.saturating_sub(1)) * pagination.per_page;
154 rows = rows
155 .into_iter()
156 .skip(offset)
157 .take(pagination.per_page)
158 .collect();
159 }
160
161 Ok(rows)
162 }
163}
164
165fn matches_filter(row: &StoredRow, filter: &crate::query::Filter) -> bool {
166 let Some(candidate) = row.data.get(&filter.field) else {
167 return false;
168 };
169 match filter.op {
170 FilterOperator::Eq => candidate == &filter.value,
171 FilterOperator::Neq => candidate != &filter.value,
172 FilterOperator::Contains => candidate
173 .as_str()
174 .zip(filter.value.as_str())
175 .is_some_and(|(left, right)| left.contains(right)),
176 FilterOperator::Gt => {
177 compare_numbers(candidate, &filter.value).is_some_and(|ord| ord == Ordering::Greater)
178 }
179 FilterOperator::Gte => compare_numbers(candidate, &filter.value)
180 .is_some_and(|ord| ord == Ordering::Greater || ord == Ordering::Equal),
181 FilterOperator::Lt => {
182 compare_numbers(candidate, &filter.value).is_some_and(|ord| ord == Ordering::Less)
183 }
184 FilterOperator::Lte => compare_numbers(candidate, &filter.value)
185 .is_some_and(|ord| ord == Ordering::Less || ord == Ordering::Equal),
186 }
187}
188
189fn compare_for_sort(left: &StoredRow, right: &StoredRow, field: &str) -> Ordering {
190 let left_value = left.data.get(field);
191 let right_value = right.data.get(field);
192 match (left_value, right_value) {
193 (Some(Value::Number(left_num)), Some(Value::Number(right_num))) => left_num
194 .as_f64()
195 .partial_cmp(&right_num.as_f64())
196 .unwrap_or(Ordering::Equal),
197 (Some(Value::String(left_text)), Some(Value::String(right_text))) => {
198 left_text.cmp(right_text)
199 }
200 _ => left.id.cmp(&right.id),
201 }
202}
203
204fn compare_numbers(left: &Value, right: &Value) -> Option<Ordering> {
205 left.as_f64()
206 .zip(right.as_f64())
207 .and_then(|(left, right)| left.partial_cmp(&right))
208}
209
210#[cfg(test)]
211mod tests {
212 use super::{adapter_for, DatabaseConfig, MemoryRepo, Repo, Row};
213 use crate::{AdapterKind, DataError, Filter, FilterOperator, Query, SortDirection};
214 use serde_json::json;
215
216 #[test]
217 fn memory_repo_works_for_adapter_selection() {
218 let mut repo = MemoryRepo::new(
219 adapter_for(&DatabaseConfig {
220 adapter: AdapterKind::Sqlite,
221 url: None,
222 url_env: None,
223 })
224 .unwrap(),
225 );
226 let mut row = Row::new();
227 row.insert("title".to_string(), json!("Alpha"));
228 row.insert("score".to_string(), json!(10));
229 repo.insert("posts", row).unwrap();
230
231 let rows = repo
232 .list(
233 "posts",
234 &Query::new()
235 .where_filter(Filter::contains("title", "Al"))
236 .order_by("score", SortDirection::Desc),
237 )
238 .unwrap();
239 assert_eq!(rows.len(), 1);
240 assert_eq!(rows[0].data.get("title"), Some(&json!("Alpha")));
241 }
242
243 #[test]
244 fn adapter_for_rejects_none_and_selects_expected_driver() {
245 let none_result = adapter_for(&DatabaseConfig {
246 adapter: AdapterKind::None,
247 url: None,
248 url_env: None,
249 });
250 assert!(matches!(none_result, Err(DataError::Adapter(_))));
251
252 for kind in [
253 AdapterKind::Postgres,
254 AdapterKind::MySql,
255 AdapterKind::Sqlite,
256 ] {
257 let driver = adapter_for(&DatabaseConfig {
258 adapter: kind,
259 url: None,
260 url_env: None,
261 })
262 .expect("driver should be created");
263 assert_eq!(driver.kind(), kind);
264 }
265 }
266
267 #[test]
268 fn update_delete_and_find_cover_missing_rows() {
269 let mut repo = MemoryRepo::new(Box::new(super::SqliteAdapter));
270
271 let mut row = Row::new();
272 row.insert("title".to_string(), json!("Draft"));
273 let inserted = repo.insert("posts", row).expect("insert should work");
274
275 assert_eq!(
276 repo.find("posts", inserted.id)
277 .expect("find should not fail")
278 .map(|it| it.id),
279 Some(inserted.id)
280 );
281 assert!(repo
282 .find("posts", 999)
283 .expect("find should not fail")
284 .is_none());
285 assert!(repo
286 .find("missing_table", inserted.id)
287 .expect("find should not fail")
288 .is_none());
289
290 let mut updated = Row::new();
291 updated.insert("title".to_string(), json!("Published"));
292 let updated_row = repo
293 .update("posts", inserted.id, updated)
294 .expect("update should work");
295 assert_eq!(updated_row.data.get("title"), Some(&json!("Published")));
296
297 let update_err = repo
298 .update("posts", 404, Row::new())
299 .expect_err("missing row should fail update");
300 assert!(matches!(update_err, DataError::Query(_)));
301
302 repo.delete("posts", inserted.id)
303 .expect("delete should remove row");
304 let delete_err = repo
305 .delete("posts", inserted.id)
306 .expect_err("deleting missing row should fail");
307 assert!(matches!(delete_err, DataError::Query(_)));
308 }
309
310 #[test]
311 fn list_applies_filters_sorts_and_pagination() {
312 let mut repo = MemoryRepo::new(Box::new(super::SqliteAdapter));
313
314 let mut alpha = Row::new();
315 alpha.insert("title".to_string(), json!("Alpha"));
316 alpha.insert("score".to_string(), json!(10));
317 alpha.insert("tag".to_string(), json!("core"));
318 repo.insert("posts", alpha).expect("insert alpha");
319
320 let mut beta = Row::new();
321 beta.insert("title".to_string(), json!("Beta"));
322 beta.insert("score".to_string(), json!(20));
323 beta.insert("tag".to_string(), json!("ops"));
324 repo.insert("posts", beta).expect("insert beta");
325
326 let mut gamma = Row::new();
327 gamma.insert("title".to_string(), json!("Gamma"));
328 gamma.insert("score".to_string(), json!(15));
329 gamma.insert("tag".to_string(), json!(123));
330 repo.insert("posts", gamma).expect("insert gamma");
331
332 let eq_rows = repo
333 .list(
334 "posts",
335 &Query::new().where_filter(Filter::eq("title", json!("Alpha"))),
336 )
337 .expect("eq filter");
338 assert_eq!(eq_rows.len(), 1);
339 assert_eq!(eq_rows[0].data.get("title"), Some(&json!("Alpha")));
340
341 let neq_rows = repo
342 .list(
343 "posts",
344 &Query::new().where_filter(crate::Filter {
345 field: "title".to_string(),
346 op: FilterOperator::Neq,
347 value: json!("Alpha"),
348 }),
349 )
350 .expect("neq filter");
351 assert_eq!(neq_rows.len(), 2);
352
353 let contains_rows = repo
354 .list(
355 "posts",
356 &Query::new().where_filter(Filter::contains("title", "mm")),
357 )
358 .expect("contains filter");
359 assert_eq!(contains_rows.len(), 1);
360 assert_eq!(contains_rows[0].data.get("title"), Some(&json!("Gamma")));
361
362 let contains_non_string_rows = repo
363 .list(
364 "posts",
365 &Query::new().where_filter(Filter::contains("tag", "2")),
366 )
367 .expect("contains on mixed type");
368 assert!(contains_non_string_rows.is_empty());
369
370 for (op, expected_titles) in [
371 (FilterOperator::Gt, vec!["Beta"]),
372 (FilterOperator::Gte, vec!["Beta", "Gamma"]),
373 (FilterOperator::Lt, vec!["Alpha"]),
374 (FilterOperator::Lte, vec!["Alpha", "Gamma"]),
375 ] {
376 let rows = repo
377 .list(
378 "posts",
379 &Query::new().where_filter(crate::Filter {
380 field: "score".to_string(),
381 op,
382 value: json!(15),
383 }),
384 )
385 .expect("numeric filter");
386 let titles: Vec<&str> = rows
387 .iter()
388 .map(|row| {
389 row.data
390 .get("title")
391 .and_then(|value| value.as_str())
392 .expect("title")
393 })
394 .collect();
395 assert_eq!(titles, expected_titles);
396 }
397
398 let unknown_field_sort = repo
399 .list(
400 "posts",
401 &Query::new()
402 .order_by("missing", SortDirection::Desc)
403 .paginate(1, 2),
404 )
405 .expect("fallback sort");
406 assert_eq!(unknown_field_sort.len(), 2);
407 assert_eq!(unknown_field_sort[0].id, 3);
408 assert_eq!(unknown_field_sort[1].id, 2);
409
410 let score_sort = repo
411 .list(
412 "posts",
413 &Query::new()
414 .order_by("score", SortDirection::Desc)
415 .order_by("title", SortDirection::Asc),
416 )
417 .expect("score sort");
418 let score_titles: Vec<&str> = score_sort
419 .iter()
420 .map(|row| {
421 row.data
422 .get("title")
423 .and_then(|value| value.as_str())
424 .expect("title")
425 })
426 .collect();
427 assert_eq!(score_titles, vec!["Beta", "Gamma", "Alpha"]);
428 }
429}