Skip to main content

teaql_runtime/
memory.rs

1use std::cmp::Ordering;
2use std::collections::BTreeMap;
3use std::sync::{Arc, Mutex};
4
5use rust_decimal::Decimal;
6use rust_decimal::prelude::ToPrimitive;
7use teaql_core::{
8    Aggregate, AggregateFunction, BinaryOp, DeleteCommand, Entity, Expr, ExprFunction,
9    InsertCommand, Record, RecoverCommand, RelationAggregate, SelectQuery, SmartList, SortDirection,
10    UpdateCommand, Value,
11};
12
13use crate::{InMemoryMetadataStore, MetadataStore, RepositoryError, RuntimeError};
14
15#[derive(Debug)]
16pub enum MemoryRepositoryError {
17    Poisoned,
18    UnsupportedExpression(String),
19    UnsupportedAggregate(String),
20}
21
22impl std::fmt::Display for MemoryRepositoryError {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            Self::Poisoned => write!(f, "memory repository lock poisoned"),
26            Self::UnsupportedExpression(message) => {
27                write!(f, "unsupported memory expression: {message}")
28            }
29            Self::UnsupportedAggregate(message) => {
30                write!(f, "unsupported memory aggregate: {message}")
31            }
32        }
33    }
34}
35
36impl std::error::Error for MemoryRepositoryError {}
37
38#[derive(Debug, Clone)]
39pub struct MemoryRepository<M = InMemoryMetadataStore> {
40    metadata: M,
41    data: Arc<Mutex<BTreeMap<String, Vec<Record>>>>,
42}
43
44impl<M> MemoryRepository<M>
45where
46    M: MetadataStore,
47{
48    pub fn new(metadata: M) -> Self {
49        Self {
50            metadata,
51            data: Arc::new(Mutex::new(BTreeMap::new())),
52        }
53    }
54
55    pub fn with_rows(mut self, entity: impl Into<String>, rows: Vec<Record>) -> Self {
56        self.seed(entity, rows);
57        self
58    }
59
60    pub fn seed(&mut self, entity: impl Into<String>, rows: Vec<Record>) {
61        if let Ok(mut data) = self.data.lock() {
62            data.insert(entity.into(), rows);
63        }
64    }
65
66    pub fn fetch_all(
67        &self,
68        query: &SelectQuery,
69    ) -> Result<Vec<Record>, RepositoryError<MemoryRepositoryError>> {
70        self.require_entity(&query.entity)?;
71        let data = self
72            .data
73            .lock()
74            .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
75        let mut rows = data.get(&query.entity).cloned().unwrap_or_default();
76        drop(data);
77
78        if let Some(filter) = &query.filter {
79            rows = rows
80                .into_iter()
81                .filter_map(|row| match eval_filter(filter, &row) {
82                    Ok(true) => Some(Ok(row)),
83                    Ok(false) => None,
84                    Err(err) => Some(Err(err)),
85                })
86                .collect::<Result<Vec<_>, _>>()
87                .map_err(RepositoryError::Executor)?;
88        }
89
90        if !query.aggregates.is_empty() {
91            return aggregate_rows(query, &rows).map_err(RepositoryError::Executor);
92        }
93
94        apply_ordering(&mut rows, query);
95        rows = apply_slice(rows, query);
96        if !query.projection.is_empty() || !query.expr_projection.is_empty() {
97            rows = rows
98                .into_iter()
99                .map(|row| project_row(row, query))
100                .collect::<Result<Vec<_>, _>>()
101                .map_err(RepositoryError::Executor)?;
102        }
103        Ok(rows)
104    }
105
106    pub fn fetch_smart_list(
107        &self,
108        query: &SelectQuery,
109    ) -> Result<SmartList<Record>, RepositoryError<MemoryRepositoryError>> {
110        self.fetch_all(query).map(SmartList::from)
111    }
112
113    pub fn fetch_entities<T>(
114        &self,
115        query: &SelectQuery,
116    ) -> Result<SmartList<T>, RepositoryError<MemoryRepositoryError>>
117    where
118        T: Entity,
119    {
120        self.fetch_all(query)?
121            .into_iter()
122            .map(T::from_record)
123            .collect::<Result<Vec<_>, _>>()
124            .map(SmartList::from)
125            .map_err(RepositoryError::Entity)
126    }
127
128    pub fn fetch_all_with_relation_aggregates(
129        &self,
130        query: &SelectQuery,
131        relation_aggregates: &[RelationAggregate],
132    ) -> Result<Vec<Record>, RepositoryError<MemoryRepositoryError>> {
133        let mut rows = self.fetch_all(query)?;
134        self.enhance_relation_aggregates(&query.entity, &mut rows, relation_aggregates)?;
135        Ok(rows)
136    }
137
138    pub fn fetch_smart_list_with_relation_aggregates(
139        &self,
140        query: &SelectQuery,
141        relation_aggregates: &[RelationAggregate],
142    ) -> Result<SmartList<Record>, RepositoryError<MemoryRepositoryError>> {
143        self.fetch_all_with_relation_aggregates(query, relation_aggregates)
144            .map(SmartList::from)
145    }
146
147    pub fn fetch_entities_with_relation_aggregates<T>(
148        &self,
149        query: &SelectQuery,
150        relation_aggregates: &[RelationAggregate],
151    ) -> Result<SmartList<T>, RepositoryError<MemoryRepositoryError>>
152    where
153        T: Entity,
154    {
155        self.fetch_all_with_relation_aggregates(query, relation_aggregates)?
156            .into_iter()
157            .map(T::from_record)
158            .collect::<Result<Vec<_>, _>>()
159            .map(SmartList::from)
160            .map_err(RepositoryError::Entity)
161    }
162
163    pub fn enhance_relation_aggregates(
164        &self,
165        parent_entity: &str,
166        parent_rows: &mut [Record],
167        relation_aggregates: &[RelationAggregate],
168    ) -> Result<(), RepositoryError<MemoryRepositoryError>> {
169        for aggregate in relation_aggregates {
170            self.enhance_relation_aggregate(parent_entity, parent_rows, aggregate)?;
171        }
172        Ok(())
173    }
174
175    fn enhance_relation_aggregate(
176        &self,
177        parent_entity: &str,
178        parent_rows: &mut [Record],
179        aggregate: &RelationAggregate,
180    ) -> Result<(), RepositoryError<MemoryRepositoryError>> {
181        let descriptor = self
182            .metadata
183            .entity(parent_entity)
184            .ok_or_else(|| {
185                RepositoryError::Runtime(RuntimeError::MissingEntity(parent_entity.to_owned()))
186            })?;
187
188        let relation = descriptor
189            .relation_by_name(&aggregate.relation_name)
190            .ok_or_else(|| {
191                RepositoryError::Runtime(RuntimeError::MissingRelation {
192                    entity: parent_entity.to_owned(),
193                    relation: aggregate.relation_name.clone(),
194                })
195            })?;
196
197        let ids = parent_rows
198            .iter()
199            .filter_map(|row| row.get(&relation.local_key).cloned())
200            .collect::<Vec<_>>();
201
202        if ids.is_empty() {
203            let value = if aggregate.single_result {
204                Value::U64(0)
205            } else {
206                Value::List(Vec::new())
207            };
208            for parent in parent_rows.iter_mut() {
209                parent.insert(aggregate.alias.clone(), value.clone());
210            }
211            return Ok(());
212        }
213
214        let mut query = aggregate.query.clone();
215        query.entity = relation.target_entity.clone();
216        query.projection.clear();
217        query.expr_projection.clear();
218        query.order_by.clear();
219        query.slice = None;
220        query.relations.clear();
221        if query.aggregates.is_empty() {
222            let alias = if aggregate.single_result {
223                aggregate.alias.clone()
224            } else {
225                "count".to_owned()
226            };
227            query = query.aggregate(Aggregate::count(alias));
228        }
229        if !query
230            .group_by
231            .iter()
232            .any(|field| field == &relation.foreign_key)
233        {
234            query = query.group_by(relation.foreign_key.clone());
235        }
236        query = query.and_filter(Expr::in_list(relation.foreign_key.clone(), ids));
237
238        let aggregate_rows = self.fetch_all(&query)?;
239
240        let mut buckets: BTreeMap<String, Vec<Record>> = BTreeMap::new();
241        for mut row in aggregate_rows {
242            if let Some(key) = row.remove(&relation.foreign_key) {
243                let bucket_key = local_graph_identity_key(&key);
244                buckets
245                    .entry(bucket_key)
246                    .or_default()
247                    .push(row);
248            }
249        }
250
251        for parent in parent_rows {
252            let value = parent
253                .get(&relation.local_key)
254                .and_then(|local_value| buckets.get(&local_graph_identity_key(local_value)))
255                .map(|rows| {
256                    if aggregate.single_result {
257                        rows.first()
258                            .map(|row| {
259                                if row.len() == 1 {
260                                    row.values().next().cloned().unwrap_or(Value::Null)
261                                } else {
262                                    Value::object(row.clone())
263                                }
264                            })
265                            .unwrap_or(Value::U64(0))
266                    } else {
267                        Value::List(rows.iter().cloned().map(Value::object).collect())
268                    }
269                })
270                .unwrap_or_else(|| {
271                    if aggregate.single_result {
272                        Value::U64(0)
273                    } else {
274                        Value::List(Vec::new())
275                    }
276                });
277            parent.insert(aggregate.alias.clone(), value);
278        }
279
280        Ok(())
281    }
282
283    pub fn insert(
284        &self,
285        command: &InsertCommand,
286    ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
287        self.require_entity(&command.entity)?;
288        let mut data = self
289            .data
290            .lock()
291            .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
292        data.entry(command.entity.clone())
293            .or_default()
294            .push(command.values.clone());
295        Ok(1)
296    }
297
298    pub fn update(
299        &self,
300        command: &UpdateCommand,
301    ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
302        let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
303        let mut data = self
304            .data
305            .lock()
306            .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
307        let rows = data.entry(command.entity.clone()).or_default();
308        let Some(row) = rows
309            .iter_mut()
310            .find(|row| row.get(id_property) == Some(&command.id))
311        else {
312            return self.maybe_optimistic_conflict(
313                command.expected_version,
314                &command.entity,
315                &command.id,
316            );
317        };
318
319        if let Some(expected_version) = command.expected_version {
320            if row.get(version_property) != Some(&Value::I64(expected_version)) {
321                return Err(RepositoryError::Runtime(
322                    RuntimeError::OptimisticLockConflict {
323                        entity: command.entity.clone(),
324                        id: format!("{:?}", command.id),
325                    },
326                ));
327            }
328            row.insert(
329                version_property.to_owned(),
330                Value::I64(expected_version + 1),
331            );
332        }
333
334        for (key, value) in &command.values {
335            row.insert(key.clone(), value.clone());
336        }
337        Ok(1)
338    }
339
340    pub fn delete(
341        &self,
342        command: &DeleteCommand,
343    ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
344        let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
345        let mut data = self
346            .data
347            .lock()
348            .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
349        let rows = data.entry(command.entity.clone()).or_default();
350        let Some(index) = rows
351            .iter()
352            .position(|row| row.get(id_property) == Some(&command.id))
353        else {
354            return self.maybe_optimistic_conflict(
355                command.expected_version,
356                &command.entity,
357                &command.id,
358            );
359        };
360
361        if let Some(expected_version) = command.expected_version {
362            if rows[index].get(version_property) != Some(&Value::I64(expected_version)) {
363                return Err(RepositoryError::Runtime(
364                    RuntimeError::OptimisticLockConflict {
365                        entity: command.entity.clone(),
366                        id: format!("{:?}", command.id),
367                    },
368                ));
369            }
370        }
371
372        if command.soft_delete {
373            let next_version = command
374                .expected_version
375                .or_else(|| read_i64(rows[index].get(version_property)))
376                .map(|version| -(version.abs() + 1))
377                .unwrap_or(-1);
378            rows[index].insert(version_property.to_owned(), Value::I64(next_version));
379        } else {
380            rows.remove(index);
381        }
382        Ok(1)
383    }
384
385    pub fn recover(
386        &self,
387        command: &RecoverCommand,
388    ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
389        let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
390        let mut data = self
391            .data
392            .lock()
393            .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
394        let rows = data.entry(command.entity.clone()).or_default();
395        let Some(row) = rows
396            .iter_mut()
397            .find(|row| row.get(id_property) == Some(&command.id))
398        else {
399            return Err(RepositoryError::Runtime(
400                RuntimeError::OptimisticLockConflict {
401                    entity: command.entity.clone(),
402                    id: format!("{:?}", command.id),
403                },
404            ));
405        };
406
407        if row.get(version_property) != Some(&Value::I64(command.expected_version)) {
408            return Err(RepositoryError::Runtime(
409                RuntimeError::OptimisticLockConflict {
410                    entity: command.entity.clone(),
411                    id: format!("{:?}", command.id),
412                },
413            ));
414        }
415
416        row.insert(
417            version_property.to_owned(),
418            Value::I64(command.expected_version.abs() + 1),
419        );
420        Ok(1)
421    }
422
423    fn require_entity(&self, entity: &str) -> Result<(), RepositoryError<MemoryRepositoryError>> {
424        self.metadata
425            .entity(entity)
426            .map(|_| ())
427            .ok_or_else(|| RepositoryError::Runtime(RuntimeError::MissingEntity(entity.to_owned())))
428    }
429
430    fn id_and_version_properties(
431        &self,
432        entity: &str,
433    ) -> Result<(&str, &str), RepositoryError<MemoryRepositoryError>> {
434        let descriptor = self.metadata.entity(entity).ok_or_else(|| {
435            RepositoryError::Runtime(RuntimeError::MissingEntity(entity.to_owned()))
436        })?;
437        let id = descriptor
438            .id_property()
439            .map(|property| property.name.as_str())
440            .unwrap_or("id");
441        let version = descriptor
442            .version_property()
443            .map(|property| property.name.as_str())
444            .unwrap_or("version");
445        Ok((id, version))
446    }
447
448    fn maybe_optimistic_conflict(
449        &self,
450        expected_version: Option<i64>,
451        entity: &str,
452        id: &Value,
453    ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
454        if expected_version.is_some() {
455            Err(RepositoryError::Runtime(
456                RuntimeError::OptimisticLockConflict {
457                    entity: entity.to_owned(),
458                    id: format!("{id:?}"),
459                },
460            ))
461        } else {
462            Ok(0)
463        }
464    }
465}
466
467fn eval_filter(expr: &Expr, row: &Record) -> Result<bool, MemoryRepositoryError> {
468    match expr {
469        Expr::Column(_) | Expr::Value(_) | Expr::Function { .. } => {
470            value_truthy(&eval_value(expr, row)?)
471        }
472        Expr::Binary { left, op, right } => {
473            let left = eval_value(left, row)?;
474            let right = eval_value(right, row)?;
475            eval_binary(&left, *op, &right)
476        }
477        Expr::SubQuery { .. } => Err(MemoryRepositoryError::UnsupportedExpression(
478            "subquery filters require a SQL executor".to_owned(),
479        )),
480        Expr::Between { expr, lower, upper } => {
481            let value = eval_value(expr, row)?;
482            let lower = eval_value(lower, row)?;
483            let upper = eval_value(upper, row)?;
484            Ok(compare_values(&value, &lower) != Some(Ordering::Less)
485                && compare_values(&value, &upper) != Some(Ordering::Greater))
486        }
487        Expr::IsNull(expr) => Ok(matches!(eval_value(expr, row)?, Value::Null)),
488        Expr::IsNotNull(expr) => Ok(!matches!(eval_value(expr, row)?, Value::Null)),
489        Expr::And(parts) => {
490            for part in parts {
491                if !eval_filter(part, row)? {
492                    return Ok(false);
493                }
494            }
495            Ok(true)
496        }
497        Expr::Or(parts) => {
498            for part in parts {
499                if eval_filter(part, row)? {
500                    return Ok(true);
501                }
502            }
503            Ok(false)
504        }
505        Expr::Not(expr) => Ok(!eval_filter(expr, row)?),
506    }
507}
508
509fn eval_value(expr: &Expr, row: &Record) -> Result<Value, MemoryRepositoryError> {
510    match expr {
511        Expr::Column(column) => Ok(row.get(column).cloned().unwrap_or(Value::Null)),
512        Expr::Value(value) => Ok(value.clone()),
513        Expr::Function { function, args } => eval_function(*function, args, row),
514        other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
515            "cannot evaluate {other:?} as a scalar value"
516        ))),
517    }
518}
519
520fn eval_function(
521    function: ExprFunction,
522    args: &[Expr],
523    row: &Record,
524) -> Result<Value, MemoryRepositoryError> {
525    match function {
526        ExprFunction::Soundex => {
527            let [arg] = args else {
528                return Err(MemoryRepositoryError::UnsupportedExpression(
529                    "SOUNDEX expects exactly one argument".to_owned(),
530                ));
531            };
532            match eval_value(arg, row)? {
533                Value::Text(value) => Ok(Value::Text(soundex(&value))),
534                Value::Null => Ok(Value::Null),
535                other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
536                    "SOUNDEX expects text, got {other:?}"
537                ))),
538            }
539        }
540        ExprFunction::Gbk => {
541            let [arg] = args else {
542                return Err(MemoryRepositoryError::UnsupportedExpression(
543                    "GBK expects exactly one argument".to_owned(),
544                ));
545            };
546            eval_value(arg, row)
547        }
548        other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
549            "function {other:?} is only supported by SQL execution"
550        ))),
551    }
552}
553
554fn eval_binary(left: &Value, op: BinaryOp, right: &Value) -> Result<bool, MemoryRepositoryError> {
555    match op {
556        BinaryOp::Eq => Ok(values_equal(left, right)),
557        BinaryOp::Ne => Ok(!values_equal(left, right)),
558        BinaryOp::Gt => Ok(compare_values(left, right) == Some(Ordering::Greater)),
559        BinaryOp::Gte => Ok(matches!(
560            compare_values(left, right),
561            Some(Ordering::Greater | Ordering::Equal)
562        )),
563        BinaryOp::Lt => Ok(compare_values(left, right) == Some(Ordering::Less)),
564        BinaryOp::Lte => Ok(matches!(
565            compare_values(left, right),
566            Some(Ordering::Less | Ordering::Equal)
567        )),
568        BinaryOp::Like => match (left, right) {
569            (Value::Text(value), Value::Text(pattern)) => Ok(like_matches(value, pattern)),
570            _ => Ok(false),
571        },
572        BinaryOp::NotLike => match (left, right) {
573            (Value::Text(value), Value::Text(pattern)) => Ok(!like_matches(value, pattern)),
574            _ => Ok(true),
575        },
576        BinaryOp::In | BinaryOp::InLarge => match right {
577            Value::List(values) => Ok(values.iter().any(|value| values_equal(left, value))),
578            _ => Err(MemoryRepositoryError::UnsupportedExpression(
579                "IN expects a list value".to_owned(),
580            )),
581        },
582        BinaryOp::NotIn | BinaryOp::NotInLarge => match right {
583            Value::List(values) => Ok(!values.iter().any(|value| values_equal(left, value))),
584            _ => Err(MemoryRepositoryError::UnsupportedExpression(
585                "NOT IN expects a list value".to_owned(),
586            )),
587        },
588    }
589}
590
591fn value_truthy(value: &Value) -> Result<bool, MemoryRepositoryError> {
592    match value {
593        Value::Bool(value) => Ok(*value),
594        Value::Null => Ok(false),
595        other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
596            "non-boolean expression result: {other:?}"
597        ))),
598    }
599}
600
601fn values_equal(left: &Value, right: &Value) -> bool {
602    match (left, right) {
603        (Value::I64(left), Value::U64(right)) if *left >= 0 => *left as u64 == *right,
604        (Value::U64(left), Value::I64(right)) if *right >= 0 => *left == *right as u64,
605        _ => left == right,
606    }
607}
608
609fn compare_values(left: &Value, right: &Value) -> Option<Ordering> {
610    match (left, right) {
611        (Value::I64(left), Value::I64(right)) => Some(left.cmp(right)),
612        (Value::U64(left), Value::U64(right)) => Some(left.cmp(right)),
613        (Value::I64(left), Value::U64(right)) if *left >= 0 => Some((*left as u64).cmp(right)),
614        (Value::U64(left), Value::I64(right)) if *right >= 0 => Some(left.cmp(&(*right as u64))),
615        (Value::F64(left), Value::F64(right)) => left.partial_cmp(right),
616        (Value::Decimal(left), Value::Decimal(right)) => Some(left.cmp(right)),
617        (Value::Text(left), Value::Text(right)) => Some(left.cmp(right)),
618        (Value::Date(left), Value::Date(right)) => Some(left.cmp(right)),
619        (Value::Timestamp(left), Value::Timestamp(right)) => Some(left.cmp(right)),
620        _ => None,
621    }
622}
623
624fn like_matches(value: &str, pattern: &str) -> bool {
625    if pattern == "%" {
626        return true;
627    }
628    match (pattern.strip_prefix('%'), pattern.strip_suffix('%')) {
629        (Some(inner), Some(_)) if pattern.len() >= 2 => value.contains(&inner[..inner.len() - 1]),
630        (Some(suffix), None) => value.ends_with(suffix),
631        (None, Some(prefix)) => value.starts_with(prefix),
632        _ => value == pattern,
633    }
634}
635
636fn soundex(value: &str) -> String {
637    let mut letters = value
638        .chars()
639        .filter(|ch| ch.is_ascii_alphabetic())
640        .map(|ch| ch.to_ascii_uppercase());
641    let Some(first) = letters.next() else {
642        return "0000".to_owned();
643    };
644
645    let mut output = String::with_capacity(4);
646    output.push(first);
647    let mut previous_code = soundex_code(first);
648
649    for ch in letters {
650        let code = soundex_code(ch);
651        if code != '0' && code != previous_code {
652            output.push(code);
653            if output.len() == 4 {
654                return output;
655            }
656        }
657        previous_code = code;
658    }
659
660    while output.len() < 4 {
661        output.push('0');
662    }
663    output
664}
665
666fn soundex_code(ch: char) -> char {
667    match ch {
668        'B' | 'F' | 'P' | 'V' => '1',
669        'C' | 'G' | 'J' | 'K' | 'Q' | 'S' | 'X' | 'Z' => '2',
670        'D' | 'T' => '3',
671        'L' => '4',
672        'M' | 'N' => '5',
673        'R' => '6',
674        _ => '0',
675    }
676}
677
678fn apply_ordering(rows: &mut [Record], query: &SelectQuery) {
679    for order in query.order_by.iter().rev() {
680        rows.sort_by(|left, right| {
681            let left_value = if let Some(expr) = &order.expr {
682                eval_value(expr, left).ok()
683            } else {
684                left.get(&order.field).cloned()
685            };
686            let right_value = if let Some(expr) = &order.expr {
687                eval_value(expr, right).ok()
688            } else {
689                right.get(&order.field).cloned()
690            };
691            let ordering = match (left_value.as_ref(), right_value.as_ref()) {
692                (Some(left), Some(right)) => compare_values(left, right).unwrap_or(Ordering::Equal),
693                (None, Some(_)) => Ordering::Less,
694                (Some(_), None) => Ordering::Greater,
695                (None, None) => Ordering::Equal,
696            };
697            match order.direction {
698                SortDirection::Asc => ordering,
699                SortDirection::Desc => ordering.reverse(),
700            }
701        });
702    }
703}
704
705fn apply_slice(rows: Vec<Record>, query: &SelectQuery) -> Vec<Record> {
706    let Some(slice) = query.slice else {
707        return rows;
708    };
709    let offset = usize::try_from(slice.offset).unwrap_or(usize::MAX);
710    let limit = slice
711        .limit
712        .and_then(|limit| usize::try_from(limit).ok())
713        .unwrap_or(usize::MAX);
714    rows.into_iter().skip(offset).take(limit).collect()
715}
716
717fn project_row(row: Record, query: &SelectQuery) -> Result<Record, MemoryRepositoryError> {
718    let mut output: Record = query
719        .projection
720        .iter()
721        .filter_map(|field| row.get(field).cloned().map(|value| (field.clone(), value)))
722        .collect();
723    for projection in &query.expr_projection {
724        output.insert(
725            projection.alias.clone(),
726            eval_value(&projection.expr, &row)?,
727        );
728    }
729    Ok(output)
730}
731
732fn aggregate_rows(
733    query: &SelectQuery,
734    rows: &[Record],
735) -> Result<Vec<Record>, MemoryRepositoryError> {
736    let mut groups: BTreeMap<Vec<String>, Vec<&Record>> = BTreeMap::new();
737    if query.group_by.is_empty() {
738        groups.insert(Vec::new(), rows.iter().collect());
739    } else {
740        for row in rows {
741            let key = query
742                .group_by
743                .iter()
744                .map(|field| row.get(field).map(value_key).unwrap_or_default())
745                .collect::<Vec<_>>();
746            groups.entry(key).or_default().push(row);
747        }
748    }
749
750    let rows = groups
751        .into_values()
752        .map(|rows| {
753            let mut output = Record::new();
754            if let Some(first) = rows.first() {
755                for field in &query.group_by {
756                    if let Some(value) = first.get(field) {
757                        output.insert(field.clone(), value.clone());
758                    }
759                }
760            }
761            for aggregate in &query.aggregates {
762                let value = match aggregate.function {
763                    AggregateFunction::Count => {
764                        if aggregate.field == "*" {
765                            Value::U64(rows.len() as u64)
766                        } else {
767                            Value::U64(
768                                rows.iter()
769                                    .filter(|row| {
770                                        !matches!(
771                                            row.get(&aggregate.field),
772                                            None | Some(Value::Null)
773                                        )
774                                    })
775                                    .count() as u64,
776                            )
777                        }
778                    }
779                    AggregateFunction::Sum => numeric_sum(&rows, &aggregate.field)?,
780                    AggregateFunction::Avg => numeric_avg(&rows, &aggregate.field)?,
781                    AggregateFunction::Min => min_max(&rows, &aggregate.field, false)?,
782                    AggregateFunction::Max => min_max(&rows, &aggregate.field, true)?,
783                    AggregateFunction::Stddev => numeric_stddev(&rows, &aggregate.field, true)?,
784                    AggregateFunction::StddevPop => numeric_stddev(&rows, &aggregate.field, false)?,
785                    AggregateFunction::VarSamp => numeric_variance(&rows, &aggregate.field, true)?,
786                    AggregateFunction::VarPop => numeric_variance(&rows, &aggregate.field, false)?,
787                    AggregateFunction::BitAnd => {
788                        bit_aggregate(&rows, &aggregate.field, BitOp::And)?
789                    }
790                    AggregateFunction::BitOr => bit_aggregate(&rows, &aggregate.field, BitOp::Or)?,
791                    AggregateFunction::BitXor => {
792                        bit_aggregate(&rows, &aggregate.field, BitOp::Xor)?
793                    }
794                };
795                output.insert(aggregate.alias.clone(), value);
796            }
797            for projection in &query.expr_projection {
798                output.insert(
799                    projection.alias.clone(),
800                    eval_value(&projection.expr, &output)?,
801                );
802            }
803            Ok(output)
804        })
805        .collect::<Result<Vec<_>, _>>()?;
806    if let Some(having) = &query.having {
807        rows.into_iter()
808            .filter_map(|row| match eval_filter(having, &row) {
809                Ok(true) => Some(Ok(row)),
810                Ok(false) => None,
811                Err(err) => Some(Err(err)),
812            })
813            .collect()
814    } else {
815        Ok(rows)
816    }
817}
818
819fn numeric_sum(rows: &[&Record], field: &str) -> Result<Value, MemoryRepositoryError> {
820    let mut decimal_sum = Decimal::ZERO;
821    let mut integer_sum: i128 = 0;
822    let mut saw_decimal = false;
823    for value in rows.iter().filter_map(|row| row.get(field)) {
824        match value {
825            Value::I64(value) => {
826                integer_sum += i128::from(*value);
827                decimal_sum += Decimal::from(*value);
828            }
829            Value::U64(value) => {
830                integer_sum += i128::from(*value);
831                decimal_sum += Decimal::from(*value);
832            }
833            Value::F64(value) => {
834                saw_decimal = true;
835                decimal_sum += decimal_from_f64(*value);
836            }
837            Value::Decimal(value) => {
838                saw_decimal = true;
839                decimal_sum += *value;
840            }
841            Value::Null => {}
842            other => {
843                return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
844                    "SUM does not support {other:?}"
845                )));
846            }
847        }
848    }
849    if saw_decimal {
850        Ok(Value::Decimal(decimal_sum))
851    } else if integer_sum >= 0 {
852        Ok(Value::U64(integer_sum as u64))
853    } else {
854        Ok(Value::I64(integer_sum as i64))
855    }
856}
857
858fn numeric_avg(rows: &[&Record], field: &str) -> Result<Value, MemoryRepositoryError> {
859    let mut sum = Decimal::ZERO;
860    let mut count: u64 = 0;
861    for value in rows.iter().filter_map(|row| row.get(field)) {
862        match value {
863            Value::I64(value) => {
864                sum += Decimal::from(*value);
865                count += 1;
866            }
867            Value::U64(value) => {
868                sum += Decimal::from(*value);
869                count += 1;
870            }
871            Value::F64(value) => {
872                sum += decimal_from_f64(*value);
873                count += 1;
874            }
875            Value::Decimal(value) => {
876                sum += *value;
877                count += 1;
878            }
879            Value::Null => {}
880            other => {
881                return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
882                    "AVG does not support {other:?}"
883                )));
884            }
885        }
886    }
887    Ok(if count == 0 {
888        Value::Null
889    } else {
890        Value::Decimal(sum / Decimal::from(count))
891    })
892}
893
894fn decimal_from_f64(value: f64) -> Decimal {
895    Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
896}
897
898fn numeric_values(rows: &[&Record], field: &str) -> Result<Vec<f64>, MemoryRepositoryError> {
899    rows.iter()
900        .filter_map(|row| row.get(field))
901        .filter(|value| !matches!(value, Value::Null))
902        .map(|value| match value {
903            Value::I64(value) => Ok(*value as f64),
904            Value::U64(value) => Ok(*value as f64),
905            Value::F64(value) => Ok(*value),
906            Value::Decimal(value) => value.to_f64().ok_or_else(|| {
907                MemoryRepositoryError::UnsupportedAggregate(format!(
908                    "cannot convert decimal {value} to f64 for statistical aggregate"
909                ))
910            }),
911            other => Err(MemoryRepositoryError::UnsupportedAggregate(format!(
912                "numeric aggregate does not support {other:?}"
913            ))),
914        })
915        .collect()
916}
917
918fn numeric_variance(
919    rows: &[&Record],
920    field: &str,
921    sample: bool,
922) -> Result<Value, MemoryRepositoryError> {
923    let values = numeric_values(rows, field)?;
924    let count = values.len();
925    if count == 0 || (sample && count < 2) {
926        return Ok(Value::Null);
927    }
928    let mean = values.iter().sum::<f64>() / count as f64;
929    let sum = values
930        .iter()
931        .map(|value| {
932            let diff = value - mean;
933            diff * diff
934        })
935        .sum::<f64>();
936    let denominator = if sample { count - 1 } else { count } as f64;
937    Ok(Value::Decimal(decimal_from_f64(sum / denominator)))
938}
939
940fn numeric_stddev(
941    rows: &[&Record],
942    field: &str,
943    sample: bool,
944) -> Result<Value, MemoryRepositoryError> {
945    Ok(match numeric_variance(rows, field, sample)? {
946        Value::Decimal(value) => {
947            Value::Decimal(decimal_from_f64(value.to_f64().unwrap_or(0.0).sqrt()))
948        }
949        Value::Null => Value::Null,
950        other => other,
951    })
952}
953
954#[derive(Debug, Clone, Copy)]
955enum BitOp {
956    And,
957    Or,
958    Xor,
959}
960
961fn bit_aggregate(rows: &[&Record], field: &str, op: BitOp) -> Result<Value, MemoryRepositoryError> {
962    let mut selected: Option<i64> = None;
963    for value in rows.iter().filter_map(|row| row.get(field)) {
964        let value = match value {
965            Value::I64(value) => *value,
966            Value::U64(value) => i64::try_from(*value).map_err(|_| {
967                MemoryRepositoryError::UnsupportedAggregate(format!(
968                    "BIT aggregate u64 {value} exceeds i64 range"
969                ))
970            })?,
971            Value::Null => continue,
972            other => {
973                return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
974                    "BIT aggregate does not support {other:?}"
975                )));
976            }
977        };
978        selected = Some(match (selected, op) {
979            (None, _) => value,
980            (Some(current), BitOp::And) => current & value,
981            (Some(current), BitOp::Or) => current | value,
982            (Some(current), BitOp::Xor) => current ^ value,
983        });
984    }
985    Ok(selected.map(Value::I64).unwrap_or(Value::Null))
986}
987
988fn min_max(rows: &[&Record], field: &str, max: bool) -> Result<Value, MemoryRepositoryError> {
989    let mut selected: Option<Value> = None;
990    for value in rows.iter().filter_map(|row| row.get(field)) {
991        if matches!(value, Value::Null) {
992            continue;
993        }
994        match &selected {
995            None => selected = Some(value.clone()),
996            Some(current) => {
997                let Some(ordering) = compare_values(value, current) else {
998                    return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
999                        "MIN/MAX does not support {value:?}"
1000                    )));
1001                };
1002                if (max && ordering == Ordering::Greater) || (!max && ordering == Ordering::Less) {
1003                    selected = Some(value.clone());
1004                }
1005            }
1006        }
1007    }
1008    Ok(selected.unwrap_or(Value::Null))
1009}
1010
1011fn value_key(value: &Value) -> String {
1012    match value {
1013        Value::Null => "null".to_owned(),
1014        Value::Bool(value) => format!("b:{value}"),
1015        Value::I64(value) => format!("i:{value}"),
1016        Value::U64(value) => format!("u:{value}"),
1017        Value::F64(value) => format!("f:{value}"),
1018        Value::Decimal(value) => format!("d:{value}"),
1019        Value::Text(value) => format!("t:{value}"),
1020        Value::Json(value) => format!("j:{value}"),
1021        Value::Date(value) => format!("d:{value}"),
1022        Value::Timestamp(value) => format!("ts:{}", value.to_rfc3339()),
1023        Value::Object(_) => "object".to_owned(),
1024        Value::List(_) => "list".to_owned(),
1025    }
1026}
1027
1028fn read_i64(value: Option<&Value>) -> Option<i64> {
1029    match value {
1030        Some(Value::I64(value)) => Some(*value),
1031        _ => None,
1032    }
1033}
1034
1035fn local_graph_identity_key(value: &Value) -> String {
1036    match value {
1037        Value::I64(val) if *val >= 0 => format!("u:{}", *val as u64),
1038        Value::U64(val) => format!("u:{val}"),
1039        Value::Null => "null".to_owned(),
1040        Value::Bool(v) => format!("b:{v}"),
1041        Value::I64(v) => format!("i:{v}"),
1042        Value::F64(v) => format!("f:{v}"),
1043        Value::Decimal(v) => format!("d:{v}"),
1044        Value::Text(v) => format!("t:{v}"),
1045        Value::Json(v) => format!("j:{v}"),
1046        Value::Date(v) => format!("d:{v}"),
1047        Value::Timestamp(v) => format!("ts:{}", v.to_rfc3339()),
1048        Value::Object(_) => "o".to_owned(),
1049        Value::List(_) => "l".to_owned(),
1050    }
1051}
1052