1use std::cmp::Ordering;
2use std::collections::BTreeMap;
3use std::sync::{Arc, Mutex};
4
5use rust_decimal::Decimal;
6use rust_decimal::prelude::ToPrimitive;
7use teaql_core::{
8 AggregateFunction, BinaryOp, DeleteCommand, Entity, Expr, ExprFunction, InsertCommand, Record,
9 RecoverCommand, SelectQuery, SmartList, SortDirection, UpdateCommand, Value,
10};
11
12use crate::{InMemoryMetadataStore, MetadataStore, RepositoryError, RuntimeError};
13
14#[derive(Debug)]
15pub enum MemoryRepositoryError {
16 Poisoned,
17 UnsupportedExpression(String),
18 UnsupportedAggregate(String),
19}
20
21impl std::fmt::Display for MemoryRepositoryError {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 match self {
24 Self::Poisoned => write!(f, "memory repository lock poisoned"),
25 Self::UnsupportedExpression(message) => {
26 write!(f, "unsupported memory expression: {message}")
27 }
28 Self::UnsupportedAggregate(message) => {
29 write!(f, "unsupported memory aggregate: {message}")
30 }
31 }
32 }
33}
34
35impl std::error::Error for MemoryRepositoryError {}
36
37#[derive(Debug, Clone)]
38pub struct MemoryRepository<M = InMemoryMetadataStore> {
39 metadata: M,
40 data: Arc<Mutex<BTreeMap<String, Vec<Record>>>>,
41}
42
43impl<M> MemoryRepository<M>
44where
45 M: MetadataStore,
46{
47 pub fn new(metadata: M) -> Self {
48 Self {
49 metadata,
50 data: Arc::new(Mutex::new(BTreeMap::new())),
51 }
52 }
53
54 pub fn with_rows(mut self, entity: impl Into<String>, rows: Vec<Record>) -> Self {
55 self.seed(entity, rows);
56 self
57 }
58
59 pub fn seed(&mut self, entity: impl Into<String>, rows: Vec<Record>) {
60 if let Ok(mut data) = self.data.lock() {
61 data.insert(entity.into(), rows);
62 }
63 }
64
65 pub fn fetch_all(
66 &self,
67 query: &SelectQuery,
68 ) -> Result<Vec<Record>, RepositoryError<MemoryRepositoryError>> {
69 self.require_entity(&query.entity)?;
70 let data = self
71 .data
72 .lock()
73 .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
74 let mut rows = data.get(&query.entity).cloned().unwrap_or_default();
75 drop(data);
76
77 if let Some(filter) = &query.filter {
78 rows = rows
79 .into_iter()
80 .filter_map(|row| match eval_filter(filter, &row) {
81 Ok(true) => Some(Ok(row)),
82 Ok(false) => None,
83 Err(err) => Some(Err(err)),
84 })
85 .collect::<Result<Vec<_>, _>>()
86 .map_err(RepositoryError::Executor)?;
87 }
88
89 if !query.aggregates.is_empty() {
90 return aggregate_rows(query, &rows).map_err(RepositoryError::Executor);
91 }
92
93 apply_ordering(&mut rows, query);
94 rows = apply_slice(rows, query);
95 if !query.projection.is_empty() || !query.expr_projection.is_empty() {
96 rows = rows
97 .into_iter()
98 .map(|row| project_row(row, query))
99 .collect::<Result<Vec<_>, _>>()
100 .map_err(RepositoryError::Executor)?;
101 }
102 Ok(rows)
103 }
104
105 pub fn fetch_smart_list(
106 &self,
107 query: &SelectQuery,
108 ) -> Result<SmartList<Record>, RepositoryError<MemoryRepositoryError>> {
109 self.fetch_all(query).map(SmartList::from)
110 }
111
112 pub fn fetch_entities<T>(
113 &self,
114 query: &SelectQuery,
115 ) -> Result<SmartList<T>, RepositoryError<MemoryRepositoryError>>
116 where
117 T: Entity,
118 {
119 self.fetch_all(query)?
120 .into_iter()
121 .map(T::from_record)
122 .collect::<Result<Vec<_>, _>>()
123 .map(SmartList::from)
124 .map_err(RepositoryError::Entity)
125 }
126
127 pub fn insert(
128 &self,
129 command: &InsertCommand,
130 ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
131 self.require_entity(&command.entity)?;
132 let mut data = self
133 .data
134 .lock()
135 .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
136 data.entry(command.entity.clone())
137 .or_default()
138 .push(command.values.clone());
139 Ok(1)
140 }
141
142 pub fn update(
143 &self,
144 command: &UpdateCommand,
145 ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
146 let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
147 let mut data = self
148 .data
149 .lock()
150 .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
151 let rows = data.entry(command.entity.clone()).or_default();
152 let Some(row) = rows
153 .iter_mut()
154 .find(|row| row.get(id_property) == Some(&command.id))
155 else {
156 return self.maybe_optimistic_conflict(
157 command.expected_version,
158 &command.entity,
159 &command.id,
160 );
161 };
162
163 if let Some(expected_version) = command.expected_version {
164 if row.get(version_property) != Some(&Value::I64(expected_version)) {
165 return Err(RepositoryError::Runtime(
166 RuntimeError::OptimisticLockConflict {
167 entity: command.entity.clone(),
168 id: format!("{:?}", command.id),
169 },
170 ));
171 }
172 row.insert(
173 version_property.to_owned(),
174 Value::I64(expected_version + 1),
175 );
176 }
177
178 for (key, value) in &command.values {
179 row.insert(key.clone(), value.clone());
180 }
181 Ok(1)
182 }
183
184 pub fn delete(
185 &self,
186 command: &DeleteCommand,
187 ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
188 let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
189 let mut data = self
190 .data
191 .lock()
192 .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
193 let rows = data.entry(command.entity.clone()).or_default();
194 let Some(index) = rows
195 .iter()
196 .position(|row| row.get(id_property) == Some(&command.id))
197 else {
198 return self.maybe_optimistic_conflict(
199 command.expected_version,
200 &command.entity,
201 &command.id,
202 );
203 };
204
205 if let Some(expected_version) = command.expected_version {
206 if rows[index].get(version_property) != Some(&Value::I64(expected_version)) {
207 return Err(RepositoryError::Runtime(
208 RuntimeError::OptimisticLockConflict {
209 entity: command.entity.clone(),
210 id: format!("{:?}", command.id),
211 },
212 ));
213 }
214 }
215
216 if command.soft_delete {
217 let next_version = command
218 .expected_version
219 .or_else(|| read_i64(rows[index].get(version_property)))
220 .map(|version| -(version.abs() + 1))
221 .unwrap_or(-1);
222 rows[index].insert(version_property.to_owned(), Value::I64(next_version));
223 } else {
224 rows.remove(index);
225 }
226 Ok(1)
227 }
228
229 pub fn recover(
230 &self,
231 command: &RecoverCommand,
232 ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
233 let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
234 let mut data = self
235 .data
236 .lock()
237 .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
238 let rows = data.entry(command.entity.clone()).or_default();
239 let Some(row) = rows
240 .iter_mut()
241 .find(|row| row.get(id_property) == Some(&command.id))
242 else {
243 return Err(RepositoryError::Runtime(
244 RuntimeError::OptimisticLockConflict {
245 entity: command.entity.clone(),
246 id: format!("{:?}", command.id),
247 },
248 ));
249 };
250
251 if row.get(version_property) != Some(&Value::I64(command.expected_version)) {
252 return Err(RepositoryError::Runtime(
253 RuntimeError::OptimisticLockConflict {
254 entity: command.entity.clone(),
255 id: format!("{:?}", command.id),
256 },
257 ));
258 }
259
260 row.insert(
261 version_property.to_owned(),
262 Value::I64(command.expected_version.abs() + 1),
263 );
264 Ok(1)
265 }
266
267 fn require_entity(&self, entity: &str) -> Result<(), RepositoryError<MemoryRepositoryError>> {
268 self.metadata
269 .entity(entity)
270 .map(|_| ())
271 .ok_or_else(|| RepositoryError::Runtime(RuntimeError::MissingEntity(entity.to_owned())))
272 }
273
274 fn id_and_version_properties(
275 &self,
276 entity: &str,
277 ) -> Result<(&str, &str), RepositoryError<MemoryRepositoryError>> {
278 let descriptor = self.metadata.entity(entity).ok_or_else(|| {
279 RepositoryError::Runtime(RuntimeError::MissingEntity(entity.to_owned()))
280 })?;
281 let id = descriptor
282 .id_property()
283 .map(|property| property.name.as_str())
284 .unwrap_or("id");
285 let version = descriptor
286 .version_property()
287 .map(|property| property.name.as_str())
288 .unwrap_or("version");
289 Ok((id, version))
290 }
291
292 fn maybe_optimistic_conflict(
293 &self,
294 expected_version: Option<i64>,
295 entity: &str,
296 id: &Value,
297 ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
298 if expected_version.is_some() {
299 Err(RepositoryError::Runtime(
300 RuntimeError::OptimisticLockConflict {
301 entity: entity.to_owned(),
302 id: format!("{id:?}"),
303 },
304 ))
305 } else {
306 Ok(0)
307 }
308 }
309}
310
311fn eval_filter(expr: &Expr, row: &Record) -> Result<bool, MemoryRepositoryError> {
312 match expr {
313 Expr::Column(_) | Expr::Value(_) | Expr::Function { .. } => {
314 value_truthy(&eval_value(expr, row)?)
315 }
316 Expr::Binary { left, op, right } => {
317 let left = eval_value(left, row)?;
318 let right = eval_value(right, row)?;
319 eval_binary(&left, *op, &right)
320 }
321 Expr::SubQuery { .. } => Err(MemoryRepositoryError::UnsupportedExpression(
322 "subquery filters require a SQL executor".to_owned(),
323 )),
324 Expr::Between { expr, lower, upper } => {
325 let value = eval_value(expr, row)?;
326 let lower = eval_value(lower, row)?;
327 let upper = eval_value(upper, row)?;
328 Ok(compare_values(&value, &lower) != Some(Ordering::Less)
329 && compare_values(&value, &upper) != Some(Ordering::Greater))
330 }
331 Expr::IsNull(expr) => Ok(matches!(eval_value(expr, row)?, Value::Null)),
332 Expr::IsNotNull(expr) => Ok(!matches!(eval_value(expr, row)?, Value::Null)),
333 Expr::And(parts) => {
334 for part in parts {
335 if !eval_filter(part, row)? {
336 return Ok(false);
337 }
338 }
339 Ok(true)
340 }
341 Expr::Or(parts) => {
342 for part in parts {
343 if eval_filter(part, row)? {
344 return Ok(true);
345 }
346 }
347 Ok(false)
348 }
349 Expr::Not(expr) => Ok(!eval_filter(expr, row)?),
350 }
351}
352
353fn eval_value(expr: &Expr, row: &Record) -> Result<Value, MemoryRepositoryError> {
354 match expr {
355 Expr::Column(column) => Ok(row.get(column).cloned().unwrap_or(Value::Null)),
356 Expr::Value(value) => Ok(value.clone()),
357 Expr::Function { function, args } => eval_function(*function, args, row),
358 other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
359 "cannot evaluate {other:?} as a scalar value"
360 ))),
361 }
362}
363
364fn eval_function(
365 function: ExprFunction,
366 args: &[Expr],
367 row: &Record,
368) -> Result<Value, MemoryRepositoryError> {
369 match function {
370 ExprFunction::Soundex => {
371 let [arg] = args else {
372 return Err(MemoryRepositoryError::UnsupportedExpression(
373 "SOUNDEX expects exactly one argument".to_owned(),
374 ));
375 };
376 match eval_value(arg, row)? {
377 Value::Text(value) => Ok(Value::Text(soundex(&value))),
378 Value::Null => Ok(Value::Null),
379 other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
380 "SOUNDEX expects text, got {other:?}"
381 ))),
382 }
383 }
384 ExprFunction::Gbk => {
385 let [arg] = args else {
386 return Err(MemoryRepositoryError::UnsupportedExpression(
387 "GBK expects exactly one argument".to_owned(),
388 ));
389 };
390 eval_value(arg, row)
391 }
392 other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
393 "function {other:?} is only supported by SQL execution"
394 ))),
395 }
396}
397
398fn eval_binary(left: &Value, op: BinaryOp, right: &Value) -> Result<bool, MemoryRepositoryError> {
399 match op {
400 BinaryOp::Eq => Ok(values_equal(left, right)),
401 BinaryOp::Ne => Ok(!values_equal(left, right)),
402 BinaryOp::Gt => Ok(compare_values(left, right) == Some(Ordering::Greater)),
403 BinaryOp::Gte => Ok(matches!(
404 compare_values(left, right),
405 Some(Ordering::Greater | Ordering::Equal)
406 )),
407 BinaryOp::Lt => Ok(compare_values(left, right) == Some(Ordering::Less)),
408 BinaryOp::Lte => Ok(matches!(
409 compare_values(left, right),
410 Some(Ordering::Less | Ordering::Equal)
411 )),
412 BinaryOp::Like => match (left, right) {
413 (Value::Text(value), Value::Text(pattern)) => Ok(like_matches(value, pattern)),
414 _ => Ok(false),
415 },
416 BinaryOp::NotLike => match (left, right) {
417 (Value::Text(value), Value::Text(pattern)) => Ok(!like_matches(value, pattern)),
418 _ => Ok(true),
419 },
420 BinaryOp::In | BinaryOp::InLarge => match right {
421 Value::List(values) => Ok(values.iter().any(|value| values_equal(left, value))),
422 _ => Err(MemoryRepositoryError::UnsupportedExpression(
423 "IN expects a list value".to_owned(),
424 )),
425 },
426 BinaryOp::NotIn | BinaryOp::NotInLarge => match right {
427 Value::List(values) => Ok(!values.iter().any(|value| values_equal(left, value))),
428 _ => Err(MemoryRepositoryError::UnsupportedExpression(
429 "NOT IN expects a list value".to_owned(),
430 )),
431 },
432 }
433}
434
435fn value_truthy(value: &Value) -> Result<bool, MemoryRepositoryError> {
436 match value {
437 Value::Bool(value) => Ok(*value),
438 Value::Null => Ok(false),
439 other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
440 "non-boolean expression result: {other:?}"
441 ))),
442 }
443}
444
445fn values_equal(left: &Value, right: &Value) -> bool {
446 match (left, right) {
447 (Value::I64(left), Value::U64(right)) if *left >= 0 => *left as u64 == *right,
448 (Value::U64(left), Value::I64(right)) if *right >= 0 => *left == *right as u64,
449 _ => left == right,
450 }
451}
452
453fn compare_values(left: &Value, right: &Value) -> Option<Ordering> {
454 match (left, right) {
455 (Value::I64(left), Value::I64(right)) => Some(left.cmp(right)),
456 (Value::U64(left), Value::U64(right)) => Some(left.cmp(right)),
457 (Value::I64(left), Value::U64(right)) if *left >= 0 => Some((*left as u64).cmp(right)),
458 (Value::U64(left), Value::I64(right)) if *right >= 0 => Some(left.cmp(&(*right as u64))),
459 (Value::F64(left), Value::F64(right)) => left.partial_cmp(right),
460 (Value::Decimal(left), Value::Decimal(right)) => Some(left.cmp(right)),
461 (Value::Text(left), Value::Text(right)) => Some(left.cmp(right)),
462 (Value::Date(left), Value::Date(right)) => Some(left.cmp(right)),
463 (Value::Timestamp(left), Value::Timestamp(right)) => Some(left.cmp(right)),
464 _ => None,
465 }
466}
467
468fn like_matches(value: &str, pattern: &str) -> bool {
469 if pattern == "%" {
470 return true;
471 }
472 match (pattern.strip_prefix('%'), pattern.strip_suffix('%')) {
473 (Some(inner), Some(_)) if pattern.len() >= 2 => value.contains(&inner[..inner.len() - 1]),
474 (Some(suffix), None) => value.ends_with(suffix),
475 (None, Some(prefix)) => value.starts_with(prefix),
476 _ => value == pattern,
477 }
478}
479
480fn soundex(value: &str) -> String {
481 let mut letters = value
482 .chars()
483 .filter(|ch| ch.is_ascii_alphabetic())
484 .map(|ch| ch.to_ascii_uppercase());
485 let Some(first) = letters.next() else {
486 return "0000".to_owned();
487 };
488
489 let mut output = String::with_capacity(4);
490 output.push(first);
491 let mut previous_code = soundex_code(first);
492
493 for ch in letters {
494 let code = soundex_code(ch);
495 if code != '0' && code != previous_code {
496 output.push(code);
497 if output.len() == 4 {
498 return output;
499 }
500 }
501 previous_code = code;
502 }
503
504 while output.len() < 4 {
505 output.push('0');
506 }
507 output
508}
509
510fn soundex_code(ch: char) -> char {
511 match ch {
512 'B' | 'F' | 'P' | 'V' => '1',
513 'C' | 'G' | 'J' | 'K' | 'Q' | 'S' | 'X' | 'Z' => '2',
514 'D' | 'T' => '3',
515 'L' => '4',
516 'M' | 'N' => '5',
517 'R' => '6',
518 _ => '0',
519 }
520}
521
522fn apply_ordering(rows: &mut [Record], query: &SelectQuery) {
523 for order in query.order_by.iter().rev() {
524 rows.sort_by(|left, right| {
525 let left_value = if let Some(expr) = &order.expr {
526 eval_value(expr, left).ok()
527 } else {
528 left.get(&order.field).cloned()
529 };
530 let right_value = if let Some(expr) = &order.expr {
531 eval_value(expr, right).ok()
532 } else {
533 right.get(&order.field).cloned()
534 };
535 let ordering = match (left_value.as_ref(), right_value.as_ref()) {
536 (Some(left), Some(right)) => compare_values(left, right).unwrap_or(Ordering::Equal),
537 (None, Some(_)) => Ordering::Less,
538 (Some(_), None) => Ordering::Greater,
539 (None, None) => Ordering::Equal,
540 };
541 match order.direction {
542 SortDirection::Asc => ordering,
543 SortDirection::Desc => ordering.reverse(),
544 }
545 });
546 }
547}
548
549fn apply_slice(rows: Vec<Record>, query: &SelectQuery) -> Vec<Record> {
550 let Some(slice) = query.slice else {
551 return rows;
552 };
553 let offset = usize::try_from(slice.offset).unwrap_or(usize::MAX);
554 let limit = slice
555 .limit
556 .and_then(|limit| usize::try_from(limit).ok())
557 .unwrap_or(usize::MAX);
558 rows.into_iter().skip(offset).take(limit).collect()
559}
560
561fn project_row(row: Record, query: &SelectQuery) -> Result<Record, MemoryRepositoryError> {
562 let mut output: Record = query
563 .projection
564 .iter()
565 .filter_map(|field| row.get(field).cloned().map(|value| (field.clone(), value)))
566 .collect();
567 for projection in &query.expr_projection {
568 output.insert(
569 projection.alias.clone(),
570 eval_value(&projection.expr, &row)?,
571 );
572 }
573 Ok(output)
574}
575
576fn aggregate_rows(
577 query: &SelectQuery,
578 rows: &[Record],
579) -> Result<Vec<Record>, MemoryRepositoryError> {
580 let mut groups: BTreeMap<Vec<String>, Vec<&Record>> = BTreeMap::new();
581 if query.group_by.is_empty() {
582 groups.insert(Vec::new(), rows.iter().collect());
583 } else {
584 for row in rows {
585 let key = query
586 .group_by
587 .iter()
588 .map(|field| row.get(field).map(value_key).unwrap_or_default())
589 .collect::<Vec<_>>();
590 groups.entry(key).or_default().push(row);
591 }
592 }
593
594 let rows = groups
595 .into_values()
596 .map(|rows| {
597 let mut output = Record::new();
598 if let Some(first) = rows.first() {
599 for field in &query.group_by {
600 if let Some(value) = first.get(field) {
601 output.insert(field.clone(), value.clone());
602 }
603 }
604 }
605 for aggregate in &query.aggregates {
606 let value = match aggregate.function {
607 AggregateFunction::Count => {
608 if aggregate.field == "*" {
609 Value::U64(rows.len() as u64)
610 } else {
611 Value::U64(
612 rows.iter()
613 .filter(|row| {
614 !matches!(
615 row.get(&aggregate.field),
616 None | Some(Value::Null)
617 )
618 })
619 .count() as u64,
620 )
621 }
622 }
623 AggregateFunction::Sum => numeric_sum(&rows, &aggregate.field)?,
624 AggregateFunction::Avg => numeric_avg(&rows, &aggregate.field)?,
625 AggregateFunction::Min => min_max(&rows, &aggregate.field, false)?,
626 AggregateFunction::Max => min_max(&rows, &aggregate.field, true)?,
627 AggregateFunction::Stddev => numeric_stddev(&rows, &aggregate.field, true)?,
628 AggregateFunction::StddevPop => numeric_stddev(&rows, &aggregate.field, false)?,
629 AggregateFunction::VarSamp => numeric_variance(&rows, &aggregate.field, true)?,
630 AggregateFunction::VarPop => numeric_variance(&rows, &aggregate.field, false)?,
631 AggregateFunction::BitAnd => {
632 bit_aggregate(&rows, &aggregate.field, BitOp::And)?
633 }
634 AggregateFunction::BitOr => bit_aggregate(&rows, &aggregate.field, BitOp::Or)?,
635 AggregateFunction::BitXor => {
636 bit_aggregate(&rows, &aggregate.field, BitOp::Xor)?
637 }
638 };
639 output.insert(aggregate.alias.clone(), value);
640 }
641 for projection in &query.expr_projection {
642 output.insert(
643 projection.alias.clone(),
644 eval_value(&projection.expr, &output)?,
645 );
646 }
647 Ok(output)
648 })
649 .collect::<Result<Vec<_>, _>>()?;
650 if let Some(having) = &query.having {
651 rows.into_iter()
652 .filter_map(|row| match eval_filter(having, &row) {
653 Ok(true) => Some(Ok(row)),
654 Ok(false) => None,
655 Err(err) => Some(Err(err)),
656 })
657 .collect()
658 } else {
659 Ok(rows)
660 }
661}
662
663fn numeric_sum(rows: &[&Record], field: &str) -> Result<Value, MemoryRepositoryError> {
664 let mut decimal_sum = Decimal::ZERO;
665 let mut integer_sum: i128 = 0;
666 let mut saw_decimal = false;
667 for value in rows.iter().filter_map(|row| row.get(field)) {
668 match value {
669 Value::I64(value) => {
670 integer_sum += i128::from(*value);
671 decimal_sum += Decimal::from(*value);
672 }
673 Value::U64(value) => {
674 integer_sum += i128::from(*value);
675 decimal_sum += Decimal::from(*value);
676 }
677 Value::F64(value) => {
678 saw_decimal = true;
679 decimal_sum += decimal_from_f64(*value);
680 }
681 Value::Decimal(value) => {
682 saw_decimal = true;
683 decimal_sum += *value;
684 }
685 Value::Null => {}
686 other => {
687 return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
688 "SUM does not support {other:?}"
689 )));
690 }
691 }
692 }
693 if saw_decimal {
694 Ok(Value::Decimal(decimal_sum))
695 } else if integer_sum >= 0 {
696 Ok(Value::U64(integer_sum as u64))
697 } else {
698 Ok(Value::I64(integer_sum as i64))
699 }
700}
701
702fn numeric_avg(rows: &[&Record], field: &str) -> Result<Value, MemoryRepositoryError> {
703 let mut sum = Decimal::ZERO;
704 let mut count: u64 = 0;
705 for value in rows.iter().filter_map(|row| row.get(field)) {
706 match value {
707 Value::I64(value) => {
708 sum += Decimal::from(*value);
709 count += 1;
710 }
711 Value::U64(value) => {
712 sum += Decimal::from(*value);
713 count += 1;
714 }
715 Value::F64(value) => {
716 sum += decimal_from_f64(*value);
717 count += 1;
718 }
719 Value::Decimal(value) => {
720 sum += *value;
721 count += 1;
722 }
723 Value::Null => {}
724 other => {
725 return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
726 "AVG does not support {other:?}"
727 )));
728 }
729 }
730 }
731 Ok(if count == 0 {
732 Value::Null
733 } else {
734 Value::Decimal(sum / Decimal::from(count))
735 })
736}
737
738fn decimal_from_f64(value: f64) -> Decimal {
739 Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
740}
741
742fn numeric_values(rows: &[&Record], field: &str) -> Result<Vec<f64>, MemoryRepositoryError> {
743 rows.iter()
744 .filter_map(|row| row.get(field))
745 .filter(|value| !matches!(value, Value::Null))
746 .map(|value| match value {
747 Value::I64(value) => Ok(*value as f64),
748 Value::U64(value) => Ok(*value as f64),
749 Value::F64(value) => Ok(*value),
750 Value::Decimal(value) => value.to_f64().ok_or_else(|| {
751 MemoryRepositoryError::UnsupportedAggregate(format!(
752 "cannot convert decimal {value} to f64 for statistical aggregate"
753 ))
754 }),
755 other => Err(MemoryRepositoryError::UnsupportedAggregate(format!(
756 "numeric aggregate does not support {other:?}"
757 ))),
758 })
759 .collect()
760}
761
762fn numeric_variance(
763 rows: &[&Record],
764 field: &str,
765 sample: bool,
766) -> Result<Value, MemoryRepositoryError> {
767 let values = numeric_values(rows, field)?;
768 let count = values.len();
769 if count == 0 || (sample && count < 2) {
770 return Ok(Value::Null);
771 }
772 let mean = values.iter().sum::<f64>() / count as f64;
773 let sum = values
774 .iter()
775 .map(|value| {
776 let diff = value - mean;
777 diff * diff
778 })
779 .sum::<f64>();
780 let denominator = if sample { count - 1 } else { count } as f64;
781 Ok(Value::Decimal(decimal_from_f64(sum / denominator)))
782}
783
784fn numeric_stddev(
785 rows: &[&Record],
786 field: &str,
787 sample: bool,
788) -> Result<Value, MemoryRepositoryError> {
789 Ok(match numeric_variance(rows, field, sample)? {
790 Value::Decimal(value) => {
791 Value::Decimal(decimal_from_f64(value.to_f64().unwrap_or(0.0).sqrt()))
792 }
793 Value::Null => Value::Null,
794 other => other,
795 })
796}
797
798#[derive(Debug, Clone, Copy)]
799enum BitOp {
800 And,
801 Or,
802 Xor,
803}
804
805fn bit_aggregate(rows: &[&Record], field: &str, op: BitOp) -> Result<Value, MemoryRepositoryError> {
806 let mut selected: Option<i64> = None;
807 for value in rows.iter().filter_map(|row| row.get(field)) {
808 let value = match value {
809 Value::I64(value) => *value,
810 Value::U64(value) => i64::try_from(*value).map_err(|_| {
811 MemoryRepositoryError::UnsupportedAggregate(format!(
812 "BIT aggregate u64 {value} exceeds i64 range"
813 ))
814 })?,
815 Value::Null => continue,
816 other => {
817 return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
818 "BIT aggregate does not support {other:?}"
819 )));
820 }
821 };
822 selected = Some(match (selected, op) {
823 (None, _) => value,
824 (Some(current), BitOp::And) => current & value,
825 (Some(current), BitOp::Or) => current | value,
826 (Some(current), BitOp::Xor) => current ^ value,
827 });
828 }
829 Ok(selected.map(Value::I64).unwrap_or(Value::Null))
830}
831
832fn min_max(rows: &[&Record], field: &str, max: bool) -> Result<Value, MemoryRepositoryError> {
833 let mut selected: Option<Value> = None;
834 for value in rows.iter().filter_map(|row| row.get(field)) {
835 if matches!(value, Value::Null) {
836 continue;
837 }
838 match &selected {
839 None => selected = Some(value.clone()),
840 Some(current) => {
841 let Some(ordering) = compare_values(value, current) else {
842 return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
843 "MIN/MAX does not support {value:?}"
844 )));
845 };
846 if (max && ordering == Ordering::Greater) || (!max && ordering == Ordering::Less) {
847 selected = Some(value.clone());
848 }
849 }
850 }
851 }
852 Ok(selected.unwrap_or(Value::Null))
853}
854
855fn value_key(value: &Value) -> String {
856 match value {
857 Value::Null => "null".to_owned(),
858 Value::Bool(value) => format!("b:{value}"),
859 Value::I64(value) => format!("i:{value}"),
860 Value::U64(value) => format!("u:{value}"),
861 Value::F64(value) => format!("f:{value}"),
862 Value::Decimal(value) => format!("d:{value}"),
863 Value::Text(value) => format!("t:{value}"),
864 Value::Json(value) => format!("j:{value}"),
865 Value::Date(value) => format!("d:{value}"),
866 Value::Timestamp(value) => format!("ts:{}", value.to_rfc3339()),
867 Value::Object(_) => "object".to_owned(),
868 Value::List(_) => "list".to_owned(),
869 }
870}
871
872fn read_i64(value: Option<&Value>) -> Option<i64> {
873 match value {
874 Some(Value::I64(value)) => Some(*value),
875 _ => None,
876 }
877}