Skip to main content

teaql_runtime/
memory.rs

1use std::cmp::Ordering;
2use std::collections::BTreeMap;
3use std::sync::{Arc, Mutex};
4
5use rust_decimal::Decimal;
6use rust_decimal::prelude::ToPrimitive;
7use teaql_core::{
8    AggregateFunction, BinaryOp, DeleteCommand, Entity, Expr, ExprFunction, InsertCommand, Record,
9    RecoverCommand, SelectQuery, SmartList, SortDirection, UpdateCommand, Value,
10};
11
12use crate::{InMemoryMetadataStore, MetadataStore, RepositoryError, RuntimeError};
13
14#[derive(Debug)]
15pub enum MemoryRepositoryError {
16    Poisoned,
17    UnsupportedExpression(String),
18    UnsupportedAggregate(String),
19}
20
21impl std::fmt::Display for MemoryRepositoryError {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        match self {
24            Self::Poisoned => write!(f, "memory repository lock poisoned"),
25            Self::UnsupportedExpression(message) => {
26                write!(f, "unsupported memory expression: {message}")
27            }
28            Self::UnsupportedAggregate(message) => {
29                write!(f, "unsupported memory aggregate: {message}")
30            }
31        }
32    }
33}
34
35impl std::error::Error for MemoryRepositoryError {}
36
37#[derive(Debug, Clone)]
38pub struct MemoryRepository<M = InMemoryMetadataStore> {
39    metadata: M,
40    data: Arc<Mutex<BTreeMap<String, Vec<Record>>>>,
41}
42
43impl<M> MemoryRepository<M>
44where
45    M: MetadataStore,
46{
47    pub fn new(metadata: M) -> Self {
48        Self {
49            metadata,
50            data: Arc::new(Mutex::new(BTreeMap::new())),
51        }
52    }
53
54    pub fn with_rows(mut self, entity: impl Into<String>, rows: Vec<Record>) -> Self {
55        self.seed(entity, rows);
56        self
57    }
58
59    pub fn seed(&mut self, entity: impl Into<String>, rows: Vec<Record>) {
60        if let Ok(mut data) = self.data.lock() {
61            data.insert(entity.into(), rows);
62        }
63    }
64
65    pub fn fetch_all(
66        &self,
67        query: &SelectQuery,
68    ) -> Result<Vec<Record>, RepositoryError<MemoryRepositoryError>> {
69        self.require_entity(&query.entity)?;
70        let data = self
71            .data
72            .lock()
73            .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
74        let mut rows = data.get(&query.entity).cloned().unwrap_or_default();
75        drop(data);
76
77        if let Some(filter) = &query.filter {
78            rows = rows
79                .into_iter()
80                .filter_map(|row| match eval_filter(filter, &row) {
81                    Ok(true) => Some(Ok(row)),
82                    Ok(false) => None,
83                    Err(err) => Some(Err(err)),
84                })
85                .collect::<Result<Vec<_>, _>>()
86                .map_err(RepositoryError::Executor)?;
87        }
88
89        if !query.aggregates.is_empty() {
90            return aggregate_rows(query, &rows).map_err(RepositoryError::Executor);
91        }
92
93        apply_ordering(&mut rows, query);
94        rows = apply_slice(rows, query);
95        if !query.projection.is_empty() || !query.expr_projection.is_empty() {
96            rows = rows
97                .into_iter()
98                .map(|row| project_row(row, query))
99                .collect::<Result<Vec<_>, _>>()
100                .map_err(RepositoryError::Executor)?;
101        }
102        Ok(rows)
103    }
104
105    pub fn fetch_smart_list(
106        &self,
107        query: &SelectQuery,
108    ) -> Result<SmartList<Record>, RepositoryError<MemoryRepositoryError>> {
109        self.fetch_all(query).map(SmartList::from)
110    }
111
112    pub fn fetch_entities<T>(
113        &self,
114        query: &SelectQuery,
115    ) -> Result<SmartList<T>, RepositoryError<MemoryRepositoryError>>
116    where
117        T: Entity,
118    {
119        self.fetch_all(query)?
120            .into_iter()
121            .map(T::from_record)
122            .collect::<Result<Vec<_>, _>>()
123            .map(SmartList::from)
124            .map_err(RepositoryError::Entity)
125    }
126
127    pub fn insert(
128        &self,
129        command: &InsertCommand,
130    ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
131        self.require_entity(&command.entity)?;
132        let mut data = self
133            .data
134            .lock()
135            .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
136        data.entry(command.entity.clone())
137            .or_default()
138            .push(command.values.clone());
139        Ok(1)
140    }
141
142    pub fn update(
143        &self,
144        command: &UpdateCommand,
145    ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
146        let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
147        let mut data = self
148            .data
149            .lock()
150            .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
151        let rows = data.entry(command.entity.clone()).or_default();
152        let Some(row) = rows
153            .iter_mut()
154            .find(|row| row.get(id_property) == Some(&command.id))
155        else {
156            return self.maybe_optimistic_conflict(
157                command.expected_version,
158                &command.entity,
159                &command.id,
160            );
161        };
162
163        if let Some(expected_version) = command.expected_version {
164            if row.get(version_property) != Some(&Value::I64(expected_version)) {
165                return Err(RepositoryError::Runtime(
166                    RuntimeError::OptimisticLockConflict {
167                        entity: command.entity.clone(),
168                        id: format!("{:?}", command.id),
169                    },
170                ));
171            }
172            row.insert(
173                version_property.to_owned(),
174                Value::I64(expected_version + 1),
175            );
176        }
177
178        for (key, value) in &command.values {
179            row.insert(key.clone(), value.clone());
180        }
181        Ok(1)
182    }
183
184    pub fn delete(
185        &self,
186        command: &DeleteCommand,
187    ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
188        let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
189        let mut data = self
190            .data
191            .lock()
192            .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
193        let rows = data.entry(command.entity.clone()).or_default();
194        let Some(index) = rows
195            .iter()
196            .position(|row| row.get(id_property) == Some(&command.id))
197        else {
198            return self.maybe_optimistic_conflict(
199                command.expected_version,
200                &command.entity,
201                &command.id,
202            );
203        };
204
205        if let Some(expected_version) = command.expected_version {
206            if rows[index].get(version_property) != Some(&Value::I64(expected_version)) {
207                return Err(RepositoryError::Runtime(
208                    RuntimeError::OptimisticLockConflict {
209                        entity: command.entity.clone(),
210                        id: format!("{:?}", command.id),
211                    },
212                ));
213            }
214        }
215
216        if command.soft_delete {
217            let next_version = command
218                .expected_version
219                .or_else(|| read_i64(rows[index].get(version_property)))
220                .map(|version| -(version.abs() + 1))
221                .unwrap_or(-1);
222            rows[index].insert(version_property.to_owned(), Value::I64(next_version));
223        } else {
224            rows.remove(index);
225        }
226        Ok(1)
227    }
228
229    pub fn recover(
230        &self,
231        command: &RecoverCommand,
232    ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
233        let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
234        let mut data = self
235            .data
236            .lock()
237            .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
238        let rows = data.entry(command.entity.clone()).or_default();
239        let Some(row) = rows
240            .iter_mut()
241            .find(|row| row.get(id_property) == Some(&command.id))
242        else {
243            return Err(RepositoryError::Runtime(
244                RuntimeError::OptimisticLockConflict {
245                    entity: command.entity.clone(),
246                    id: format!("{:?}", command.id),
247                },
248            ));
249        };
250
251        if row.get(version_property) != Some(&Value::I64(command.expected_version)) {
252            return Err(RepositoryError::Runtime(
253                RuntimeError::OptimisticLockConflict {
254                    entity: command.entity.clone(),
255                    id: format!("{:?}", command.id),
256                },
257            ));
258        }
259
260        row.insert(
261            version_property.to_owned(),
262            Value::I64(command.expected_version.abs() + 1),
263        );
264        Ok(1)
265    }
266
267    fn require_entity(&self, entity: &str) -> Result<(), RepositoryError<MemoryRepositoryError>> {
268        self.metadata
269            .entity(entity)
270            .map(|_| ())
271            .ok_or_else(|| RepositoryError::Runtime(RuntimeError::MissingEntity(entity.to_owned())))
272    }
273
274    fn id_and_version_properties(
275        &self,
276        entity: &str,
277    ) -> Result<(&str, &str), RepositoryError<MemoryRepositoryError>> {
278        let descriptor = self.metadata.entity(entity).ok_or_else(|| {
279            RepositoryError::Runtime(RuntimeError::MissingEntity(entity.to_owned()))
280        })?;
281        let id = descriptor
282            .id_property()
283            .map(|property| property.name.as_str())
284            .unwrap_or("id");
285        let version = descriptor
286            .version_property()
287            .map(|property| property.name.as_str())
288            .unwrap_or("version");
289        Ok((id, version))
290    }
291
292    fn maybe_optimistic_conflict(
293        &self,
294        expected_version: Option<i64>,
295        entity: &str,
296        id: &Value,
297    ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
298        if expected_version.is_some() {
299            Err(RepositoryError::Runtime(
300                RuntimeError::OptimisticLockConflict {
301                    entity: entity.to_owned(),
302                    id: format!("{id:?}"),
303                },
304            ))
305        } else {
306            Ok(0)
307        }
308    }
309}
310
311fn eval_filter(expr: &Expr, row: &Record) -> Result<bool, MemoryRepositoryError> {
312    match expr {
313        Expr::Column(_) | Expr::Value(_) | Expr::Function { .. } => {
314            value_truthy(&eval_value(expr, row)?)
315        }
316        Expr::Binary { left, op, right } => {
317            let left = eval_value(left, row)?;
318            let right = eval_value(right, row)?;
319            eval_binary(&left, *op, &right)
320        }
321        Expr::SubQuery { .. } => Err(MemoryRepositoryError::UnsupportedExpression(
322            "subquery filters require a SQL executor".to_owned(),
323        )),
324        Expr::Between { expr, lower, upper } => {
325            let value = eval_value(expr, row)?;
326            let lower = eval_value(lower, row)?;
327            let upper = eval_value(upper, row)?;
328            Ok(compare_values(&value, &lower) != Some(Ordering::Less)
329                && compare_values(&value, &upper) != Some(Ordering::Greater))
330        }
331        Expr::IsNull(expr) => Ok(matches!(eval_value(expr, row)?, Value::Null)),
332        Expr::IsNotNull(expr) => Ok(!matches!(eval_value(expr, row)?, Value::Null)),
333        Expr::And(parts) => {
334            for part in parts {
335                if !eval_filter(part, row)? {
336                    return Ok(false);
337                }
338            }
339            Ok(true)
340        }
341        Expr::Or(parts) => {
342            for part in parts {
343                if eval_filter(part, row)? {
344                    return Ok(true);
345                }
346            }
347            Ok(false)
348        }
349        Expr::Not(expr) => Ok(!eval_filter(expr, row)?),
350    }
351}
352
353fn eval_value(expr: &Expr, row: &Record) -> Result<Value, MemoryRepositoryError> {
354    match expr {
355        Expr::Column(column) => Ok(row.get(column).cloned().unwrap_or(Value::Null)),
356        Expr::Value(value) => Ok(value.clone()),
357        Expr::Function { function, args } => eval_function(*function, args, row),
358        other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
359            "cannot evaluate {other:?} as a scalar value"
360        ))),
361    }
362}
363
364fn eval_function(
365    function: ExprFunction,
366    args: &[Expr],
367    row: &Record,
368) -> Result<Value, MemoryRepositoryError> {
369    match function {
370        ExprFunction::Soundex => {
371            let [arg] = args else {
372                return Err(MemoryRepositoryError::UnsupportedExpression(
373                    "SOUNDEX expects exactly one argument".to_owned(),
374                ));
375            };
376            match eval_value(arg, row)? {
377                Value::Text(value) => Ok(Value::Text(soundex(&value))),
378                Value::Null => Ok(Value::Null),
379                other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
380                    "SOUNDEX expects text, got {other:?}"
381                ))),
382            }
383        }
384        ExprFunction::Gbk => {
385            let [arg] = args else {
386                return Err(MemoryRepositoryError::UnsupportedExpression(
387                    "GBK expects exactly one argument".to_owned(),
388                ));
389            };
390            eval_value(arg, row)
391        }
392        other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
393            "function {other:?} is only supported by SQL execution"
394        ))),
395    }
396}
397
398fn eval_binary(left: &Value, op: BinaryOp, right: &Value) -> Result<bool, MemoryRepositoryError> {
399    match op {
400        BinaryOp::Eq => Ok(values_equal(left, right)),
401        BinaryOp::Ne => Ok(!values_equal(left, right)),
402        BinaryOp::Gt => Ok(compare_values(left, right) == Some(Ordering::Greater)),
403        BinaryOp::Gte => Ok(matches!(
404            compare_values(left, right),
405            Some(Ordering::Greater | Ordering::Equal)
406        )),
407        BinaryOp::Lt => Ok(compare_values(left, right) == Some(Ordering::Less)),
408        BinaryOp::Lte => Ok(matches!(
409            compare_values(left, right),
410            Some(Ordering::Less | Ordering::Equal)
411        )),
412        BinaryOp::Like => match (left, right) {
413            (Value::Text(value), Value::Text(pattern)) => Ok(like_matches(value, pattern)),
414            _ => Ok(false),
415        },
416        BinaryOp::NotLike => match (left, right) {
417            (Value::Text(value), Value::Text(pattern)) => Ok(!like_matches(value, pattern)),
418            _ => Ok(true),
419        },
420        BinaryOp::In | BinaryOp::InLarge => match right {
421            Value::List(values) => Ok(values.iter().any(|value| values_equal(left, value))),
422            _ => Err(MemoryRepositoryError::UnsupportedExpression(
423                "IN expects a list value".to_owned(),
424            )),
425        },
426        BinaryOp::NotIn | BinaryOp::NotInLarge => match right {
427            Value::List(values) => Ok(!values.iter().any(|value| values_equal(left, value))),
428            _ => Err(MemoryRepositoryError::UnsupportedExpression(
429                "NOT IN expects a list value".to_owned(),
430            )),
431        },
432    }
433}
434
435fn value_truthy(value: &Value) -> Result<bool, MemoryRepositoryError> {
436    match value {
437        Value::Bool(value) => Ok(*value),
438        Value::Null => Ok(false),
439        other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
440            "non-boolean expression result: {other:?}"
441        ))),
442    }
443}
444
445fn values_equal(left: &Value, right: &Value) -> bool {
446    match (left, right) {
447        (Value::I64(left), Value::U64(right)) if *left >= 0 => *left as u64 == *right,
448        (Value::U64(left), Value::I64(right)) if *right >= 0 => *left == *right as u64,
449        _ => left == right,
450    }
451}
452
453fn compare_values(left: &Value, right: &Value) -> Option<Ordering> {
454    match (left, right) {
455        (Value::I64(left), Value::I64(right)) => Some(left.cmp(right)),
456        (Value::U64(left), Value::U64(right)) => Some(left.cmp(right)),
457        (Value::I64(left), Value::U64(right)) if *left >= 0 => Some((*left as u64).cmp(right)),
458        (Value::U64(left), Value::I64(right)) if *right >= 0 => Some(left.cmp(&(*right as u64))),
459        (Value::F64(left), Value::F64(right)) => left.partial_cmp(right),
460        (Value::Decimal(left), Value::Decimal(right)) => Some(left.cmp(right)),
461        (Value::Text(left), Value::Text(right)) => Some(left.cmp(right)),
462        (Value::Date(left), Value::Date(right)) => Some(left.cmp(right)),
463        (Value::Timestamp(left), Value::Timestamp(right)) => Some(left.cmp(right)),
464        _ => None,
465    }
466}
467
468fn like_matches(value: &str, pattern: &str) -> bool {
469    if pattern == "%" {
470        return true;
471    }
472    match (pattern.strip_prefix('%'), pattern.strip_suffix('%')) {
473        (Some(inner), Some(_)) if pattern.len() >= 2 => value.contains(&inner[..inner.len() - 1]),
474        (Some(suffix), None) => value.ends_with(suffix),
475        (None, Some(prefix)) => value.starts_with(prefix),
476        _ => value == pattern,
477    }
478}
479
480fn soundex(value: &str) -> String {
481    let mut letters = value
482        .chars()
483        .filter(|ch| ch.is_ascii_alphabetic())
484        .map(|ch| ch.to_ascii_uppercase());
485    let Some(first) = letters.next() else {
486        return "0000".to_owned();
487    };
488
489    let mut output = String::with_capacity(4);
490    output.push(first);
491    let mut previous_code = soundex_code(first);
492
493    for ch in letters {
494        let code = soundex_code(ch);
495        if code != '0' && code != previous_code {
496            output.push(code);
497            if output.len() == 4 {
498                return output;
499            }
500        }
501        previous_code = code;
502    }
503
504    while output.len() < 4 {
505        output.push('0');
506    }
507    output
508}
509
510fn soundex_code(ch: char) -> char {
511    match ch {
512        'B' | 'F' | 'P' | 'V' => '1',
513        'C' | 'G' | 'J' | 'K' | 'Q' | 'S' | 'X' | 'Z' => '2',
514        'D' | 'T' => '3',
515        'L' => '4',
516        'M' | 'N' => '5',
517        'R' => '6',
518        _ => '0',
519    }
520}
521
522fn apply_ordering(rows: &mut [Record], query: &SelectQuery) {
523    for order in query.order_by.iter().rev() {
524        rows.sort_by(|left, right| {
525            let left_value = if let Some(expr) = &order.expr {
526                eval_value(expr, left).ok()
527            } else {
528                left.get(&order.field).cloned()
529            };
530            let right_value = if let Some(expr) = &order.expr {
531                eval_value(expr, right).ok()
532            } else {
533                right.get(&order.field).cloned()
534            };
535            let ordering = match (left_value.as_ref(), right_value.as_ref()) {
536                (Some(left), Some(right)) => compare_values(left, right).unwrap_or(Ordering::Equal),
537                (None, Some(_)) => Ordering::Less,
538                (Some(_), None) => Ordering::Greater,
539                (None, None) => Ordering::Equal,
540            };
541            match order.direction {
542                SortDirection::Asc => ordering,
543                SortDirection::Desc => ordering.reverse(),
544            }
545        });
546    }
547}
548
549fn apply_slice(rows: Vec<Record>, query: &SelectQuery) -> Vec<Record> {
550    let Some(slice) = query.slice else {
551        return rows;
552    };
553    let offset = usize::try_from(slice.offset).unwrap_or(usize::MAX);
554    let limit = slice
555        .limit
556        .and_then(|limit| usize::try_from(limit).ok())
557        .unwrap_or(usize::MAX);
558    rows.into_iter().skip(offset).take(limit).collect()
559}
560
561fn project_row(row: Record, query: &SelectQuery) -> Result<Record, MemoryRepositoryError> {
562    let mut output: Record = query
563        .projection
564        .iter()
565        .filter_map(|field| row.get(field).cloned().map(|value| (field.clone(), value)))
566        .collect();
567    for projection in &query.expr_projection {
568        output.insert(
569            projection.alias.clone(),
570            eval_value(&projection.expr, &row)?,
571        );
572    }
573    Ok(output)
574}
575
576fn aggregate_rows(
577    query: &SelectQuery,
578    rows: &[Record],
579) -> Result<Vec<Record>, MemoryRepositoryError> {
580    let mut groups: BTreeMap<Vec<String>, Vec<&Record>> = BTreeMap::new();
581    if query.group_by.is_empty() {
582        groups.insert(Vec::new(), rows.iter().collect());
583    } else {
584        for row in rows {
585            let key = query
586                .group_by
587                .iter()
588                .map(|field| row.get(field).map(value_key).unwrap_or_default())
589                .collect::<Vec<_>>();
590            groups.entry(key).or_default().push(row);
591        }
592    }
593
594    let rows = groups
595        .into_values()
596        .map(|rows| {
597            let mut output = Record::new();
598            if let Some(first) = rows.first() {
599                for field in &query.group_by {
600                    if let Some(value) = first.get(field) {
601                        output.insert(field.clone(), value.clone());
602                    }
603                }
604            }
605            for aggregate in &query.aggregates {
606                let value = match aggregate.function {
607                    AggregateFunction::Count => {
608                        if aggregate.field == "*" {
609                            Value::U64(rows.len() as u64)
610                        } else {
611                            Value::U64(
612                                rows.iter()
613                                    .filter(|row| {
614                                        !matches!(
615                                            row.get(&aggregate.field),
616                                            None | Some(Value::Null)
617                                        )
618                                    })
619                                    .count() as u64,
620                            )
621                        }
622                    }
623                    AggregateFunction::Sum => numeric_sum(&rows, &aggregate.field)?,
624                    AggregateFunction::Avg => numeric_avg(&rows, &aggregate.field)?,
625                    AggregateFunction::Min => min_max(&rows, &aggregate.field, false)?,
626                    AggregateFunction::Max => min_max(&rows, &aggregate.field, true)?,
627                    AggregateFunction::Stddev => numeric_stddev(&rows, &aggregate.field, true)?,
628                    AggregateFunction::StddevPop => numeric_stddev(&rows, &aggregate.field, false)?,
629                    AggregateFunction::VarSamp => numeric_variance(&rows, &aggregate.field, true)?,
630                    AggregateFunction::VarPop => numeric_variance(&rows, &aggregate.field, false)?,
631                    AggregateFunction::BitAnd => {
632                        bit_aggregate(&rows, &aggregate.field, BitOp::And)?
633                    }
634                    AggregateFunction::BitOr => bit_aggregate(&rows, &aggregate.field, BitOp::Or)?,
635                    AggregateFunction::BitXor => {
636                        bit_aggregate(&rows, &aggregate.field, BitOp::Xor)?
637                    }
638                };
639                output.insert(aggregate.alias.clone(), value);
640            }
641            for projection in &query.expr_projection {
642                output.insert(
643                    projection.alias.clone(),
644                    eval_value(&projection.expr, &output)?,
645                );
646            }
647            Ok(output)
648        })
649        .collect::<Result<Vec<_>, _>>()?;
650    if let Some(having) = &query.having {
651        rows.into_iter()
652            .filter_map(|row| match eval_filter(having, &row) {
653                Ok(true) => Some(Ok(row)),
654                Ok(false) => None,
655                Err(err) => Some(Err(err)),
656            })
657            .collect()
658    } else {
659        Ok(rows)
660    }
661}
662
663fn numeric_sum(rows: &[&Record], field: &str) -> Result<Value, MemoryRepositoryError> {
664    let mut decimal_sum = Decimal::ZERO;
665    let mut integer_sum: i128 = 0;
666    let mut saw_decimal = false;
667    for value in rows.iter().filter_map(|row| row.get(field)) {
668        match value {
669            Value::I64(value) => {
670                integer_sum += i128::from(*value);
671                decimal_sum += Decimal::from(*value);
672            }
673            Value::U64(value) => {
674                integer_sum += i128::from(*value);
675                decimal_sum += Decimal::from(*value);
676            }
677            Value::F64(value) => {
678                saw_decimal = true;
679                decimal_sum += decimal_from_f64(*value);
680            }
681            Value::Decimal(value) => {
682                saw_decimal = true;
683                decimal_sum += *value;
684            }
685            Value::Null => {}
686            other => {
687                return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
688                    "SUM does not support {other:?}"
689                )));
690            }
691        }
692    }
693    if saw_decimal {
694        Ok(Value::Decimal(decimal_sum))
695    } else if integer_sum >= 0 {
696        Ok(Value::U64(integer_sum as u64))
697    } else {
698        Ok(Value::I64(integer_sum as i64))
699    }
700}
701
702fn numeric_avg(rows: &[&Record], field: &str) -> Result<Value, MemoryRepositoryError> {
703    let mut sum = Decimal::ZERO;
704    let mut count: u64 = 0;
705    for value in rows.iter().filter_map(|row| row.get(field)) {
706        match value {
707            Value::I64(value) => {
708                sum += Decimal::from(*value);
709                count += 1;
710            }
711            Value::U64(value) => {
712                sum += Decimal::from(*value);
713                count += 1;
714            }
715            Value::F64(value) => {
716                sum += decimal_from_f64(*value);
717                count += 1;
718            }
719            Value::Decimal(value) => {
720                sum += *value;
721                count += 1;
722            }
723            Value::Null => {}
724            other => {
725                return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
726                    "AVG does not support {other:?}"
727                )));
728            }
729        }
730    }
731    Ok(if count == 0 {
732        Value::Null
733    } else {
734        Value::Decimal(sum / Decimal::from(count))
735    })
736}
737
738fn decimal_from_f64(value: f64) -> Decimal {
739    Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
740}
741
742fn numeric_values(rows: &[&Record], field: &str) -> Result<Vec<f64>, MemoryRepositoryError> {
743    rows.iter()
744        .filter_map(|row| row.get(field))
745        .filter(|value| !matches!(value, Value::Null))
746        .map(|value| match value {
747            Value::I64(value) => Ok(*value as f64),
748            Value::U64(value) => Ok(*value as f64),
749            Value::F64(value) => Ok(*value),
750            Value::Decimal(value) => value.to_f64().ok_or_else(|| {
751                MemoryRepositoryError::UnsupportedAggregate(format!(
752                    "cannot convert decimal {value} to f64 for statistical aggregate"
753                ))
754            }),
755            other => Err(MemoryRepositoryError::UnsupportedAggregate(format!(
756                "numeric aggregate does not support {other:?}"
757            ))),
758        })
759        .collect()
760}
761
762fn numeric_variance(
763    rows: &[&Record],
764    field: &str,
765    sample: bool,
766) -> Result<Value, MemoryRepositoryError> {
767    let values = numeric_values(rows, field)?;
768    let count = values.len();
769    if count == 0 || (sample && count < 2) {
770        return Ok(Value::Null);
771    }
772    let mean = values.iter().sum::<f64>() / count as f64;
773    let sum = values
774        .iter()
775        .map(|value| {
776            let diff = value - mean;
777            diff * diff
778        })
779        .sum::<f64>();
780    let denominator = if sample { count - 1 } else { count } as f64;
781    Ok(Value::Decimal(decimal_from_f64(sum / denominator)))
782}
783
784fn numeric_stddev(
785    rows: &[&Record],
786    field: &str,
787    sample: bool,
788) -> Result<Value, MemoryRepositoryError> {
789    Ok(match numeric_variance(rows, field, sample)? {
790        Value::Decimal(value) => {
791            Value::Decimal(decimal_from_f64(value.to_f64().unwrap_or(0.0).sqrt()))
792        }
793        Value::Null => Value::Null,
794        other => other,
795    })
796}
797
798#[derive(Debug, Clone, Copy)]
799enum BitOp {
800    And,
801    Or,
802    Xor,
803}
804
805fn bit_aggregate(rows: &[&Record], field: &str, op: BitOp) -> Result<Value, MemoryRepositoryError> {
806    let mut selected: Option<i64> = None;
807    for value in rows.iter().filter_map(|row| row.get(field)) {
808        let value = match value {
809            Value::I64(value) => *value,
810            Value::U64(value) => i64::try_from(*value).map_err(|_| {
811                MemoryRepositoryError::UnsupportedAggregate(format!(
812                    "BIT aggregate u64 {value} exceeds i64 range"
813                ))
814            })?,
815            Value::Null => continue,
816            other => {
817                return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
818                    "BIT aggregate does not support {other:?}"
819                )));
820            }
821        };
822        selected = Some(match (selected, op) {
823            (None, _) => value,
824            (Some(current), BitOp::And) => current & value,
825            (Some(current), BitOp::Or) => current | value,
826            (Some(current), BitOp::Xor) => current ^ value,
827        });
828    }
829    Ok(selected.map(Value::I64).unwrap_or(Value::Null))
830}
831
832fn min_max(rows: &[&Record], field: &str, max: bool) -> Result<Value, MemoryRepositoryError> {
833    let mut selected: Option<Value> = None;
834    for value in rows.iter().filter_map(|row| row.get(field)) {
835        if matches!(value, Value::Null) {
836            continue;
837        }
838        match &selected {
839            None => selected = Some(value.clone()),
840            Some(current) => {
841                let Some(ordering) = compare_values(value, current) else {
842                    return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
843                        "MIN/MAX does not support {value:?}"
844                    )));
845                };
846                if (max && ordering == Ordering::Greater) || (!max && ordering == Ordering::Less) {
847                    selected = Some(value.clone());
848                }
849            }
850        }
851    }
852    Ok(selected.unwrap_or(Value::Null))
853}
854
855fn value_key(value: &Value) -> String {
856    match value {
857        Value::Null => "null".to_owned(),
858        Value::Bool(value) => format!("b:{value}"),
859        Value::I64(value) => format!("i:{value}"),
860        Value::U64(value) => format!("u:{value}"),
861        Value::F64(value) => format!("f:{value}"),
862        Value::Decimal(value) => format!("d:{value}"),
863        Value::Text(value) => format!("t:{value}"),
864        Value::Json(value) => format!("j:{value}"),
865        Value::Date(value) => format!("d:{value}"),
866        Value::Timestamp(value) => format!("ts:{}", value.to_rfc3339()),
867        Value::Object(_) => "object".to_owned(),
868        Value::List(_) => "list".to_owned(),
869    }
870}
871
872fn read_i64(value: Option<&Value>) -> Option<i64> {
873    match value {
874        Some(Value::I64(value)) => Some(*value),
875        _ => None,
876    }
877}