1use std::cmp::Ordering;
2use std::collections::BTreeMap;
3use std::sync::{Arc, Mutex};
4
5use rust_decimal::Decimal;
6use rust_decimal::prelude::ToPrimitive;
7use teaql_core::{
8 Aggregate, AggregateFunction, BinaryOp, DeleteCommand, Entity, Expr, ExprFunction,
9 InsertCommand, Record, RecoverCommand, RelationAggregate, SelectQuery, SmartList, SortDirection,
10 UpdateCommand, Value,
11};
12
13use crate::{InMemoryMetadataStore, MetadataStore, RepositoryError, RuntimeError};
14
15#[derive(Debug)]
16pub enum MemoryRepositoryError {
17 Poisoned,
18 UnsupportedExpression(String),
19 UnsupportedAggregate(String),
20}
21
22impl std::fmt::Display for MemoryRepositoryError {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 match self {
25 Self::Poisoned => write!(f, "memory repository lock poisoned"),
26 Self::UnsupportedExpression(message) => {
27 write!(f, "unsupported memory expression: {message}")
28 }
29 Self::UnsupportedAggregate(message) => {
30 write!(f, "unsupported memory aggregate: {message}")
31 }
32 }
33 }
34}
35
36impl std::error::Error for MemoryRepositoryError {}
37
38#[derive(Debug, Clone)]
39pub struct MemoryRepository<M = InMemoryMetadataStore> {
40 metadata: M,
41 data: Arc<Mutex<BTreeMap<String, Vec<Record>>>>,
42}
43
44impl<M> MemoryRepository<M>
45where
46 M: MetadataStore,
47{
48 pub fn new(metadata: M) -> Self {
49 Self {
50 metadata,
51 data: Arc::new(Mutex::new(BTreeMap::new())),
52 }
53 }
54
55 pub fn with_rows(mut self, entity: impl Into<String>, rows: Vec<Record>) -> Self {
56 self.seed(entity, rows);
57 self
58 }
59
60 pub fn seed(&mut self, entity: impl Into<String>, rows: Vec<Record>) {
61 if let Ok(mut data) = self.data.lock() {
62 data.insert(entity.into(), rows);
63 }
64 }
65
66 pub fn fetch_all(
67 &self,
68 query: &SelectQuery,
69 ) -> Result<Vec<Record>, RepositoryError<MemoryRepositoryError>> {
70 self.require_entity(&query.entity)?;
71 let data = self
72 .data
73 .lock()
74 .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
75 let mut rows = data.get(&query.entity).cloned().unwrap_or_default();
76 drop(data);
77
78 if let Some(filter) = &query.filter {
79 rows = rows
80 .into_iter()
81 .filter_map(|row| match eval_filter(filter, &row) {
82 Ok(true) => Some(Ok(row)),
83 Ok(false) => None,
84 Err(err) => Some(Err(err)),
85 })
86 .collect::<Result<Vec<_>, _>>()
87 .map_err(RepositoryError::Executor)?;
88 }
89
90 if !query.aggregates.is_empty() {
91 return aggregate_rows(query, &rows).map_err(RepositoryError::Executor);
92 }
93
94 apply_ordering(&mut rows, query);
95 rows = apply_slice(rows, query);
96 if !query.projection.is_empty() || !query.expr_projection.is_empty() {
97 rows = rows
98 .into_iter()
99 .map(|row| project_row(row, query))
100 .collect::<Result<Vec<_>, _>>()
101 .map_err(RepositoryError::Executor)?;
102 }
103 Ok(rows)
104 }
105
106 pub fn fetch_smart_list(
107 &self,
108 query: &SelectQuery,
109 ) -> Result<SmartList<Record>, RepositoryError<MemoryRepositoryError>> {
110 self.fetch_all(query).map(SmartList::from)
111 }
112
113 pub fn fetch_entities<T>(
114 &self,
115 query: &SelectQuery,
116 ) -> Result<SmartList<T>, RepositoryError<MemoryRepositoryError>>
117 where
118 T: Entity,
119 {
120 self.fetch_all(query)?
121 .into_iter()
122 .map(T::from_record)
123 .collect::<Result<Vec<_>, _>>()
124 .map(SmartList::from)
125 .map_err(RepositoryError::Entity)
126 }
127
128 pub fn fetch_all_with_relation_aggregates(
129 &self,
130 query: &SelectQuery,
131 relation_aggregates: &[RelationAggregate],
132 ) -> Result<Vec<Record>, RepositoryError<MemoryRepositoryError>> {
133 let mut rows = self.fetch_all(query)?;
134 self.enhance_relation_aggregates(&query.entity, &mut rows, relation_aggregates)?;
135 Ok(rows)
136 }
137
138 pub fn fetch_smart_list_with_relation_aggregates(
139 &self,
140 query: &SelectQuery,
141 relation_aggregates: &[RelationAggregate],
142 ) -> Result<SmartList<Record>, RepositoryError<MemoryRepositoryError>> {
143 self.fetch_all_with_relation_aggregates(query, relation_aggregates)
144 .map(SmartList::from)
145 }
146
147 pub fn fetch_entities_with_relation_aggregates<T>(
148 &self,
149 query: &SelectQuery,
150 relation_aggregates: &[RelationAggregate],
151 ) -> Result<SmartList<T>, RepositoryError<MemoryRepositoryError>>
152 where
153 T: Entity,
154 {
155 self.fetch_all_with_relation_aggregates(query, relation_aggregates)?
156 .into_iter()
157 .map(T::from_record)
158 .collect::<Result<Vec<_>, _>>()
159 .map(SmartList::from)
160 .map_err(RepositoryError::Entity)
161 }
162
163 pub fn enhance_relation_aggregates(
164 &self,
165 parent_entity: &str,
166 parent_rows: &mut [Record],
167 relation_aggregates: &[RelationAggregate],
168 ) -> Result<(), RepositoryError<MemoryRepositoryError>> {
169 for aggregate in relation_aggregates {
170 self.enhance_relation_aggregate(parent_entity, parent_rows, aggregate)?;
171 }
172 Ok(())
173 }
174
175 fn enhance_relation_aggregate(
176 &self,
177 parent_entity: &str,
178 parent_rows: &mut [Record],
179 aggregate: &RelationAggregate,
180 ) -> Result<(), RepositoryError<MemoryRepositoryError>> {
181 let descriptor = self
182 .metadata
183 .entity(parent_entity)
184 .ok_or_else(|| {
185 RepositoryError::Runtime(RuntimeError::MissingEntity(parent_entity.to_owned()))
186 })?;
187
188 let relation = descriptor
189 .relation_by_name(&aggregate.relation_name)
190 .ok_or_else(|| {
191 RepositoryError::Runtime(RuntimeError::MissingRelation {
192 entity: parent_entity.to_owned(),
193 relation: aggregate.relation_name.clone(),
194 })
195 })?;
196
197 let ids = parent_rows
198 .iter()
199 .filter_map(|row| row.get(&relation.local_key).cloned())
200 .collect::<Vec<_>>();
201
202 if ids.is_empty() {
203 let value = if aggregate.single_result {
204 Value::U64(0)
205 } else {
206 Value::List(Vec::new())
207 };
208 for parent in parent_rows.iter_mut() {
209 parent.insert(aggregate.alias.clone(), value.clone());
210 }
211 return Ok(());
212 }
213
214 let mut query = aggregate.query.clone();
215 query.entity = relation.target_entity.clone();
216 query.projection.clear();
217 query.expr_projection.clear();
218 query.order_by.clear();
219 query.slice = None;
220 query.relations.clear();
221 if query.aggregates.is_empty() {
222 let alias = if aggregate.single_result {
223 aggregate.alias.clone()
224 } else {
225 "count".to_owned()
226 };
227 query = query.aggregate(Aggregate::count(alias));
228 }
229 if !query
230 .group_by
231 .iter()
232 .any(|field| field == &relation.foreign_key)
233 {
234 query = query.group_by(relation.foreign_key.clone());
235 }
236 query = query.and_filter(Expr::in_list(relation.foreign_key.clone(), ids));
237
238 let aggregate_rows = self.fetch_all(&query)?;
239
240 let mut buckets: BTreeMap<String, Vec<Record>> = BTreeMap::new();
241 for mut row in aggregate_rows {
242 if let Some(key) = row.remove(&relation.foreign_key) {
243 let bucket_key = local_graph_identity_key(&key);
244 buckets
245 .entry(bucket_key)
246 .or_default()
247 .push(row);
248 }
249 }
250
251 for parent in parent_rows {
252 let value = parent
253 .get(&relation.local_key)
254 .and_then(|local_value| buckets.get(&local_graph_identity_key(local_value)))
255 .map(|rows| {
256 if aggregate.single_result {
257 rows.first()
258 .map(|row| {
259 if row.len() == 1 {
260 row.values().next().cloned().unwrap_or(Value::Null)
261 } else {
262 Value::object(row.clone())
263 }
264 })
265 .unwrap_or(Value::U64(0))
266 } else {
267 Value::List(rows.iter().cloned().map(Value::object).collect())
268 }
269 })
270 .unwrap_or_else(|| {
271 if aggregate.single_result {
272 Value::U64(0)
273 } else {
274 Value::List(Vec::new())
275 }
276 });
277 parent.insert(aggregate.alias.clone(), value);
278 }
279
280 Ok(())
281 }
282
283 pub fn insert(
284 &self,
285 command: &InsertCommand,
286 ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
287 self.require_entity(&command.entity)?;
288 let mut data = self
289 .data
290 .lock()
291 .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
292 data.entry(command.entity.clone())
293 .or_default()
294 .push(command.values.clone());
295 Ok(1)
296 }
297
298 pub fn update(
299 &self,
300 command: &UpdateCommand,
301 ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
302 let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
303 let mut data = self
304 .data
305 .lock()
306 .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
307 let rows = data.entry(command.entity.clone()).or_default();
308 let Some(row) = rows
309 .iter_mut()
310 .find(|row| row.get(id_property) == Some(&command.id))
311 else {
312 return self.maybe_optimistic_conflict(
313 command.expected_version,
314 &command.entity,
315 &command.id,
316 );
317 };
318
319 if let Some(expected_version) = command.expected_version {
320 if row.get(version_property) != Some(&Value::I64(expected_version)) {
321 return Err(RepositoryError::Runtime(
322 RuntimeError::OptimisticLockConflict {
323 entity: command.entity.clone(),
324 id: format!("{:?}", command.id),
325 },
326 ));
327 }
328 row.insert(
329 version_property.to_owned(),
330 Value::I64(expected_version + 1),
331 );
332 }
333
334 for (key, value) in &command.values {
335 row.insert(key.clone(), value.clone());
336 }
337 Ok(1)
338 }
339
340 pub fn delete(
341 &self,
342 command: &DeleteCommand,
343 ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
344 let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
345 let mut data = self
346 .data
347 .lock()
348 .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
349 let rows = data.entry(command.entity.clone()).or_default();
350 let Some(index) = rows
351 .iter()
352 .position(|row| row.get(id_property) == Some(&command.id))
353 else {
354 return self.maybe_optimistic_conflict(
355 command.expected_version,
356 &command.entity,
357 &command.id,
358 );
359 };
360
361 if let Some(expected_version) = command.expected_version {
362 if rows[index].get(version_property) != Some(&Value::I64(expected_version)) {
363 return Err(RepositoryError::Runtime(
364 RuntimeError::OptimisticLockConflict {
365 entity: command.entity.clone(),
366 id: format!("{:?}", command.id),
367 },
368 ));
369 }
370 }
371
372 if command.soft_delete {
373 let next_version = command
374 .expected_version
375 .or_else(|| read_i64(rows[index].get(version_property)))
376 .map(|version| -(version.abs() + 1))
377 .unwrap_or(-1);
378 rows[index].insert(version_property.to_owned(), Value::I64(next_version));
379 } else {
380 rows.remove(index);
381 }
382 Ok(1)
383 }
384
385 pub fn recover(
386 &self,
387 command: &RecoverCommand,
388 ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
389 let (id_property, version_property) = self.id_and_version_properties(&command.entity)?;
390 let mut data = self
391 .data
392 .lock()
393 .map_err(|_| RepositoryError::Executor(MemoryRepositoryError::Poisoned))?;
394 let rows = data.entry(command.entity.clone()).or_default();
395 let Some(row) = rows
396 .iter_mut()
397 .find(|row| row.get(id_property) == Some(&command.id))
398 else {
399 return Err(RepositoryError::Runtime(
400 RuntimeError::OptimisticLockConflict {
401 entity: command.entity.clone(),
402 id: format!("{:?}", command.id),
403 },
404 ));
405 };
406
407 if row.get(version_property) != Some(&Value::I64(command.expected_version)) {
408 return Err(RepositoryError::Runtime(
409 RuntimeError::OptimisticLockConflict {
410 entity: command.entity.clone(),
411 id: format!("{:?}", command.id),
412 },
413 ));
414 }
415
416 row.insert(
417 version_property.to_owned(),
418 Value::I64(command.expected_version.abs() + 1),
419 );
420 Ok(1)
421 }
422
423 fn require_entity(&self, entity: &str) -> Result<(), RepositoryError<MemoryRepositoryError>> {
424 self.metadata
425 .entity(entity)
426 .map(|_| ())
427 .ok_or_else(|| RepositoryError::Runtime(RuntimeError::MissingEntity(entity.to_owned())))
428 }
429
430 fn id_and_version_properties(
431 &self,
432 entity: &str,
433 ) -> Result<(&str, &str), RepositoryError<MemoryRepositoryError>> {
434 let descriptor = self.metadata.entity(entity).ok_or_else(|| {
435 RepositoryError::Runtime(RuntimeError::MissingEntity(entity.to_owned()))
436 })?;
437 let id = descriptor
438 .id_property()
439 .map(|property| property.name.as_str())
440 .unwrap_or("id");
441 let version = descriptor
442 .version_property()
443 .map(|property| property.name.as_str())
444 .unwrap_or("version");
445 Ok((id, version))
446 }
447
448 fn maybe_optimistic_conflict(
449 &self,
450 expected_version: Option<i64>,
451 entity: &str,
452 id: &Value,
453 ) -> Result<u64, RepositoryError<MemoryRepositoryError>> {
454 if expected_version.is_some() {
455 Err(RepositoryError::Runtime(
456 RuntimeError::OptimisticLockConflict {
457 entity: entity.to_owned(),
458 id: format!("{id:?}"),
459 },
460 ))
461 } else {
462 Ok(0)
463 }
464 }
465}
466
467fn eval_filter(expr: &Expr, row: &Record) -> Result<bool, MemoryRepositoryError> {
468 match expr {
469 Expr::Column(_) | Expr::Value(_) | Expr::Function { .. } => {
470 value_truthy(&eval_value(expr, row)?)
471 }
472 Expr::Binary { left, op, right } => {
473 let left = eval_value(left, row)?;
474 let right = eval_value(right, row)?;
475 eval_binary(&left, *op, &right)
476 }
477 Expr::SubQuery { .. } => Err(MemoryRepositoryError::UnsupportedExpression(
478 "subquery filters require a SQL executor".to_owned(),
479 )),
480 Expr::Between { expr, lower, upper } => {
481 let value = eval_value(expr, row)?;
482 let lower = eval_value(lower, row)?;
483 let upper = eval_value(upper, row)?;
484 Ok(compare_values(&value, &lower) != Some(Ordering::Less)
485 && compare_values(&value, &upper) != Some(Ordering::Greater))
486 }
487 Expr::IsNull(expr) => Ok(matches!(eval_value(expr, row)?, Value::Null)),
488 Expr::IsNotNull(expr) => Ok(!matches!(eval_value(expr, row)?, Value::Null)),
489 Expr::And(parts) => {
490 for part in parts {
491 if !eval_filter(part, row)? {
492 return Ok(false);
493 }
494 }
495 Ok(true)
496 }
497 Expr::Or(parts) => {
498 for part in parts {
499 if eval_filter(part, row)? {
500 return Ok(true);
501 }
502 }
503 Ok(false)
504 }
505 Expr::Not(expr) => Ok(!eval_filter(expr, row)?),
506 }
507}
508
509fn eval_value(expr: &Expr, row: &Record) -> Result<Value, MemoryRepositoryError> {
510 match expr {
511 Expr::Column(column) => Ok(row.get(column).cloned().unwrap_or(Value::Null)),
512 Expr::Value(value) => Ok(value.clone()),
513 Expr::Function { function, args } => eval_function(*function, args, row),
514 other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
515 "cannot evaluate {other:?} as a scalar value"
516 ))),
517 }
518}
519
520fn eval_function(
521 function: ExprFunction,
522 args: &[Expr],
523 row: &Record,
524) -> Result<Value, MemoryRepositoryError> {
525 match function {
526 ExprFunction::Soundex => {
527 let [arg] = args else {
528 return Err(MemoryRepositoryError::UnsupportedExpression(
529 "SOUNDEX expects exactly one argument".to_owned(),
530 ));
531 };
532 match eval_value(arg, row)? {
533 Value::Text(value) => Ok(Value::Text(soundex(&value))),
534 Value::Null => Ok(Value::Null),
535 other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
536 "SOUNDEX expects text, got {other:?}"
537 ))),
538 }
539 }
540 ExprFunction::Gbk => {
541 let [arg] = args else {
542 return Err(MemoryRepositoryError::UnsupportedExpression(
543 "GBK expects exactly one argument".to_owned(),
544 ));
545 };
546 eval_value(arg, row)
547 }
548 other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
549 "function {other:?} is only supported by SQL execution"
550 ))),
551 }
552}
553
554fn eval_binary(left: &Value, op: BinaryOp, right: &Value) -> Result<bool, MemoryRepositoryError> {
555 match op {
556 BinaryOp::Eq => Ok(values_equal(left, right)),
557 BinaryOp::Ne => Ok(!values_equal(left, right)),
558 BinaryOp::Gt => Ok(compare_values(left, right) == Some(Ordering::Greater)),
559 BinaryOp::Gte => Ok(matches!(
560 compare_values(left, right),
561 Some(Ordering::Greater | Ordering::Equal)
562 )),
563 BinaryOp::Lt => Ok(compare_values(left, right) == Some(Ordering::Less)),
564 BinaryOp::Lte => Ok(matches!(
565 compare_values(left, right),
566 Some(Ordering::Less | Ordering::Equal)
567 )),
568 BinaryOp::Like => match (left, right) {
569 (Value::Text(value), Value::Text(pattern)) => Ok(like_matches(value, pattern)),
570 _ => Ok(false),
571 },
572 BinaryOp::NotLike => match (left, right) {
573 (Value::Text(value), Value::Text(pattern)) => Ok(!like_matches(value, pattern)),
574 _ => Ok(true),
575 },
576 BinaryOp::In | BinaryOp::InLarge => match right {
577 Value::List(values) => Ok(values.iter().any(|value| values_equal(left, value))),
578 _ => Err(MemoryRepositoryError::UnsupportedExpression(
579 "IN expects a list value".to_owned(),
580 )),
581 },
582 BinaryOp::NotIn | BinaryOp::NotInLarge => match right {
583 Value::List(values) => Ok(!values.iter().any(|value| values_equal(left, value))),
584 _ => Err(MemoryRepositoryError::UnsupportedExpression(
585 "NOT IN expects a list value".to_owned(),
586 )),
587 },
588 }
589}
590
591fn value_truthy(value: &Value) -> Result<bool, MemoryRepositoryError> {
592 match value {
593 Value::Bool(value) => Ok(*value),
594 Value::Null => Ok(false),
595 other => Err(MemoryRepositoryError::UnsupportedExpression(format!(
596 "non-boolean expression result: {other:?}"
597 ))),
598 }
599}
600
601fn values_equal(left: &Value, right: &Value) -> bool {
602 match (left, right) {
603 (Value::I64(left), Value::U64(right)) if *left >= 0 => *left as u64 == *right,
604 (Value::U64(left), Value::I64(right)) if *right >= 0 => *left == *right as u64,
605 _ => left == right,
606 }
607}
608
609fn compare_values(left: &Value, right: &Value) -> Option<Ordering> {
610 match (left, right) {
611 (Value::I64(left), Value::I64(right)) => Some(left.cmp(right)),
612 (Value::U64(left), Value::U64(right)) => Some(left.cmp(right)),
613 (Value::I64(left), Value::U64(right)) if *left >= 0 => Some((*left as u64).cmp(right)),
614 (Value::U64(left), Value::I64(right)) if *right >= 0 => Some(left.cmp(&(*right as u64))),
615 (Value::F64(left), Value::F64(right)) => left.partial_cmp(right),
616 (Value::Decimal(left), Value::Decimal(right)) => Some(left.cmp(right)),
617 (Value::Text(left), Value::Text(right)) => Some(left.cmp(right)),
618 (Value::Date(left), Value::Date(right)) => Some(left.cmp(right)),
619 (Value::Timestamp(left), Value::Timestamp(right)) => Some(left.cmp(right)),
620 _ => None,
621 }
622}
623
624fn like_matches(value: &str, pattern: &str) -> bool {
625 if pattern == "%" {
626 return true;
627 }
628 match (pattern.strip_prefix('%'), pattern.strip_suffix('%')) {
629 (Some(inner), Some(_)) if pattern.len() >= 2 => value.contains(&inner[..inner.len() - 1]),
630 (Some(suffix), None) => value.ends_with(suffix),
631 (None, Some(prefix)) => value.starts_with(prefix),
632 _ => value == pattern,
633 }
634}
635
636fn soundex(value: &str) -> String {
637 let mut letters = value
638 .chars()
639 .filter(|ch| ch.is_ascii_alphabetic())
640 .map(|ch| ch.to_ascii_uppercase());
641 let Some(first) = letters.next() else {
642 return "0000".to_owned();
643 };
644
645 let mut output = String::with_capacity(4);
646 output.push(first);
647 let mut previous_code = soundex_code(first);
648
649 for ch in letters {
650 let code = soundex_code(ch);
651 if code != '0' && code != previous_code {
652 output.push(code);
653 if output.len() == 4 {
654 return output;
655 }
656 }
657 previous_code = code;
658 }
659
660 while output.len() < 4 {
661 output.push('0');
662 }
663 output
664}
665
666fn soundex_code(ch: char) -> char {
667 match ch {
668 'B' | 'F' | 'P' | 'V' => '1',
669 'C' | 'G' | 'J' | 'K' | 'Q' | 'S' | 'X' | 'Z' => '2',
670 'D' | 'T' => '3',
671 'L' => '4',
672 'M' | 'N' => '5',
673 'R' => '6',
674 _ => '0',
675 }
676}
677
678fn apply_ordering(rows: &mut [Record], query: &SelectQuery) {
679 for order in query.order_by.iter().rev() {
680 rows.sort_by(|left, right| {
681 let left_value = if let Some(expr) = &order.expr {
682 eval_value(expr, left).ok()
683 } else {
684 left.get(&order.field).cloned()
685 };
686 let right_value = if let Some(expr) = &order.expr {
687 eval_value(expr, right).ok()
688 } else {
689 right.get(&order.field).cloned()
690 };
691 let ordering = match (left_value.as_ref(), right_value.as_ref()) {
692 (Some(left), Some(right)) => compare_values(left, right).unwrap_or(Ordering::Equal),
693 (None, Some(_)) => Ordering::Less,
694 (Some(_), None) => Ordering::Greater,
695 (None, None) => Ordering::Equal,
696 };
697 match order.direction {
698 SortDirection::Asc => ordering,
699 SortDirection::Desc => ordering.reverse(),
700 }
701 });
702 }
703}
704
705fn apply_slice(rows: Vec<Record>, query: &SelectQuery) -> Vec<Record> {
706 let Some(slice) = query.slice else {
707 return rows;
708 };
709 let offset = usize::try_from(slice.offset).unwrap_or(usize::MAX);
710 let limit = slice
711 .limit
712 .and_then(|limit| usize::try_from(limit).ok())
713 .unwrap_or(usize::MAX);
714 rows.into_iter().skip(offset).take(limit).collect()
715}
716
717fn project_row(row: Record, query: &SelectQuery) -> Result<Record, MemoryRepositoryError> {
718 let mut output: Record = query
719 .projection
720 .iter()
721 .filter_map(|field| row.get(field).cloned().map(|value| (field.clone(), value)))
722 .collect();
723 for projection in &query.expr_projection {
724 output.insert(
725 projection.alias.clone(),
726 eval_value(&projection.expr, &row)?,
727 );
728 }
729 Ok(output)
730}
731
732fn aggregate_rows(
733 query: &SelectQuery,
734 rows: &[Record],
735) -> Result<Vec<Record>, MemoryRepositoryError> {
736 let mut groups: BTreeMap<Vec<String>, Vec<&Record>> = BTreeMap::new();
737 if query.group_by.is_empty() {
738 groups.insert(Vec::new(), rows.iter().collect());
739 } else {
740 for row in rows {
741 let key = query
742 .group_by
743 .iter()
744 .map(|field| row.get(field).map(value_key).unwrap_or_default())
745 .collect::<Vec<_>>();
746 groups.entry(key).or_default().push(row);
747 }
748 }
749
750 let rows = groups
751 .into_values()
752 .map(|rows| {
753 let mut output = Record::new();
754 if let Some(first) = rows.first() {
755 for field in &query.group_by {
756 if let Some(value) = first.get(field) {
757 output.insert(field.clone(), value.clone());
758 }
759 }
760 }
761 for aggregate in &query.aggregates {
762 let value = match aggregate.function {
763 AggregateFunction::Count => {
764 if aggregate.field == "*" {
765 Value::U64(rows.len() as u64)
766 } else {
767 Value::U64(
768 rows.iter()
769 .filter(|row| {
770 !matches!(
771 row.get(&aggregate.field),
772 None | Some(Value::Null)
773 )
774 })
775 .count() as u64,
776 )
777 }
778 }
779 AggregateFunction::Sum => numeric_sum(&rows, &aggregate.field)?,
780 AggregateFunction::Avg => numeric_avg(&rows, &aggregate.field)?,
781 AggregateFunction::Min => min_max(&rows, &aggregate.field, false)?,
782 AggregateFunction::Max => min_max(&rows, &aggregate.field, true)?,
783 AggregateFunction::Stddev => numeric_stddev(&rows, &aggregate.field, true)?,
784 AggregateFunction::StddevPop => numeric_stddev(&rows, &aggregate.field, false)?,
785 AggregateFunction::VarSamp => numeric_variance(&rows, &aggregate.field, true)?,
786 AggregateFunction::VarPop => numeric_variance(&rows, &aggregate.field, false)?,
787 AggregateFunction::BitAnd => {
788 bit_aggregate(&rows, &aggregate.field, BitOp::And)?
789 }
790 AggregateFunction::BitOr => bit_aggregate(&rows, &aggregate.field, BitOp::Or)?,
791 AggregateFunction::BitXor => {
792 bit_aggregate(&rows, &aggregate.field, BitOp::Xor)?
793 }
794 };
795 output.insert(aggregate.alias.clone(), value);
796 }
797 for projection in &query.expr_projection {
798 output.insert(
799 projection.alias.clone(),
800 eval_value(&projection.expr, &output)?,
801 );
802 }
803 Ok(output)
804 })
805 .collect::<Result<Vec<_>, _>>()?;
806 if let Some(having) = &query.having {
807 rows.into_iter()
808 .filter_map(|row| match eval_filter(having, &row) {
809 Ok(true) => Some(Ok(row)),
810 Ok(false) => None,
811 Err(err) => Some(Err(err)),
812 })
813 .collect()
814 } else {
815 Ok(rows)
816 }
817}
818
819fn numeric_sum(rows: &[&Record], field: &str) -> Result<Value, MemoryRepositoryError> {
820 let mut decimal_sum = Decimal::ZERO;
821 let mut integer_sum: i128 = 0;
822 let mut saw_decimal = false;
823 for value in rows.iter().filter_map(|row| row.get(field)) {
824 match value {
825 Value::I64(value) => {
826 integer_sum += i128::from(*value);
827 decimal_sum += Decimal::from(*value);
828 }
829 Value::U64(value) => {
830 integer_sum += i128::from(*value);
831 decimal_sum += Decimal::from(*value);
832 }
833 Value::F64(value) => {
834 saw_decimal = true;
835 decimal_sum += decimal_from_f64(*value);
836 }
837 Value::Decimal(value) => {
838 saw_decimal = true;
839 decimal_sum += *value;
840 }
841 Value::Null => {}
842 other => {
843 return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
844 "SUM does not support {other:?}"
845 )));
846 }
847 }
848 }
849 if saw_decimal {
850 Ok(Value::Decimal(decimal_sum))
851 } else if integer_sum >= 0 {
852 Ok(Value::U64(integer_sum as u64))
853 } else {
854 Ok(Value::I64(integer_sum as i64))
855 }
856}
857
858fn numeric_avg(rows: &[&Record], field: &str) -> Result<Value, MemoryRepositoryError> {
859 let mut sum = Decimal::ZERO;
860 let mut count: u64 = 0;
861 for value in rows.iter().filter_map(|row| row.get(field)) {
862 match value {
863 Value::I64(value) => {
864 sum += Decimal::from(*value);
865 count += 1;
866 }
867 Value::U64(value) => {
868 sum += Decimal::from(*value);
869 count += 1;
870 }
871 Value::F64(value) => {
872 sum += decimal_from_f64(*value);
873 count += 1;
874 }
875 Value::Decimal(value) => {
876 sum += *value;
877 count += 1;
878 }
879 Value::Null => {}
880 other => {
881 return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
882 "AVG does not support {other:?}"
883 )));
884 }
885 }
886 }
887 Ok(if count == 0 {
888 Value::Null
889 } else {
890 Value::Decimal(sum / Decimal::from(count))
891 })
892}
893
894fn decimal_from_f64(value: f64) -> Decimal {
895 Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
896}
897
898fn numeric_values(rows: &[&Record], field: &str) -> Result<Vec<f64>, MemoryRepositoryError> {
899 rows.iter()
900 .filter_map(|row| row.get(field))
901 .filter(|value| !matches!(value, Value::Null))
902 .map(|value| match value {
903 Value::I64(value) => Ok(*value as f64),
904 Value::U64(value) => Ok(*value as f64),
905 Value::F64(value) => Ok(*value),
906 Value::Decimal(value) => value.to_f64().ok_or_else(|| {
907 MemoryRepositoryError::UnsupportedAggregate(format!(
908 "cannot convert decimal {value} to f64 for statistical aggregate"
909 ))
910 }),
911 other => Err(MemoryRepositoryError::UnsupportedAggregate(format!(
912 "numeric aggregate does not support {other:?}"
913 ))),
914 })
915 .collect()
916}
917
918fn numeric_variance(
919 rows: &[&Record],
920 field: &str,
921 sample: bool,
922) -> Result<Value, MemoryRepositoryError> {
923 let values = numeric_values(rows, field)?;
924 let count = values.len();
925 if count == 0 || (sample && count < 2) {
926 return Ok(Value::Null);
927 }
928 let mean = values.iter().sum::<f64>() / count as f64;
929 let sum = values
930 .iter()
931 .map(|value| {
932 let diff = value - mean;
933 diff * diff
934 })
935 .sum::<f64>();
936 let denominator = if sample { count - 1 } else { count } as f64;
937 Ok(Value::Decimal(decimal_from_f64(sum / denominator)))
938}
939
940fn numeric_stddev(
941 rows: &[&Record],
942 field: &str,
943 sample: bool,
944) -> Result<Value, MemoryRepositoryError> {
945 Ok(match numeric_variance(rows, field, sample)? {
946 Value::Decimal(value) => {
947 Value::Decimal(decimal_from_f64(value.to_f64().unwrap_or(0.0).sqrt()))
948 }
949 Value::Null => Value::Null,
950 other => other,
951 })
952}
953
954#[derive(Debug, Clone, Copy)]
955enum BitOp {
956 And,
957 Or,
958 Xor,
959}
960
961fn bit_aggregate(rows: &[&Record], field: &str, op: BitOp) -> Result<Value, MemoryRepositoryError> {
962 let mut selected: Option<i64> = None;
963 for value in rows.iter().filter_map(|row| row.get(field)) {
964 let value = match value {
965 Value::I64(value) => *value,
966 Value::U64(value) => i64::try_from(*value).map_err(|_| {
967 MemoryRepositoryError::UnsupportedAggregate(format!(
968 "BIT aggregate u64 {value} exceeds i64 range"
969 ))
970 })?,
971 Value::Null => continue,
972 other => {
973 return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
974 "BIT aggregate does not support {other:?}"
975 )));
976 }
977 };
978 selected = Some(match (selected, op) {
979 (None, _) => value,
980 (Some(current), BitOp::And) => current & value,
981 (Some(current), BitOp::Or) => current | value,
982 (Some(current), BitOp::Xor) => current ^ value,
983 });
984 }
985 Ok(selected.map(Value::I64).unwrap_or(Value::Null))
986}
987
988fn min_max(rows: &[&Record], field: &str, max: bool) -> Result<Value, MemoryRepositoryError> {
989 let mut selected: Option<Value> = None;
990 for value in rows.iter().filter_map(|row| row.get(field)) {
991 if matches!(value, Value::Null) {
992 continue;
993 }
994 match &selected {
995 None => selected = Some(value.clone()),
996 Some(current) => {
997 let Some(ordering) = compare_values(value, current) else {
998 return Err(MemoryRepositoryError::UnsupportedAggregate(format!(
999 "MIN/MAX does not support {value:?}"
1000 )));
1001 };
1002 if (max && ordering == Ordering::Greater) || (!max && ordering == Ordering::Less) {
1003 selected = Some(value.clone());
1004 }
1005 }
1006 }
1007 }
1008 Ok(selected.unwrap_or(Value::Null))
1009}
1010
1011fn value_key(value: &Value) -> String {
1012 match value {
1013 Value::Null => "null".to_owned(),
1014 Value::Bool(value) => format!("b:{value}"),
1015 Value::I64(value) => format!("i:{value}"),
1016 Value::U64(value) => format!("u:{value}"),
1017 Value::F64(value) => format!("f:{value}"),
1018 Value::Decimal(value) => format!("d:{value}"),
1019 Value::Text(value) => format!("t:{value}"),
1020 Value::Json(value) => format!("j:{value}"),
1021 Value::Date(value) => format!("d:{value}"),
1022 Value::Timestamp(value) => format!("ts:{}", value.to_rfc3339()),
1023 Value::Object(_) => "object".to_owned(),
1024 Value::List(_) => "list".to_owned(),
1025 }
1026}
1027
1028fn read_i64(value: Option<&Value>) -> Option<i64> {
1029 match value {
1030 Some(Value::I64(value)) => Some(*value),
1031 _ => None,
1032 }
1033}
1034
1035fn local_graph_identity_key(value: &Value) -> String {
1036 match value {
1037 Value::I64(val) if *val >= 0 => format!("u:{}", *val as u64),
1038 Value::U64(val) => format!("u:{val}"),
1039 Value::Null => "null".to_owned(),
1040 Value::Bool(v) => format!("b:{v}"),
1041 Value::I64(v) => format!("i:{v}"),
1042 Value::F64(v) => format!("f:{v}"),
1043 Value::Decimal(v) => format!("d:{v}"),
1044 Value::Text(v) => format!("t:{v}"),
1045 Value::Json(v) => format!("j:{v}"),
1046 Value::Date(v) => format!("d:{v}"),
1047 Value::Timestamp(v) => format!("ts:{}", v.to_rfc3339()),
1048 Value::Object(_) => "o".to_owned(),
1049 Value::List(_) => "l".to_owned(),
1050 }
1051}
1052