Skip to main content

shelly_data/
repo.rs

1use crate::{
2    adapter::{AdapterKind, DatabaseConfig},
3    error::{DataError, DataResult},
4    query::{FilterOperator, Query, SortDirection},
5};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::{cmp::Ordering, collections::BTreeMap};
9
10pub type Row = BTreeMap<String, Value>;
11
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13pub struct StoredRow {
14    pub id: u64,
15    pub data: Row,
16}
17
18pub trait AdapterDriver: Send + Sync {
19    fn kind(&self) -> AdapterKind;
20}
21
22#[derive(Debug, Default, Clone, Copy)]
23pub struct PostgresAdapter;
24
25impl AdapterDriver for PostgresAdapter {
26    fn kind(&self) -> AdapterKind {
27        AdapterKind::Postgres
28    }
29}
30
31#[derive(Debug, Default, Clone, Copy)]
32pub struct MySqlAdapter;
33
34impl AdapterDriver for MySqlAdapter {
35    fn kind(&self) -> AdapterKind {
36        AdapterKind::MySql
37    }
38}
39
40#[derive(Debug, Default, Clone, Copy)]
41pub struct SqliteAdapter;
42
43impl AdapterDriver for SqliteAdapter {
44    fn kind(&self) -> AdapterKind {
45        AdapterKind::Sqlite
46    }
47}
48
49pub fn adapter_for(config: &DatabaseConfig) -> DataResult<Box<dyn AdapterDriver>> {
50    match config.adapter {
51        AdapterKind::Postgres => Ok(Box::new(PostgresAdapter)),
52        AdapterKind::MySql => Ok(Box::new(MySqlAdapter)),
53        AdapterKind::Sqlite => Ok(Box::new(SqliteAdapter)),
54        AdapterKind::None => Err(DataError::Adapter(
55            "database adapter is `none`; select postgres/mysql/sqlite in shelly.data.toml"
56                .to_string(),
57        )),
58    }
59}
60
61pub trait Repo {
62    fn adapter_kind(&self) -> AdapterKind;
63    fn insert(&mut self, table: &str, data: Row) -> DataResult<StoredRow>;
64    fn update(&mut self, table: &str, id: u64, data: Row) -> DataResult<StoredRow>;
65    fn delete(&mut self, table: &str, id: u64) -> DataResult<()>;
66    fn find(&self, table: &str, id: u64) -> DataResult<Option<StoredRow>>;
67    fn list(&self, table: &str, query: &Query) -> DataResult<Vec<StoredRow>>;
68}
69
70pub struct MemoryRepo {
71    driver: Box<dyn AdapterDriver>,
72    tables: BTreeMap<String, Vec<StoredRow>>,
73    next_id: u64,
74}
75
76impl MemoryRepo {
77    pub fn new(driver: Box<dyn AdapterDriver>) -> Self {
78        Self {
79            driver,
80            tables: BTreeMap::new(),
81            next_id: 1,
82        }
83    }
84}
85
86impl Repo for MemoryRepo {
87    fn adapter_kind(&self) -> AdapterKind {
88        self.driver.kind()
89    }
90
91    fn insert(&mut self, table: &str, data: Row) -> DataResult<StoredRow> {
92        let entry = self.tables.entry(table.to_string()).or_default();
93        let row = StoredRow {
94            id: self.next_id,
95            data,
96        };
97        self.next_id += 1;
98        entry.push(row.clone());
99        Ok(row)
100    }
101
102    fn update(&mut self, table: &str, id: u64, data: Row) -> DataResult<StoredRow> {
103        let rows = self.tables.entry(table.to_string()).or_default();
104        let Some(existing) = rows.iter_mut().find(|row| row.id == id) else {
105            return Err(DataError::Query(format!(
106                "row id {id} not found in table `{table}`"
107            )));
108        };
109        existing.data = data;
110        Ok(existing.clone())
111    }
112
113    fn delete(&mut self, table: &str, id: u64) -> DataResult<()> {
114        let rows = self.tables.entry(table.to_string()).or_default();
115        let initial_len = rows.len();
116        rows.retain(|row| row.id != id);
117        if rows.len() == initial_len {
118            return Err(DataError::Query(format!(
119                "row id {id} not found in table `{table}`"
120            )));
121        }
122        Ok(())
123    }
124
125    fn find(&self, table: &str, id: u64) -> DataResult<Option<StoredRow>> {
126        Ok(self
127            .tables
128            .get(table)
129            .and_then(|rows| rows.iter().find(|row| row.id == id))
130            .cloned())
131    }
132
133    fn list(&self, table: &str, query: &Query) -> DataResult<Vec<StoredRow>> {
134        let mut rows = self.tables.get(table).cloned().unwrap_or_default();
135
136        if !query.filters.is_empty() {
137            rows.retain(|row| {
138                query
139                    .filters
140                    .iter()
141                    .all(|filter| matches_filter(row, filter))
142            });
143        }
144
145        for sort in query.sorts.iter().rev() {
146            rows.sort_by(|left, right| compare_for_sort(left, right, sort.field.as_str()));
147            if sort.direction == SortDirection::Desc {
148                rows.reverse();
149            }
150        }
151
152        if let Some(pagination) = query.pagination {
153            let offset = (pagination.page.saturating_sub(1)) * pagination.per_page;
154            rows = rows
155                .into_iter()
156                .skip(offset)
157                .take(pagination.per_page)
158                .collect();
159        }
160
161        Ok(rows)
162    }
163}
164
165fn matches_filter(row: &StoredRow, filter: &crate::query::Filter) -> bool {
166    let Some(candidate) = row.data.get(&filter.field) else {
167        return false;
168    };
169    match filter.op {
170        FilterOperator::Eq => candidate == &filter.value,
171        FilterOperator::Neq => candidate != &filter.value,
172        FilterOperator::Contains => candidate
173            .as_str()
174            .zip(filter.value.as_str())
175            .is_some_and(|(left, right)| left.contains(right)),
176        FilterOperator::Gt => {
177            compare_numbers(candidate, &filter.value).is_some_and(|ord| ord == Ordering::Greater)
178        }
179        FilterOperator::Gte => compare_numbers(candidate, &filter.value)
180            .is_some_and(|ord| ord == Ordering::Greater || ord == Ordering::Equal),
181        FilterOperator::Lt => {
182            compare_numbers(candidate, &filter.value).is_some_and(|ord| ord == Ordering::Less)
183        }
184        FilterOperator::Lte => compare_numbers(candidate, &filter.value)
185            .is_some_and(|ord| ord == Ordering::Less || ord == Ordering::Equal),
186    }
187}
188
189fn compare_for_sort(left: &StoredRow, right: &StoredRow, field: &str) -> Ordering {
190    let left_value = left.data.get(field);
191    let right_value = right.data.get(field);
192    match (left_value, right_value) {
193        (Some(Value::Number(left_num)), Some(Value::Number(right_num))) => left_num
194            .as_f64()
195            .partial_cmp(&right_num.as_f64())
196            .unwrap_or(Ordering::Equal),
197        (Some(Value::String(left_text)), Some(Value::String(right_text))) => {
198            left_text.cmp(right_text)
199        }
200        _ => left.id.cmp(&right.id),
201    }
202}
203
204fn compare_numbers(left: &Value, right: &Value) -> Option<Ordering> {
205    left.as_f64()
206        .zip(right.as_f64())
207        .and_then(|(left, right)| left.partial_cmp(&right))
208}
209
210#[cfg(test)]
211mod tests {
212    use super::{adapter_for, DatabaseConfig, MemoryRepo, Repo, Row};
213    use crate::{AdapterKind, DataError, Filter, FilterOperator, Query, SortDirection};
214    use serde_json::json;
215
216    #[test]
217    fn memory_repo_works_for_adapter_selection() {
218        let mut repo = MemoryRepo::new(
219            adapter_for(&DatabaseConfig {
220                adapter: AdapterKind::Sqlite,
221                url: None,
222                url_env: None,
223            })
224            .unwrap(),
225        );
226        let mut row = Row::new();
227        row.insert("title".to_string(), json!("Alpha"));
228        row.insert("score".to_string(), json!(10));
229        repo.insert("posts", row).unwrap();
230
231        let rows = repo
232            .list(
233                "posts",
234                &Query::new()
235                    .where_filter(Filter::contains("title", "Al"))
236                    .order_by("score", SortDirection::Desc),
237            )
238            .unwrap();
239        assert_eq!(rows.len(), 1);
240        assert_eq!(rows[0].data.get("title"), Some(&json!("Alpha")));
241    }
242
243    #[test]
244    fn adapter_for_rejects_none_and_selects_expected_driver() {
245        let none_result = adapter_for(&DatabaseConfig {
246            adapter: AdapterKind::None,
247            url: None,
248            url_env: None,
249        });
250        assert!(matches!(none_result, Err(DataError::Adapter(_))));
251
252        for kind in [
253            AdapterKind::Postgres,
254            AdapterKind::MySql,
255            AdapterKind::Sqlite,
256        ] {
257            let driver = adapter_for(&DatabaseConfig {
258                adapter: kind,
259                url: None,
260                url_env: None,
261            })
262            .expect("driver should be created");
263            assert_eq!(driver.kind(), kind);
264        }
265    }
266
267    #[test]
268    fn update_delete_and_find_cover_missing_rows() {
269        let mut repo = MemoryRepo::new(Box::new(super::SqliteAdapter));
270
271        let mut row = Row::new();
272        row.insert("title".to_string(), json!("Draft"));
273        let inserted = repo.insert("posts", row).expect("insert should work");
274
275        assert_eq!(
276            repo.find("posts", inserted.id)
277                .expect("find should not fail")
278                .map(|it| it.id),
279            Some(inserted.id)
280        );
281        assert!(repo
282            .find("posts", 999)
283            .expect("find should not fail")
284            .is_none());
285        assert!(repo
286            .find("missing_table", inserted.id)
287            .expect("find should not fail")
288            .is_none());
289
290        let mut updated = Row::new();
291        updated.insert("title".to_string(), json!("Published"));
292        let updated_row = repo
293            .update("posts", inserted.id, updated)
294            .expect("update should work");
295        assert_eq!(updated_row.data.get("title"), Some(&json!("Published")));
296
297        let update_err = repo
298            .update("posts", 404, Row::new())
299            .expect_err("missing row should fail update");
300        assert!(matches!(update_err, DataError::Query(_)));
301
302        repo.delete("posts", inserted.id)
303            .expect("delete should remove row");
304        let delete_err = repo
305            .delete("posts", inserted.id)
306            .expect_err("deleting missing row should fail");
307        assert!(matches!(delete_err, DataError::Query(_)));
308    }
309
310    #[test]
311    fn list_applies_filters_sorts_and_pagination() {
312        let mut repo = MemoryRepo::new(Box::new(super::SqliteAdapter));
313
314        let mut alpha = Row::new();
315        alpha.insert("title".to_string(), json!("Alpha"));
316        alpha.insert("score".to_string(), json!(10));
317        alpha.insert("tag".to_string(), json!("core"));
318        repo.insert("posts", alpha).expect("insert alpha");
319
320        let mut beta = Row::new();
321        beta.insert("title".to_string(), json!("Beta"));
322        beta.insert("score".to_string(), json!(20));
323        beta.insert("tag".to_string(), json!("ops"));
324        repo.insert("posts", beta).expect("insert beta");
325
326        let mut gamma = Row::new();
327        gamma.insert("title".to_string(), json!("Gamma"));
328        gamma.insert("score".to_string(), json!(15));
329        gamma.insert("tag".to_string(), json!(123));
330        repo.insert("posts", gamma).expect("insert gamma");
331
332        let eq_rows = repo
333            .list(
334                "posts",
335                &Query::new().where_filter(Filter::eq("title", json!("Alpha"))),
336            )
337            .expect("eq filter");
338        assert_eq!(eq_rows.len(), 1);
339        assert_eq!(eq_rows[0].data.get("title"), Some(&json!("Alpha")));
340
341        let neq_rows = repo
342            .list(
343                "posts",
344                &Query::new().where_filter(crate::Filter {
345                    field: "title".to_string(),
346                    op: FilterOperator::Neq,
347                    value: json!("Alpha"),
348                }),
349            )
350            .expect("neq filter");
351        assert_eq!(neq_rows.len(), 2);
352
353        let contains_rows = repo
354            .list(
355                "posts",
356                &Query::new().where_filter(Filter::contains("title", "mm")),
357            )
358            .expect("contains filter");
359        assert_eq!(contains_rows.len(), 1);
360        assert_eq!(contains_rows[0].data.get("title"), Some(&json!("Gamma")));
361
362        let contains_non_string_rows = repo
363            .list(
364                "posts",
365                &Query::new().where_filter(Filter::contains("tag", "2")),
366            )
367            .expect("contains on mixed type");
368        assert!(contains_non_string_rows.is_empty());
369
370        for (op, expected_titles) in [
371            (FilterOperator::Gt, vec!["Beta"]),
372            (FilterOperator::Gte, vec!["Beta", "Gamma"]),
373            (FilterOperator::Lt, vec!["Alpha"]),
374            (FilterOperator::Lte, vec!["Alpha", "Gamma"]),
375        ] {
376            let rows = repo
377                .list(
378                    "posts",
379                    &Query::new().where_filter(crate::Filter {
380                        field: "score".to_string(),
381                        op,
382                        value: json!(15),
383                    }),
384                )
385                .expect("numeric filter");
386            let titles: Vec<&str> = rows
387                .iter()
388                .map(|row| {
389                    row.data
390                        .get("title")
391                        .and_then(|value| value.as_str())
392                        .expect("title")
393                })
394                .collect();
395            assert_eq!(titles, expected_titles);
396        }
397
398        let unknown_field_sort = repo
399            .list(
400                "posts",
401                &Query::new()
402                    .order_by("missing", SortDirection::Desc)
403                    .paginate(1, 2),
404            )
405            .expect("fallback sort");
406        assert_eq!(unknown_field_sort.len(), 2);
407        assert_eq!(unknown_field_sort[0].id, 3);
408        assert_eq!(unknown_field_sort[1].id, 2);
409
410        let score_sort = repo
411            .list(
412                "posts",
413                &Query::new()
414                    .order_by("score", SortDirection::Desc)
415                    .order_by("title", SortDirection::Asc),
416            )
417            .expect("score sort");
418        let score_titles: Vec<&str> = score_sort
419            .iter()
420            .map(|row| {
421                row.data
422                    .get("title")
423                    .and_then(|value| value.as_str())
424                    .expect("title")
425            })
426            .collect();
427        assert_eq!(score_titles, vec!["Beta", "Gamma", "Alpha"]);
428    }
429}