1use std::cell::RefCell;
6use std::collections::HashMap;
7
8use chrono::Datelike;
9use regex::Regex;
10use rust_decimal::Decimal;
11use rustledger_core::{
12 Amount, Directive, InternedStr, Inventory, NaiveDate, Position, Transaction,
13};
14
15use crate::ast::{
16 BalancesQuery, BinaryOp, BinaryOperator, Expr, FromClause, FunctionCall, JournalQuery, Literal,
17 OrderSpec, PrintQuery, Query, SelectQuery, SortDirection, Target, UnaryOp, UnaryOperator,
18 WindowFunction,
19};
20use crate::error::QueryError;
21
22#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum Value {
25 String(String),
27 Number(Decimal),
29 Integer(i64),
31 Date(NaiveDate),
33 Boolean(bool),
35 Amount(Amount),
37 Position(Position),
39 Inventory(Inventory),
41 StringSet(Vec<String>),
43 Null,
45}
46
47pub type Row = Vec<Value>;
49
50#[derive(Debug, Clone)]
52pub struct QueryResult {
53 pub columns: Vec<String>,
55 pub rows: Vec<Row>,
57}
58
59impl QueryResult {
60 pub const fn new(columns: Vec<String>) -> Self {
62 Self {
63 columns,
64 rows: Vec::new(),
65 }
66 }
67
68 pub fn add_row(&mut self, row: Row) {
70 self.rows.push(row);
71 }
72
73 pub fn len(&self) -> usize {
75 self.rows.len()
76 }
77
78 pub fn is_empty(&self) -> bool {
80 self.rows.is_empty()
81 }
82}
83
84#[derive(Debug)]
86pub struct PostingContext<'a> {
87 pub transaction: &'a Transaction,
89 pub posting_index: usize,
91 pub balance: Option<Inventory>,
93}
94
95#[derive(Debug, Clone)]
97pub struct WindowContext {
98 pub row_number: usize,
100 pub rank: usize,
102 pub dense_rank: usize,
104}
105
106pub struct Executor<'a> {
108 directives: &'a [Directive],
110 balances: HashMap<InternedStr, Inventory>,
112 price_db: crate::price::PriceDatabase,
114 target_currency: Option<String>,
116 regex_cache: RefCell<HashMap<String, Option<Regex>>>,
118}
119
120impl<'a> Executor<'a> {
121 pub fn new(directives: &'a [Directive]) -> Self {
123 let price_db = crate::price::PriceDatabase::from_directives(directives);
124 Self {
125 directives,
126 balances: HashMap::new(),
127 price_db,
128 target_currency: None,
129 regex_cache: RefCell::new(HashMap::new()),
130 }
131 }
132
133 fn get_or_compile_regex(&self, pattern: &str) -> Option<Regex> {
138 let mut cache = self.regex_cache.borrow_mut();
139 if let Some(cached) = cache.get(pattern) {
140 return cached.clone();
141 }
142 let compiled = Regex::new(pattern).ok();
143 cache.insert(pattern.to_string(), compiled.clone());
144 compiled
145 }
146
147 fn require_regex(&self, pattern: &str) -> Result<Regex, QueryError> {
149 self.get_or_compile_regex(pattern)
150 .ok_or_else(|| QueryError::Type(format!("invalid regex: {pattern}")))
151 }
152
153 pub fn set_target_currency(&mut self, currency: impl Into<String>) {
155 self.target_currency = Some(currency.into());
156 }
157
158 pub fn execute(&mut self, query: &Query) -> Result<QueryResult, QueryError> {
171 match query {
172 Query::Select(select) => self.execute_select(select),
173 Query::Journal(journal) => self.execute_journal(journal),
174 Query::Balances(balances) => self.execute_balances(balances),
175 Query::Print(print) => self.execute_print(print),
176 }
177 }
178
179 fn execute_select(&self, query: &SelectQuery) -> Result<QueryResult, QueryError> {
181 if let Some(from) = &query.from {
183 if let Some(subquery) = &from.subquery {
184 return self.execute_select_from_subquery(query, subquery);
185 }
186 }
187
188 let column_names = self.resolve_column_names(&query.targets)?;
190 let mut result = QueryResult::new(column_names.clone());
191
192 let postings = self.collect_postings(query.from.as_ref(), query.where_clause.as_ref())?;
194
195 let is_aggregate = query
197 .targets
198 .iter()
199 .any(|t| Self::is_aggregate_expr(&t.expr));
200
201 if is_aggregate {
202 let grouped = self.group_postings(&postings, query.group_by.as_ref())?;
204 for (_, group) in grouped {
205 let row = self.evaluate_aggregate_row(&query.targets, &group)?;
206
207 if let Some(having_expr) = &query.having {
209 if !self.evaluate_having_filter(
210 having_expr,
211 &row,
212 &column_names,
213 &query.targets,
214 &group,
215 )? {
216 continue;
217 }
218 }
219
220 result.add_row(row);
221 }
222 } else {
223 let has_windows = Self::has_window_functions(&query.targets);
225 let window_contexts = if has_windows {
226 if let Some(wf) = Self::find_window_function(&query.targets) {
227 Some(self.compute_window_contexts(&postings, wf)?)
228 } else {
229 None
230 }
231 } else {
232 None
233 };
234
235 for (i, ctx) in postings.iter().enumerate() {
237 let row = if let Some(ref wctxs) = window_contexts {
238 self.evaluate_row_with_window(&query.targets, ctx, Some(&wctxs[i]))?
239 } else {
240 self.evaluate_row(&query.targets, ctx)?
241 };
242 if query.distinct {
243 if !result.rows.contains(&row) {
245 result.add_row(row);
246 }
247 } else {
248 result.add_row(row);
249 }
250 }
251 }
252
253 if let Some(pivot_exprs) = &query.pivot_by {
255 result = self.apply_pivot(&result, pivot_exprs, &query.targets)?;
256 }
257
258 if let Some(order_by) = &query.order_by {
260 self.sort_results(&mut result, order_by)?;
261 }
262
263 if let Some(limit) = query.limit {
265 result.rows.truncate(limit as usize);
266 }
267
268 Ok(result)
269 }
270
271 fn execute_select_from_subquery(
273 &self,
274 outer_query: &SelectQuery,
275 inner_query: &SelectQuery,
276 ) -> Result<QueryResult, QueryError> {
277 let inner_result = self.execute_select(inner_query)?;
279
280 let inner_column_map: HashMap<String, usize> = inner_result
282 .columns
283 .iter()
284 .enumerate()
285 .map(|(i, name)| (name.to_lowercase(), i))
286 .collect();
287
288 let outer_column_names =
290 self.resolve_subquery_column_names(&outer_query.targets, &inner_result.columns)?;
291 let mut result = QueryResult::new(outer_column_names);
292
293 for inner_row in &inner_result.rows {
295 if let Some(where_expr) = &outer_query.where_clause {
297 if !self.evaluate_subquery_filter(where_expr, inner_row, &inner_column_map)? {
298 continue;
299 }
300 }
301
302 let outer_row =
304 self.evaluate_subquery_row(&outer_query.targets, inner_row, &inner_column_map)?;
305
306 if outer_query.distinct {
307 if !result.rows.contains(&outer_row) {
308 result.add_row(outer_row);
309 }
310 } else {
311 result.add_row(outer_row);
312 }
313 }
314
315 if let Some(order_by) = &outer_query.order_by {
317 self.sort_results(&mut result, order_by)?;
318 }
319
320 if let Some(limit) = outer_query.limit {
322 result.rows.truncate(limit as usize);
323 }
324
325 Ok(result)
326 }
327
328 fn resolve_subquery_column_names(
330 &self,
331 targets: &[Target],
332 inner_columns: &[String],
333 ) -> Result<Vec<String>, QueryError> {
334 let mut names = Vec::new();
335 for (i, target) in targets.iter().enumerate() {
336 if let Some(alias) = &target.alias {
337 names.push(alias.clone());
338 } else if matches!(target.expr, Expr::Wildcard) {
339 names.extend(inner_columns.iter().cloned());
341 } else {
342 names.push(self.expr_to_name(&target.expr, i));
343 }
344 }
345 Ok(names)
346 }
347
348 fn evaluate_subquery_filter(
350 &self,
351 expr: &Expr,
352 row: &[Value],
353 column_map: &HashMap<String, usize>,
354 ) -> Result<bool, QueryError> {
355 let val = self.evaluate_subquery_expr(expr, row, column_map)?;
356 self.to_bool(&val)
357 }
358
359 fn evaluate_subquery_expr(
361 &self,
362 expr: &Expr,
363 row: &[Value],
364 column_map: &HashMap<String, usize>,
365 ) -> Result<Value, QueryError> {
366 match expr {
367 Expr::Wildcard => Err(QueryError::Evaluation(
368 "Wildcard not allowed in expression context".to_string(),
369 )),
370 Expr::Column(name) => {
371 let lower = name.to_lowercase();
372 if let Some(&idx) = column_map.get(&lower) {
373 Ok(row.get(idx).cloned().unwrap_or(Value::Null))
374 } else {
375 Err(QueryError::Evaluation(format!(
376 "Unknown column '{name}' in subquery result"
377 )))
378 }
379 }
380 Expr::Literal(lit) => self.evaluate_literal(lit),
381 Expr::Function(func) => {
382 let args: Vec<Value> = func
384 .args
385 .iter()
386 .map(|a| self.evaluate_subquery_expr(a, row, column_map))
387 .collect::<Result<Vec<_>, _>>()?;
388 self.evaluate_function_on_values(&func.name, &args)
389 }
390 Expr::BinaryOp(op) => {
391 let left = self.evaluate_subquery_expr(&op.left, row, column_map)?;
392 let right = self.evaluate_subquery_expr(&op.right, row, column_map)?;
393 self.binary_op_on_values(op.op, &left, &right)
394 }
395 Expr::UnaryOp(op) => {
396 let val = self.evaluate_subquery_expr(&op.operand, row, column_map)?;
397 self.unary_op_on_value(op.op, &val)
398 }
399 Expr::Paren(inner) => self.evaluate_subquery_expr(inner, row, column_map),
400 Expr::Window(_) => Err(QueryError::Evaluation(
401 "Window functions not supported in subquery expressions".to_string(),
402 )),
403 }
404 }
405
406 fn evaluate_subquery_row(
408 &self,
409 targets: &[Target],
410 inner_row: &[Value],
411 column_map: &HashMap<String, usize>,
412 ) -> Result<Row, QueryError> {
413 let mut row = Vec::new();
414 for target in targets {
415 if matches!(target.expr, Expr::Wildcard) {
416 row.extend(inner_row.iter().cloned());
418 } else {
419 row.push(self.evaluate_subquery_expr(&target.expr, inner_row, column_map)?);
420 }
421 }
422 Ok(row)
423 }
424
425 fn execute_journal(&mut self, query: &JournalQuery) -> Result<QueryResult, QueryError> {
427 let account_pattern = &query.account_pattern;
429
430 let account_regex = self.get_or_compile_regex(account_pattern);
432
433 let columns = vec![
434 "date".to_string(),
435 "flag".to_string(),
436 "payee".to_string(),
437 "narration".to_string(),
438 "account".to_string(),
439 "position".to_string(),
440 "balance".to_string(),
441 ];
442 let mut result = QueryResult::new(columns);
443
444 for directive in self.directives {
446 if let Directive::Transaction(txn) = directive {
447 if let Some(from) = &query.from {
449 if let Some(filter) = &from.filter {
450 if !self.evaluate_from_filter(filter, txn)? {
451 continue;
452 }
453 }
454 }
455
456 for posting in &txn.postings {
457 let matches = if let Some(ref regex) = account_regex {
459 regex.is_match(&posting.account)
460 } else {
461 posting.account.contains(account_pattern)
462 };
463
464 if matches {
465 let balance = self.balances.entry(posting.account.clone()).or_default();
467
468 if let Some(units) = posting.amount() {
470 let pos = if let Some(cost_spec) = &posting.cost {
471 if let Some(cost) = cost_spec.resolve(units.number, txn.date) {
472 Position::with_cost(units.clone(), cost)
473 } else {
474 Position::simple(units.clone())
475 }
476 } else {
477 Position::simple(units.clone())
478 };
479 balance.add(pos.clone());
480 }
481
482 let position_value = if let Some(at_func) = &query.at_function {
484 match at_func.to_uppercase().as_str() {
485 "COST" => {
486 if let Some(units) = posting.amount() {
487 if let Some(cost_spec) = &posting.cost {
488 if let Some(cost) =
489 cost_spec.resolve(units.number, txn.date)
490 {
491 let total = units.number * cost.number;
492 Value::Amount(Amount::new(total, &cost.currency))
493 } else {
494 Value::Amount(units.clone())
495 }
496 } else {
497 Value::Amount(units.clone())
498 }
499 } else {
500 Value::Null
501 }
502 }
503 "UNITS" => posting
504 .amount()
505 .map_or(Value::Null, |u| Value::Amount(u.clone())),
506 _ => posting
507 .amount()
508 .map_or(Value::Null, |u| Value::Amount(u.clone())),
509 }
510 } else {
511 posting
512 .amount()
513 .map_or(Value::Null, |u| Value::Amount(u.clone()))
514 };
515
516 let row = vec![
517 Value::Date(txn.date),
518 Value::String(txn.flag.to_string()),
519 Value::String(
520 txn.payee
521 .as_ref()
522 .map_or_else(String::new, ToString::to_string),
523 ),
524 Value::String(txn.narration.to_string()),
525 Value::String(posting.account.to_string()),
526 position_value,
527 Value::Inventory(balance.clone()),
528 ];
529 result.add_row(row);
530 }
531 }
532 }
533 }
534
535 Ok(result)
536 }
537
538 fn execute_balances(&mut self, query: &BalancesQuery) -> Result<QueryResult, QueryError> {
540 self.build_balances_with_filter(query.from.as_ref())?;
542
543 let columns = vec!["account".to_string(), "balance".to_string()];
544 let mut result = QueryResult::new(columns);
545
546 let mut accounts: Vec<_> = self.balances.keys().collect();
548 accounts.sort();
549
550 for account in accounts {
551 let Some(balance) = self.balances.get(account) else {
553 continue; };
555
556 let balance_value = if let Some(at_func) = &query.at_function {
558 match at_func.to_uppercase().as_str() {
559 "COST" => {
560 let cost_inventory = balance.at_cost();
562 Value::Inventory(cost_inventory)
563 }
564 "UNITS" => {
565 let units_inventory = balance.at_units();
567 Value::Inventory(units_inventory)
568 }
569 _ => Value::Inventory(balance.clone()),
570 }
571 } else {
572 Value::Inventory(balance.clone())
573 };
574
575 let row = vec![Value::String(account.to_string()), balance_value];
576 result.add_row(row);
577 }
578
579 Ok(result)
580 }
581
582 fn execute_print(&self, query: &PrintQuery) -> Result<QueryResult, QueryError> {
584 let columns = vec!["directive".to_string()];
586 let mut result = QueryResult::new(columns);
587
588 for directive in self.directives {
589 if let Some(from) = &query.from {
591 if let Some(filter) = &from.filter {
592 if let Directive::Transaction(txn) = directive {
594 if !self.evaluate_from_filter(filter, txn)? {
595 continue;
596 }
597 }
598 }
599 }
600
601 let formatted = self.format_directive(directive);
603 result.add_row(vec![Value::String(formatted)]);
604 }
605
606 Ok(result)
607 }
608
609 fn format_directive(&self, directive: &Directive) -> String {
611 match directive {
612 Directive::Transaction(txn) => {
613 let mut out = format!("{} {} ", txn.date, txn.flag);
614 if let Some(payee) = &txn.payee {
615 out.push_str(&format!("\"{payee}\" "));
616 }
617 out.push_str(&format!("\"{}\"", txn.narration));
618
619 for tag in &txn.tags {
620 out.push_str(&format!(" #{tag}"));
621 }
622 for link in &txn.links {
623 out.push_str(&format!(" ^{link}"));
624 }
625 out.push('\n');
626
627 for posting in &txn.postings {
628 out.push_str(&format!(" {}", posting.account));
629 if let Some(units) = posting.amount() {
630 out.push_str(&format!(" {} {}", units.number, units.currency));
631 }
632 out.push('\n');
633 }
634 out
635 }
636 Directive::Balance(bal) => {
637 format!(
638 "{} balance {} {} {}\n",
639 bal.date, bal.account, bal.amount.number, bal.amount.currency
640 )
641 }
642 Directive::Open(open) => {
643 let mut out = format!("{} open {}", open.date, open.account);
644 if !open.currencies.is_empty() {
645 out.push_str(&format!(" {}", open.currencies.join(",")));
646 }
647 out.push('\n');
648 out
649 }
650 Directive::Close(close) => {
651 format!("{} close {}\n", close.date, close.account)
652 }
653 Directive::Commodity(comm) => {
654 format!("{} commodity {}\n", comm.date, comm.currency)
655 }
656 Directive::Pad(pad) => {
657 format!("{} pad {} {}\n", pad.date, pad.account, pad.source_account)
658 }
659 Directive::Event(event) => {
660 format!(
661 "{} event \"{}\" \"{}\"\n",
662 event.date, event.event_type, event.value
663 )
664 }
665 Directive::Query(query) => {
666 format!(
667 "{} query \"{}\" \"{}\"\n",
668 query.date, query.name, query.query
669 )
670 }
671 Directive::Note(note) => {
672 format!("{} note {} \"{}\"\n", note.date, note.account, note.comment)
673 }
674 Directive::Document(doc) => {
675 format!("{} document {} \"{}\"\n", doc.date, doc.account, doc.path)
676 }
677 Directive::Price(price) => {
678 format!(
679 "{} price {} {} {}\n",
680 price.date, price.currency, price.amount.number, price.amount.currency
681 )
682 }
683 Directive::Custom(custom) => {
684 format!("{} custom \"{}\"\n", custom.date, custom.custom_type)
685 }
686 }
687 }
688
689 fn build_balances_with_filter(&mut self, from: Option<&FromClause>) -> Result<(), QueryError> {
691 for directive in self.directives {
692 if let Directive::Transaction(txn) = directive {
693 if let Some(from_clause) = from {
695 if let Some(filter) = &from_clause.filter {
696 if !self.evaluate_from_filter(filter, txn)? {
697 continue;
698 }
699 }
700 }
701
702 for posting in &txn.postings {
703 if let Some(units) = posting.amount() {
704 let balance = self.balances.entry(posting.account.clone()).or_default();
705
706 let pos = if let Some(cost_spec) = &posting.cost {
707 if let Some(cost) = cost_spec.resolve(units.number, txn.date) {
708 Position::with_cost(units.clone(), cost)
709 } else {
710 Position::simple(units.clone())
711 }
712 } else {
713 Position::simple(units.clone())
714 };
715 balance.add(pos);
716 }
717 }
718 }
719 }
720 Ok(())
721 }
722
723 fn collect_postings(
725 &self,
726 from: Option<&FromClause>,
727 where_clause: Option<&Expr>,
728 ) -> Result<Vec<PostingContext<'a>>, QueryError> {
729 let mut postings = Vec::new();
730 let mut running_balances: HashMap<InternedStr, Inventory> = HashMap::new();
732
733 for directive in self.directives {
734 if let Directive::Transaction(txn) = directive {
735 if let Some(from) = from {
737 if let Some(open_date) = from.open_on {
739 if txn.date < open_date {
740 for posting in &txn.postings {
742 if let Some(units) = posting.amount() {
743 let balance = running_balances
744 .entry(posting.account.clone())
745 .or_default();
746 balance.add(Position::simple(units.clone()));
747 }
748 }
749 continue;
750 }
751 }
752 if let Some(close_date) = from.close_on {
753 if txn.date > close_date {
754 continue;
755 }
756 }
757 if let Some(filter) = &from.filter {
759 if !self.evaluate_from_filter(filter, txn)? {
760 continue;
761 }
762 }
763 }
764
765 for (i, posting) in txn.postings.iter().enumerate() {
767 if let Some(units) = posting.amount() {
769 let balance = running_balances.entry(posting.account.clone()).or_default();
770 balance.add(Position::simple(units.clone()));
771 }
772
773 let ctx = PostingContext {
774 transaction: txn,
775 posting_index: i,
776 balance: running_balances.get(&posting.account).cloned(),
777 };
778
779 if let Some(where_expr) = where_clause {
781 if self.evaluate_predicate(where_expr, &ctx)? {
782 postings.push(ctx);
783 }
784 } else {
785 postings.push(ctx);
786 }
787 }
788 }
789 }
790
791 Ok(postings)
792 }
793
794 fn evaluate_from_filter(&self, filter: &Expr, txn: &Transaction) -> Result<bool, QueryError> {
796 match filter {
798 Expr::Function(func) => {
799 if func.name.to_uppercase().as_str() == "HAS_ACCOUNT" {
800 if func.args.len() != 1 {
801 return Err(QueryError::InvalidArguments(
802 "has_account".to_string(),
803 "expected 1 argument".to_string(),
804 ));
805 }
806 let pattern = match &func.args[0] {
807 Expr::Literal(Literal::String(s)) => s.clone(),
808 Expr::Column(s) => s.clone(),
809 _ => {
810 return Err(QueryError::Type(
811 "has_account expects a string pattern".to_string(),
812 ));
813 }
814 };
815 let regex = self.require_regex(&pattern)?;
817 for posting in &txn.postings {
818 if regex.is_match(&posting.account) {
819 return Ok(true);
820 }
821 }
822 Ok(false)
823 } else {
824 let dummy_ctx = PostingContext {
826 transaction: txn,
827 posting_index: 0,
828 balance: None,
829 };
830 self.evaluate_predicate(filter, &dummy_ctx)
831 }
832 }
833 Expr::BinaryOp(op) => {
834 match (&op.left, &op.right) {
836 (Expr::Column(col), Expr::Literal(lit)) if col.to_uppercase() == "YEAR" => {
837 if let Literal::Integer(n) = lit {
838 let matches = txn.date.year() == *n as i32;
839 Ok(if op.op == BinaryOperator::Eq {
840 matches
841 } else {
842 !matches
843 })
844 } else {
845 Ok(false)
846 }
847 }
848 (Expr::Column(col), Expr::Literal(lit)) if col.to_uppercase() == "MONTH" => {
849 if let Literal::Integer(n) = lit {
850 let matches = txn.date.month() == *n as u32;
851 Ok(if op.op == BinaryOperator::Eq {
852 matches
853 } else {
854 !matches
855 })
856 } else {
857 Ok(false)
858 }
859 }
860 (Expr::Column(col), Expr::Literal(Literal::Date(d)))
861 if col.to_uppercase() == "DATE" =>
862 {
863 let matches = match op.op {
864 BinaryOperator::Eq => txn.date == *d,
865 BinaryOperator::Ne => txn.date != *d,
866 BinaryOperator::Lt => txn.date < *d,
867 BinaryOperator::Le => txn.date <= *d,
868 BinaryOperator::Gt => txn.date > *d,
869 BinaryOperator::Ge => txn.date >= *d,
870 _ => false,
871 };
872 Ok(matches)
873 }
874 _ => {
875 let dummy_ctx = PostingContext {
877 transaction: txn,
878 posting_index: 0,
879 balance: None,
880 };
881 self.evaluate_predicate(filter, &dummy_ctx)
882 }
883 }
884 }
885 _ => {
886 let dummy_ctx = PostingContext {
888 transaction: txn,
889 posting_index: 0,
890 balance: None,
891 };
892 self.evaluate_predicate(filter, &dummy_ctx)
893 }
894 }
895 }
896
897 fn evaluate_predicate(&self, expr: &Expr, ctx: &PostingContext) -> Result<bool, QueryError> {
899 let value = self.evaluate_expr(expr, ctx)?;
900 match value {
901 Value::Boolean(b) => Ok(b),
902 Value::Null => Ok(false),
903 _ => Err(QueryError::Type("expected boolean expression".to_string())),
904 }
905 }
906
907 fn evaluate_expr(&self, expr: &Expr, ctx: &PostingContext) -> Result<Value, QueryError> {
909 match expr {
910 Expr::Wildcard => Ok(Value::Null), Expr::Column(name) => self.evaluate_column(name, ctx),
912 Expr::Literal(lit) => self.evaluate_literal(lit),
913 Expr::Function(func) => self.evaluate_function(func, ctx),
914 Expr::Window(_) => {
915 Err(QueryError::Evaluation(
918 "Window function cannot be evaluated in posting context".to_string(),
919 ))
920 }
921 Expr::BinaryOp(op) => self.evaluate_binary_op(op, ctx),
922 Expr::UnaryOp(op) => self.evaluate_unary_op(op, ctx),
923 Expr::Paren(inner) => self.evaluate_expr(inner, ctx),
924 }
925 }
926
927 fn evaluate_column(&self, name: &str, ctx: &PostingContext) -> Result<Value, QueryError> {
929 let posting = &ctx.transaction.postings[ctx.posting_index];
930
931 match name {
932 "date" => Ok(Value::Date(ctx.transaction.date)),
933 "account" => Ok(Value::String(posting.account.to_string())),
934 "narration" => Ok(Value::String(ctx.transaction.narration.to_string())),
935 "payee" => Ok(ctx
936 .transaction
937 .payee
938 .as_ref()
939 .map_or(Value::Null, |p| Value::String(p.to_string()))),
940 "flag" => Ok(Value::String(ctx.transaction.flag.to_string())),
941 "tags" => Ok(Value::StringSet(
942 ctx.transaction
943 .tags
944 .iter()
945 .map(ToString::to_string)
946 .collect(),
947 )),
948 "links" => Ok(Value::StringSet(
949 ctx.transaction
950 .links
951 .iter()
952 .map(ToString::to_string)
953 .collect(),
954 )),
955 "position" | "units" => Ok(posting
956 .amount()
957 .map_or(Value::Null, |u| Value::Amount(u.clone()))),
958 "cost" => {
959 if let Some(units) = posting.amount() {
961 if let Some(cost) = &posting.cost {
962 if let Some(number_per) = &cost.number_per {
963 if let Some(currency) = &cost.currency {
964 let total = units.number.abs() * number_per;
965 return Ok(Value::Amount(Amount::new(total, currency.clone())));
966 }
967 }
968 }
969 }
970 Ok(Value::Null)
971 }
972 "weight" => {
973 if let Some(units) = posting.amount() {
977 if let Some(cost) = &posting.cost {
978 if let Some(number_per) = &cost.number_per {
979 if let Some(currency) = &cost.currency {
980 let total = units.number * number_per;
981 return Ok(Value::Amount(Amount::new(total, currency.clone())));
982 }
983 }
984 }
985 Ok(Value::Amount(units.clone()))
987 } else {
988 Ok(Value::Null)
989 }
990 }
991 "balance" => {
992 if let Some(ref balance) = ctx.balance {
994 Ok(Value::Inventory(balance.clone()))
995 } else {
996 Ok(Value::Null)
997 }
998 }
999 "year" => Ok(Value::Integer(ctx.transaction.date.year().into())),
1000 "month" => Ok(Value::Integer(ctx.transaction.date.month().into())),
1001 "day" => Ok(Value::Integer(ctx.transaction.date.day().into())),
1002 _ => Err(QueryError::UnknownColumn(name.to_string())),
1003 }
1004 }
1005
1006 fn evaluate_literal(&self, lit: &Literal) -> Result<Value, QueryError> {
1008 Ok(match lit {
1009 Literal::String(s) => Value::String(s.clone()),
1010 Literal::Number(n) => Value::Number(*n),
1011 Literal::Integer(i) => Value::Integer(*i),
1012 Literal::Date(d) => Value::Date(*d),
1013 Literal::Boolean(b) => Value::Boolean(*b),
1014 Literal::Null => Value::Null,
1015 })
1016 }
1017
1018 fn evaluate_function(
1022 &self,
1023 func: &FunctionCall,
1024 ctx: &PostingContext,
1025 ) -> Result<Value, QueryError> {
1026 let name = func.name.to_uppercase();
1027 match name.as_str() {
1028 "YEAR" | "MONTH" | "DAY" | "WEEKDAY" | "QUARTER" | "YMONTH" | "TODAY" => {
1030 self.eval_date_function(&name, func, ctx)
1031 }
1032 "LENGTH" | "UPPER" | "LOWER" | "SUBSTR" | "SUBSTRING" | "TRIM" | "STARTSWITH"
1034 | "ENDSWITH" => self.eval_string_function(&name, func, ctx),
1035 "PARENT" | "LEAF" | "ROOT" | "ACCOUNT_DEPTH" | "ACCOUNT_SORTKEY" => {
1037 self.eval_account_function(&name, func, ctx)
1038 }
1039 "ABS" | "NEG" | "ROUND" | "SAFEDIV" => self.eval_math_function(&name, func, ctx),
1041 "NUMBER" | "CURRENCY" | "GETITEM" | "GET" | "UNITS" | "COST" | "WEIGHT" | "VALUE" => {
1043 self.eval_position_function(&name, func, ctx)
1044 }
1045 "COALESCE" => self.eval_coalesce(func, ctx),
1047 "SUM" | "COUNT" | "MIN" | "MAX" | "FIRST" | "LAST" | "AVG" => Ok(Value::Null),
1050 _ => Err(QueryError::UnknownFunction(func.name.clone())),
1051 }
1052 }
1053
1054 fn eval_date_function(
1056 &self,
1057 name: &str,
1058 func: &FunctionCall,
1059 ctx: &PostingContext,
1060 ) -> Result<Value, QueryError> {
1061 if name == "TODAY" {
1062 if !func.args.is_empty() {
1063 return Err(QueryError::InvalidArguments(
1064 "TODAY".to_string(),
1065 "expected 0 arguments".to_string(),
1066 ));
1067 }
1068 return Ok(Value::Date(chrono::Local::now().date_naive()));
1069 }
1070
1071 if func.args.len() != 1 {
1073 return Err(QueryError::InvalidArguments(
1074 name.to_string(),
1075 "expected 1 argument".to_string(),
1076 ));
1077 }
1078
1079 let val = self.evaluate_expr(&func.args[0], ctx)?;
1080 let date = match val {
1081 Value::Date(d) => d,
1082 _ => return Err(QueryError::Type(format!("{name} expects a date"))),
1083 };
1084
1085 match name {
1086 "YEAR" => Ok(Value::Integer(date.year().into())),
1087 "MONTH" => Ok(Value::Integer(date.month().into())),
1088 "DAY" => Ok(Value::Integer(date.day().into())),
1089 "WEEKDAY" => Ok(Value::Integer(date.weekday().num_days_from_monday().into())),
1090 "QUARTER" => {
1091 let quarter = (date.month() - 1) / 3 + 1;
1092 Ok(Value::Integer(quarter.into()))
1093 }
1094 "YMONTH" => Ok(Value::String(format!(
1095 "{:04}-{:02}",
1096 date.year(),
1097 date.month()
1098 ))),
1099 _ => unreachable!(),
1100 }
1101 }
1102
1103 fn eval_string_function(
1105 &self,
1106 name: &str,
1107 func: &FunctionCall,
1108 ctx: &PostingContext,
1109 ) -> Result<Value, QueryError> {
1110 match name {
1111 "LENGTH" => {
1112 Self::require_args(name, func, 1)?;
1113 let val = self.evaluate_expr(&func.args[0], ctx)?;
1114 match val {
1115 Value::String(s) => Ok(Value::Integer(s.len() as i64)),
1116 Value::StringSet(s) => Ok(Value::Integer(s.len() as i64)),
1117 _ => Err(QueryError::Type(
1118 "LENGTH expects a string or set".to_string(),
1119 )),
1120 }
1121 }
1122 "UPPER" => {
1123 Self::require_args(name, func, 1)?;
1124 let val = self.evaluate_expr(&func.args[0], ctx)?;
1125 match val {
1126 Value::String(s) => Ok(Value::String(s.to_uppercase())),
1127 _ => Err(QueryError::Type("UPPER expects a string".to_string())),
1128 }
1129 }
1130 "LOWER" => {
1131 Self::require_args(name, func, 1)?;
1132 let val = self.evaluate_expr(&func.args[0], ctx)?;
1133 match val {
1134 Value::String(s) => Ok(Value::String(s.to_lowercase())),
1135 _ => Err(QueryError::Type("LOWER expects a string".to_string())),
1136 }
1137 }
1138 "TRIM" => {
1139 Self::require_args(name, func, 1)?;
1140 let val = self.evaluate_expr(&func.args[0], ctx)?;
1141 match val {
1142 Value::String(s) => Ok(Value::String(s.trim().to_string())),
1143 _ => Err(QueryError::Type("TRIM expects a string".to_string())),
1144 }
1145 }
1146 "SUBSTR" | "SUBSTRING" => self.eval_substr(func, ctx),
1147 "STARTSWITH" => {
1148 Self::require_args(name, func, 2)?;
1149 let val = self.evaluate_expr(&func.args[0], ctx)?;
1150 let prefix = self.evaluate_expr(&func.args[1], ctx)?;
1151 match (val, prefix) {
1152 (Value::String(s), Value::String(p)) => Ok(Value::Boolean(s.starts_with(&p))),
1153 _ => Err(QueryError::Type(
1154 "STARTSWITH expects two strings".to_string(),
1155 )),
1156 }
1157 }
1158 "ENDSWITH" => {
1159 Self::require_args(name, func, 2)?;
1160 let val = self.evaluate_expr(&func.args[0], ctx)?;
1161 let suffix = self.evaluate_expr(&func.args[1], ctx)?;
1162 match (val, suffix) {
1163 (Value::String(s), Value::String(p)) => Ok(Value::Boolean(s.ends_with(&p))),
1164 _ => Err(QueryError::Type("ENDSWITH expects two strings".to_string())),
1165 }
1166 }
1167 _ => unreachable!(),
1168 }
1169 }
1170
1171 fn eval_substr(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1173 if func.args.len() < 2 || func.args.len() > 3 {
1174 return Err(QueryError::InvalidArguments(
1175 "SUBSTR".to_string(),
1176 "expected 2 or 3 arguments".to_string(),
1177 ));
1178 }
1179
1180 let val = self.evaluate_expr(&func.args[0], ctx)?;
1181 let start = self.evaluate_expr(&func.args[1], ctx)?;
1182 let len = if func.args.len() == 3 {
1183 Some(self.evaluate_expr(&func.args[2], ctx)?)
1184 } else {
1185 None
1186 };
1187
1188 match (val, start, len) {
1189 (Value::String(s), Value::Integer(start), None) => {
1190 let start = start.max(0) as usize;
1191 if start >= s.len() {
1192 Ok(Value::String(String::new()))
1193 } else {
1194 Ok(Value::String(s[start..].to_string()))
1195 }
1196 }
1197 (Value::String(s), Value::Integer(start), Some(Value::Integer(len))) => {
1198 let start = start.max(0) as usize;
1199 let len = len.max(0) as usize;
1200 if start >= s.len() {
1201 Ok(Value::String(String::new()))
1202 } else {
1203 let end = (start + len).min(s.len());
1204 Ok(Value::String(s[start..end].to_string()))
1205 }
1206 }
1207 _ => Err(QueryError::Type(
1208 "SUBSTR expects (string, int, [int])".to_string(),
1209 )),
1210 }
1211 }
1212
1213 fn eval_account_function(
1215 &self,
1216 name: &str,
1217 func: &FunctionCall,
1218 ctx: &PostingContext,
1219 ) -> Result<Value, QueryError> {
1220 match name {
1221 "PARENT" => {
1222 Self::require_args(name, func, 1)?;
1223 let val = self.evaluate_expr(&func.args[0], ctx)?;
1224 match val {
1225 Value::String(s) => {
1226 if let Some(idx) = s.rfind(':') {
1227 Ok(Value::String(s[..idx].to_string()))
1228 } else {
1229 Ok(Value::Null)
1230 }
1231 }
1232 _ => Err(QueryError::Type(
1233 "PARENT expects an account string".to_string(),
1234 )),
1235 }
1236 }
1237 "LEAF" => {
1238 Self::require_args(name, func, 1)?;
1239 let val = self.evaluate_expr(&func.args[0], ctx)?;
1240 match val {
1241 Value::String(s) => {
1242 if let Some(idx) = s.rfind(':') {
1243 Ok(Value::String(s[idx + 1..].to_string()))
1244 } else {
1245 Ok(Value::String(s))
1246 }
1247 }
1248 _ => Err(QueryError::Type(
1249 "LEAF expects an account string".to_string(),
1250 )),
1251 }
1252 }
1253 "ROOT" => self.eval_root(func, ctx),
1254 "ACCOUNT_DEPTH" => {
1255 Self::require_args(name, func, 1)?;
1256 let val = self.evaluate_expr(&func.args[0], ctx)?;
1257 match val {
1258 Value::String(s) => {
1259 let depth = s.chars().filter(|c| *c == ':').count() + 1;
1260 Ok(Value::Integer(depth as i64))
1261 }
1262 _ => Err(QueryError::Type(
1263 "ACCOUNT_DEPTH expects an account string".to_string(),
1264 )),
1265 }
1266 }
1267 "ACCOUNT_SORTKEY" => {
1268 Self::require_args(name, func, 1)?;
1269 let val = self.evaluate_expr(&func.args[0], ctx)?;
1270 match val {
1271 Value::String(s) => Ok(Value::String(s)),
1272 _ => Err(QueryError::Type(
1273 "ACCOUNT_SORTKEY expects an account string".to_string(),
1274 )),
1275 }
1276 }
1277 _ => unreachable!(),
1278 }
1279 }
1280
1281 fn eval_root(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1283 if func.args.is_empty() || func.args.len() > 2 {
1284 return Err(QueryError::InvalidArguments(
1285 "ROOT".to_string(),
1286 "expected 1 or 2 arguments".to_string(),
1287 ));
1288 }
1289
1290 let val = self.evaluate_expr(&func.args[0], ctx)?;
1291 let n = if func.args.len() == 2 {
1292 match self.evaluate_expr(&func.args[1], ctx)? {
1293 Value::Integer(i) => i as usize,
1294 _ => {
1295 return Err(QueryError::Type(
1296 "ROOT second arg must be integer".to_string(),
1297 ));
1298 }
1299 }
1300 } else {
1301 1
1302 };
1303
1304 match val {
1305 Value::String(s) => {
1306 let parts: Vec<&str> = s.split(':').collect();
1307 if n >= parts.len() {
1308 Ok(Value::String(s))
1309 } else {
1310 Ok(Value::String(parts[..n].join(":")))
1311 }
1312 }
1313 _ => Err(QueryError::Type(
1314 "ROOT expects an account string".to_string(),
1315 )),
1316 }
1317 }
1318
1319 fn eval_math_function(
1321 &self,
1322 name: &str,
1323 func: &FunctionCall,
1324 ctx: &PostingContext,
1325 ) -> Result<Value, QueryError> {
1326 match name {
1327 "ABS" => {
1328 Self::require_args(name, func, 1)?;
1329 let val = self.evaluate_expr(&func.args[0], ctx)?;
1330 match val {
1331 Value::Number(n) => Ok(Value::Number(n.abs())),
1332 Value::Integer(i) => Ok(Value::Integer(i.abs())),
1333 _ => Err(QueryError::Type("ABS expects a number".to_string())),
1334 }
1335 }
1336 "NEG" => {
1337 Self::require_args(name, func, 1)?;
1338 let val = self.evaluate_expr(&func.args[0], ctx)?;
1339 match val {
1340 Value::Number(n) => Ok(Value::Number(-n)),
1341 Value::Integer(i) => Ok(Value::Integer(-i)),
1342 _ => Err(QueryError::Type("NEG expects a number".to_string())),
1343 }
1344 }
1345 "ROUND" => self.eval_round(func, ctx),
1346 "SAFEDIV" => self.eval_safediv(func, ctx),
1347 _ => unreachable!(),
1348 }
1349 }
1350
1351 fn eval_round(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1353 if func.args.is_empty() || func.args.len() > 2 {
1354 return Err(QueryError::InvalidArguments(
1355 "ROUND".to_string(),
1356 "expected 1 or 2 arguments".to_string(),
1357 ));
1358 }
1359
1360 let val = self.evaluate_expr(&func.args[0], ctx)?;
1361 let decimals = if func.args.len() == 2 {
1362 match self.evaluate_expr(&func.args[1], ctx)? {
1363 Value::Integer(i) => i as u32,
1364 _ => {
1365 return Err(QueryError::Type(
1366 "ROUND second arg must be integer".to_string(),
1367 ));
1368 }
1369 }
1370 } else {
1371 0
1372 };
1373
1374 match val {
1375 Value::Number(n) => Ok(Value::Number(n.round_dp(decimals))),
1376 Value::Integer(i) => Ok(Value::Integer(i)),
1377 _ => Err(QueryError::Type("ROUND expects a number".to_string())),
1378 }
1379 }
1380
1381 fn eval_safediv(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1383 Self::require_args("SAFEDIV", func, 2)?;
1384 let num = self.evaluate_expr(&func.args[0], ctx)?;
1385 let den = self.evaluate_expr(&func.args[1], ctx)?;
1386
1387 match (num, den) {
1388 (Value::Number(n), Value::Number(d)) => {
1389 if d.is_zero() {
1390 Ok(Value::Number(Decimal::ZERO))
1391 } else {
1392 Ok(Value::Number(n / d))
1393 }
1394 }
1395 (Value::Integer(n), Value::Integer(d)) => {
1396 if d == 0 {
1397 Ok(Value::Integer(0))
1398 } else {
1399 Ok(Value::Integer(n / d))
1400 }
1401 }
1402 _ => Err(QueryError::Type("SAFEDIV expects two numbers".to_string())),
1403 }
1404 }
1405
1406 fn eval_position_function(
1408 &self,
1409 name: &str,
1410 func: &FunctionCall,
1411 ctx: &PostingContext,
1412 ) -> Result<Value, QueryError> {
1413 match name {
1414 "NUMBER" => {
1415 Self::require_args(name, func, 1)?;
1416 let val = self.evaluate_expr(&func.args[0], ctx)?;
1417 match val {
1418 Value::Amount(a) => Ok(Value::Number(a.number)),
1419 Value::Position(p) => Ok(Value::Number(p.units.number)),
1420 Value::Number(n) => Ok(Value::Number(n)),
1421 Value::Integer(i) => Ok(Value::Number(Decimal::from(i))),
1422 _ => Err(QueryError::Type(
1423 "NUMBER expects an amount or position".to_string(),
1424 )),
1425 }
1426 }
1427 "CURRENCY" => {
1428 Self::require_args(name, func, 1)?;
1429 let val = self.evaluate_expr(&func.args[0], ctx)?;
1430 match val {
1431 Value::Amount(a) => Ok(Value::String(a.currency.to_string())),
1432 Value::Position(p) => Ok(Value::String(p.units.currency.to_string())),
1433 _ => Err(QueryError::Type(
1434 "CURRENCY expects an amount or position".to_string(),
1435 )),
1436 }
1437 }
1438 "GETITEM" | "GET" => self.eval_getitem(func, ctx),
1439 "UNITS" => self.eval_units(func, ctx),
1440 "COST" => self.eval_cost(func, ctx),
1441 "WEIGHT" => self.eval_weight(func, ctx),
1442 "VALUE" => self.eval_value(func, ctx),
1443 _ => unreachable!(),
1444 }
1445 }
1446
1447 fn eval_getitem(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1449 Self::require_args("GETITEM", func, 2)?;
1450 let val = self.evaluate_expr(&func.args[0], ctx)?;
1451 let key = self.evaluate_expr(&func.args[1], ctx)?;
1452
1453 match (val, key) {
1454 (Value::Inventory(inv), Value::String(currency)) => {
1455 let total = inv.units(¤cy);
1456 if total.is_zero() {
1457 Ok(Value::Null)
1458 } else {
1459 Ok(Value::Amount(Amount::new(total, currency)))
1460 }
1461 }
1462 _ => Err(QueryError::Type(
1463 "GETITEM expects (inventory, string)".to_string(),
1464 )),
1465 }
1466 }
1467
1468 fn eval_units(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1470 Self::require_args("UNITS", func, 1)?;
1471 let val = self.evaluate_expr(&func.args[0], ctx)?;
1472
1473 match val {
1474 Value::Position(p) => Ok(Value::Amount(p.units)),
1475 Value::Amount(a) => Ok(Value::Amount(a)),
1476 Value::Inventory(inv) => {
1477 let positions: Vec<String> = inv
1478 .positions()
1479 .iter()
1480 .map(|p| format!("{} {}", p.units.number, p.units.currency))
1481 .collect();
1482 Ok(Value::String(positions.join(", ")))
1483 }
1484 _ => Err(QueryError::Type(
1485 "UNITS expects a position or inventory".to_string(),
1486 )),
1487 }
1488 }
1489
1490 fn eval_cost(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1492 Self::require_args("COST", func, 1)?;
1493 let val = self.evaluate_expr(&func.args[0], ctx)?;
1494
1495 match val {
1496 Value::Position(p) => {
1497 if let Some(cost) = &p.cost {
1498 let total = p.units.number.abs() * cost.number;
1499 Ok(Value::Amount(Amount::new(total, cost.currency.clone())))
1500 } else {
1501 Ok(Value::Null)
1502 }
1503 }
1504 Value::Amount(a) => Ok(Value::Amount(a)),
1505 Value::Inventory(inv) => {
1506 let mut total = Decimal::ZERO;
1507 let mut currency: Option<InternedStr> = None;
1508 for pos in inv.positions() {
1509 if let Some(cost) = &pos.cost {
1510 total += pos.units.number.abs() * cost.number;
1511 if currency.is_none() {
1512 currency = Some(cost.currency.clone());
1513 }
1514 }
1515 }
1516 if let Some(curr) = currency {
1517 Ok(Value::Amount(Amount::new(total, curr)))
1518 } else {
1519 Ok(Value::Null)
1520 }
1521 }
1522 _ => Err(QueryError::Type(
1523 "COST expects a position or inventory".to_string(),
1524 )),
1525 }
1526 }
1527
1528 fn eval_weight(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1530 Self::require_args("WEIGHT", func, 1)?;
1531 let val = self.evaluate_expr(&func.args[0], ctx)?;
1532
1533 match val {
1534 Value::Position(p) => {
1535 if let Some(cost) = &p.cost {
1536 let total = p.units.number * cost.number;
1537 Ok(Value::Amount(Amount::new(total, cost.currency.clone())))
1538 } else {
1539 Ok(Value::Amount(p.units))
1540 }
1541 }
1542 Value::Amount(a) => Ok(Value::Amount(a)),
1543 _ => Err(QueryError::Type(
1544 "WEIGHT expects a position or amount".to_string(),
1545 )),
1546 }
1547 }
1548
1549 fn eval_value(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1551 if func.args.is_empty() || func.args.len() > 2 {
1552 return Err(QueryError::InvalidArguments(
1553 "VALUE".to_string(),
1554 "expected 1-2 arguments".to_string(),
1555 ));
1556 }
1557
1558 let target_currency = if func.args.len() == 2 {
1559 match self.evaluate_expr(&func.args[1], ctx)? {
1560 Value::String(s) => s,
1561 _ => {
1562 return Err(QueryError::Type(
1563 "VALUE second argument must be a currency string".to_string(),
1564 ));
1565 }
1566 }
1567 } else {
1568 self.target_currency.clone().ok_or_else(|| {
1569 QueryError::InvalidArguments(
1570 "VALUE".to_string(),
1571 "no target currency set; either call set_target_currency() on the executor \
1572 or pass the currency as VALUE(amount, 'USD')"
1573 .to_string(),
1574 )
1575 })?
1576 };
1577
1578 let val = self.evaluate_expr(&func.args[0], ctx)?;
1579 let date = ctx.transaction.date;
1580
1581 match val {
1582 Value::Position(p) => {
1583 if p.units.currency == target_currency {
1584 Ok(Value::Amount(p.units))
1585 } else if let Some(converted) =
1586 self.price_db.convert(&p.units, &target_currency, date)
1587 {
1588 Ok(Value::Amount(converted))
1589 } else {
1590 Ok(Value::Amount(p.units))
1591 }
1592 }
1593 Value::Amount(a) => {
1594 if a.currency == target_currency {
1595 Ok(Value::Amount(a))
1596 } else if let Some(converted) = self.price_db.convert(&a, &target_currency, date) {
1597 Ok(Value::Amount(converted))
1598 } else {
1599 Ok(Value::Amount(a))
1600 }
1601 }
1602 Value::Inventory(inv) => {
1603 let mut total = Decimal::ZERO;
1604 for pos in inv.positions() {
1605 if pos.units.currency == target_currency {
1606 total += pos.units.number;
1607 } else if let Some(converted) =
1608 self.price_db.convert(&pos.units, &target_currency, date)
1609 {
1610 total += converted.number;
1611 }
1612 }
1613 Ok(Value::Amount(Amount::new(total, &target_currency)))
1614 }
1615 _ => Err(QueryError::Type(
1616 "VALUE expects a position or inventory".to_string(),
1617 )),
1618 }
1619 }
1620
1621 fn eval_coalesce(
1623 &self,
1624 func: &FunctionCall,
1625 ctx: &PostingContext,
1626 ) -> Result<Value, QueryError> {
1627 for arg in &func.args {
1628 let val = self.evaluate_expr(arg, ctx)?;
1629 if !matches!(val, Value::Null) {
1630 return Ok(val);
1631 }
1632 }
1633 Ok(Value::Null)
1634 }
1635
1636 fn evaluate_function_on_values(&self, name: &str, args: &[Value]) -> Result<Value, QueryError> {
1638 let name_upper = name.to_uppercase();
1639 match name_upper.as_str() {
1640 "TODAY" => Ok(Value::Date(chrono::Local::now().date_naive())),
1642 "YEAR" => {
1643 Self::require_args_count(&name_upper, args, 1)?;
1644 match &args[0] {
1645 Value::Date(d) => Ok(Value::Integer(d.year().into())),
1646 _ => Err(QueryError::Type("YEAR expects a date".to_string())),
1647 }
1648 }
1649 "MONTH" => {
1650 Self::require_args_count(&name_upper, args, 1)?;
1651 match &args[0] {
1652 Value::Date(d) => Ok(Value::Integer(d.month().into())),
1653 _ => Err(QueryError::Type("MONTH expects a date".to_string())),
1654 }
1655 }
1656 "DAY" => {
1657 Self::require_args_count(&name_upper, args, 1)?;
1658 match &args[0] {
1659 Value::Date(d) => Ok(Value::Integer(d.day().into())),
1660 _ => Err(QueryError::Type("DAY expects a date".to_string())),
1661 }
1662 }
1663 "LENGTH" => {
1665 Self::require_args_count(&name_upper, args, 1)?;
1666 match &args[0] {
1667 Value::String(s) => Ok(Value::Integer(s.len() as i64)),
1668 _ => Err(QueryError::Type("LENGTH expects a string".to_string())),
1669 }
1670 }
1671 "UPPER" => {
1672 Self::require_args_count(&name_upper, args, 1)?;
1673 match &args[0] {
1674 Value::String(s) => Ok(Value::String(s.to_uppercase())),
1675 _ => Err(QueryError::Type("UPPER expects a string".to_string())),
1676 }
1677 }
1678 "LOWER" => {
1679 Self::require_args_count(&name_upper, args, 1)?;
1680 match &args[0] {
1681 Value::String(s) => Ok(Value::String(s.to_lowercase())),
1682 _ => Err(QueryError::Type("LOWER expects a string".to_string())),
1683 }
1684 }
1685 "TRIM" => {
1686 Self::require_args_count(&name_upper, args, 1)?;
1687 match &args[0] {
1688 Value::String(s) => Ok(Value::String(s.trim().to_string())),
1689 _ => Err(QueryError::Type("TRIM expects a string".to_string())),
1690 }
1691 }
1692 "ABS" => {
1694 Self::require_args_count(&name_upper, args, 1)?;
1695 match &args[0] {
1696 Value::Number(n) => Ok(Value::Number(n.abs())),
1697 Value::Integer(i) => Ok(Value::Integer(i.abs())),
1698 _ => Err(QueryError::Type("ABS expects a number".to_string())),
1699 }
1700 }
1701 "ROUND" => {
1702 if args.is_empty() || args.len() > 2 {
1703 return Err(QueryError::InvalidArguments(
1704 "ROUND".to_string(),
1705 "expected 1 or 2 arguments".to_string(),
1706 ));
1707 }
1708 match &args[0] {
1709 Value::Number(n) => {
1710 let scale = if args.len() == 2 {
1711 match &args[1] {
1712 Value::Integer(i) => *i as u32,
1713 _ => 0,
1714 }
1715 } else {
1716 0
1717 };
1718 Ok(Value::Number(n.round_dp(scale)))
1719 }
1720 Value::Integer(i) => Ok(Value::Integer(*i)),
1721 _ => Err(QueryError::Type("ROUND expects a number".to_string())),
1722 }
1723 }
1724 "COALESCE" => {
1726 for arg in args {
1727 if !matches!(arg, Value::Null) {
1728 return Ok(arg.clone());
1729 }
1730 }
1731 Ok(Value::Null)
1732 }
1733 "SUM" | "COUNT" | "MIN" | "MAX" | "FIRST" | "LAST" | "AVG" => Ok(Value::Null),
1735 _ => Err(QueryError::UnknownFunction(name.to_string())),
1736 }
1737 }
1738
1739 fn require_args_count(name: &str, args: &[Value], expected: usize) -> Result<(), QueryError> {
1741 if args.len() != expected {
1742 return Err(QueryError::InvalidArguments(
1743 name.to_string(),
1744 format!("expected {} argument(s), got {}", expected, args.len()),
1745 ));
1746 }
1747 Ok(())
1748 }
1749
1750 fn require_args(name: &str, func: &FunctionCall, expected: usize) -> Result<(), QueryError> {
1752 if func.args.len() != expected {
1753 return Err(QueryError::InvalidArguments(
1754 name.to_string(),
1755 format!("expected {expected} argument(s)"),
1756 ));
1757 }
1758 Ok(())
1759 }
1760
1761 fn evaluate_binary_op(&self, op: &BinaryOp, ctx: &PostingContext) -> Result<Value, QueryError> {
1763 let left = self.evaluate_expr(&op.left, ctx)?;
1764 let right = self.evaluate_expr(&op.right, ctx)?;
1765
1766 match op.op {
1767 BinaryOperator::Eq => Ok(Value::Boolean(self.values_equal(&left, &right))),
1768 BinaryOperator::Ne => Ok(Value::Boolean(!self.values_equal(&left, &right))),
1769 BinaryOperator::Lt => self.compare_values(&left, &right, std::cmp::Ordering::is_lt),
1770 BinaryOperator::Le => self.compare_values(&left, &right, std::cmp::Ordering::is_le),
1771 BinaryOperator::Gt => self.compare_values(&left, &right, std::cmp::Ordering::is_gt),
1772 BinaryOperator::Ge => self.compare_values(&left, &right, std::cmp::Ordering::is_ge),
1773 BinaryOperator::And => {
1774 let l = self.to_bool(&left)?;
1775 let r = self.to_bool(&right)?;
1776 Ok(Value::Boolean(l && r))
1777 }
1778 BinaryOperator::Or => {
1779 let l = self.to_bool(&left)?;
1780 let r = self.to_bool(&right)?;
1781 Ok(Value::Boolean(l || r))
1782 }
1783 BinaryOperator::Regex => {
1784 let s = match left {
1786 Value::String(s) => s,
1787 _ => {
1788 return Err(QueryError::Type(
1789 "regex requires string left operand".to_string(),
1790 ));
1791 }
1792 };
1793 let pattern = match right {
1794 Value::String(p) => p,
1795 _ => {
1796 return Err(QueryError::Type(
1797 "regex requires string pattern".to_string(),
1798 ));
1799 }
1800 };
1801 Ok(Value::Boolean(s.contains(&pattern)))
1803 }
1804 BinaryOperator::In => {
1805 match right {
1807 Value::StringSet(set) => {
1808 let needle = match left {
1809 Value::String(s) => s,
1810 _ => {
1811 return Err(QueryError::Type(
1812 "IN requires string left operand".to_string(),
1813 ));
1814 }
1815 };
1816 Ok(Value::Boolean(set.contains(&needle)))
1817 }
1818 _ => Err(QueryError::Type(
1819 "IN requires set right operand".to_string(),
1820 )),
1821 }
1822 }
1823 BinaryOperator::Add => self.arithmetic_op(&left, &right, |a, b| a + b),
1824 BinaryOperator::Sub => self.arithmetic_op(&left, &right, |a, b| a - b),
1825 BinaryOperator::Mul => self.arithmetic_op(&left, &right, |a, b| a * b),
1826 BinaryOperator::Div => self.arithmetic_op(&left, &right, |a, b| a / b),
1827 }
1828 }
1829
1830 fn evaluate_unary_op(&self, op: &UnaryOp, ctx: &PostingContext) -> Result<Value, QueryError> {
1832 let val = self.evaluate_expr(&op.operand, ctx)?;
1833 self.unary_op_on_value(op.op, &val)
1834 }
1835
1836 fn unary_op_on_value(&self, op: UnaryOperator, val: &Value) -> Result<Value, QueryError> {
1838 match op {
1839 UnaryOperator::Not => {
1840 let b = self.to_bool(val)?;
1841 Ok(Value::Boolean(!b))
1842 }
1843 UnaryOperator::Neg => match val {
1844 Value::Number(n) => Ok(Value::Number(-*n)),
1845 Value::Integer(i) => Ok(Value::Integer(-*i)),
1846 _ => Err(QueryError::Type(
1847 "negation requires numeric value".to_string(),
1848 )),
1849 },
1850 }
1851 }
1852
1853 fn values_equal(&self, left: &Value, right: &Value) -> bool {
1855 match (left, right) {
1857 (Value::Null, Value::Null) => true,
1858 (Value::String(a), Value::String(b)) => a == b,
1859 (Value::Number(a), Value::Number(b)) => a == b,
1860 (Value::Integer(a), Value::Integer(b)) => a == b,
1861 (Value::Number(a), Value::Integer(b)) => *a == Decimal::from(*b),
1862 (Value::Integer(a), Value::Number(b)) => Decimal::from(*a) == *b,
1863 (Value::Date(a), Value::Date(b)) => a == b,
1864 (Value::Boolean(a), Value::Boolean(b)) => a == b,
1865 _ => false,
1866 }
1867 }
1868
1869 fn compare_values<F>(&self, left: &Value, right: &Value, pred: F) -> Result<Value, QueryError>
1871 where
1872 F: FnOnce(std::cmp::Ordering) -> bool,
1873 {
1874 let ord = match (left, right) {
1875 (Value::Number(a), Value::Number(b)) => a.cmp(b),
1876 (Value::Integer(a), Value::Integer(b)) => a.cmp(b),
1877 (Value::Number(a), Value::Integer(b)) => a.cmp(&Decimal::from(*b)),
1878 (Value::Integer(a), Value::Number(b)) => Decimal::from(*a).cmp(b),
1879 (Value::String(a), Value::String(b)) => a.cmp(b),
1880 (Value::Date(a), Value::Date(b)) => a.cmp(b),
1881 _ => return Err(QueryError::Type("cannot compare values".to_string())),
1882 };
1883 Ok(Value::Boolean(pred(ord)))
1884 }
1885
1886 fn value_less_than(&self, left: &Value, right: &Value) -> Result<bool, QueryError> {
1888 let ord = match (left, right) {
1889 (Value::Number(a), Value::Number(b)) => a.cmp(b),
1890 (Value::Integer(a), Value::Integer(b)) => a.cmp(b),
1891 (Value::Number(a), Value::Integer(b)) => a.cmp(&Decimal::from(*b)),
1892 (Value::Integer(a), Value::Number(b)) => Decimal::from(*a).cmp(b),
1893 (Value::String(a), Value::String(b)) => a.cmp(b),
1894 (Value::Date(a), Value::Date(b)) => a.cmp(b),
1895 _ => return Err(QueryError::Type("cannot compare values".to_string())),
1896 };
1897 Ok(ord.is_lt())
1898 }
1899
1900 fn arithmetic_op<F>(&self, left: &Value, right: &Value, op: F) -> Result<Value, QueryError>
1902 where
1903 F: FnOnce(Decimal, Decimal) -> Decimal,
1904 {
1905 let (a, b) = match (left, right) {
1906 (Value::Number(a), Value::Number(b)) => (*a, *b),
1907 (Value::Integer(a), Value::Integer(b)) => (Decimal::from(*a), Decimal::from(*b)),
1908 (Value::Number(a), Value::Integer(b)) => (*a, Decimal::from(*b)),
1909 (Value::Integer(a), Value::Number(b)) => (Decimal::from(*a), *b),
1910 _ => {
1911 return Err(QueryError::Type(
1912 "arithmetic requires numeric values".to_string(),
1913 ));
1914 }
1915 };
1916 Ok(Value::Number(op(a, b)))
1917 }
1918
1919 fn to_bool(&self, val: &Value) -> Result<bool, QueryError> {
1921 match val {
1922 Value::Boolean(b) => Ok(*b),
1923 Value::Null => Ok(false),
1924 _ => Err(QueryError::Type("expected boolean".to_string())),
1925 }
1926 }
1927
1928 fn is_aggregate_expr(expr: &Expr) -> bool {
1930 match expr {
1931 Expr::Function(func) => {
1932 matches!(
1933 func.name.to_uppercase().as_str(),
1934 "SUM" | "COUNT" | "MIN" | "MAX" | "FIRST" | "LAST" | "AVG"
1935 )
1936 }
1937 Expr::BinaryOp(op) => {
1938 Self::is_aggregate_expr(&op.left) || Self::is_aggregate_expr(&op.right)
1939 }
1940 Expr::UnaryOp(op) => Self::is_aggregate_expr(&op.operand),
1941 Expr::Paren(inner) => Self::is_aggregate_expr(inner),
1942 _ => false,
1943 }
1944 }
1945
1946 const fn is_window_expr(expr: &Expr) -> bool {
1948 matches!(expr, Expr::Window(_))
1949 }
1950
1951 fn has_window_functions(targets: &[Target]) -> bool {
1953 targets.iter().any(|t| Self::is_window_expr(&t.expr))
1954 }
1955
1956 fn resolve_column_names(&self, targets: &[Target]) -> Result<Vec<String>, QueryError> {
1958 let mut names = Vec::new();
1959 for (i, target) in targets.iter().enumerate() {
1960 if let Some(alias) = &target.alias {
1961 names.push(alias.clone());
1962 } else {
1963 names.push(self.expr_to_name(&target.expr, i));
1964 }
1965 }
1966 Ok(names)
1967 }
1968
1969 fn expr_to_name(&self, expr: &Expr, index: usize) -> String {
1971 match expr {
1972 Expr::Wildcard => "*".to_string(),
1973 Expr::Column(name) => name.clone(),
1974 Expr::Function(func) => func.name.clone(),
1975 Expr::Window(wf) => wf.name.clone(),
1976 _ => format!("col{index}"),
1977 }
1978 }
1979
1980 fn evaluate_row(&self, targets: &[Target], ctx: &PostingContext) -> Result<Row, QueryError> {
1982 self.evaluate_row_with_window(targets, ctx, None)
1983 }
1984
1985 fn evaluate_row_with_window(
1987 &self,
1988 targets: &[Target],
1989 ctx: &PostingContext,
1990 window_ctx: Option<&WindowContext>,
1991 ) -> Result<Row, QueryError> {
1992 let mut row = Vec::new();
1993 for target in targets {
1994 if matches!(target.expr, Expr::Wildcard) {
1995 row.push(Value::Date(ctx.transaction.date));
1997 row.push(Value::String(ctx.transaction.flag.to_string()));
1998 row.push(
1999 ctx.transaction
2000 .payee
2001 .as_ref()
2002 .map_or(Value::Null, |p| Value::String(p.to_string())),
2003 );
2004 row.push(Value::String(ctx.transaction.narration.to_string()));
2005 let posting = &ctx.transaction.postings[ctx.posting_index];
2006 row.push(Value::String(posting.account.to_string()));
2007 row.push(
2008 posting
2009 .amount()
2010 .map_or(Value::Null, |u| Value::Amount(u.clone())),
2011 );
2012 } else if let Expr::Window(wf) = &target.expr {
2013 row.push(self.evaluate_window_function(wf, window_ctx)?);
2015 } else {
2016 row.push(self.evaluate_expr(&target.expr, ctx)?);
2017 }
2018 }
2019 Ok(row)
2020 }
2021
2022 fn evaluate_window_function(
2024 &self,
2025 wf: &WindowFunction,
2026 window_ctx: Option<&WindowContext>,
2027 ) -> Result<Value, QueryError> {
2028 let ctx = window_ctx.ok_or_else(|| {
2029 QueryError::Evaluation("Window function requires window context".to_string())
2030 })?;
2031
2032 match wf.name.to_uppercase().as_str() {
2033 "ROW_NUMBER" => Ok(Value::Integer(ctx.row_number as i64)),
2034 "RANK" => Ok(Value::Integer(ctx.rank as i64)),
2035 "DENSE_RANK" => Ok(Value::Integer(ctx.dense_rank as i64)),
2036 _ => Err(QueryError::Evaluation(format!(
2037 "Window function '{}' not yet implemented",
2038 wf.name
2039 ))),
2040 }
2041 }
2042
2043 fn compute_window_contexts(
2045 &self,
2046 postings: &[PostingContext],
2047 wf: &WindowFunction,
2048 ) -> Result<Vec<WindowContext>, QueryError> {
2049 let spec = &wf.over;
2050
2051 let mut partition_keys: Vec<String> = Vec::with_capacity(postings.len());
2053 for ctx in postings {
2054 if let Some(partition_exprs) = &spec.partition_by {
2055 let mut key_values = Vec::new();
2056 for expr in partition_exprs {
2057 key_values.push(self.evaluate_expr(expr, ctx)?);
2058 }
2059 partition_keys.push(Self::make_group_key(&key_values));
2060 } else {
2061 partition_keys.push(String::new());
2063 }
2064 }
2065
2066 let mut partitions: HashMap<String, Vec<usize>> = HashMap::new();
2068 for (idx, key) in partition_keys.iter().enumerate() {
2069 partitions.entry(key.clone()).or_default().push(idx);
2070 }
2071
2072 let mut order_values: Vec<Vec<Value>> = Vec::with_capacity(postings.len());
2074 for ctx in postings {
2075 if let Some(order_specs) = &spec.order_by {
2076 let mut values = Vec::new();
2077 for order_spec in order_specs {
2078 values.push(self.evaluate_expr(&order_spec.expr, ctx)?);
2079 }
2080 order_values.push(values);
2081 } else {
2082 order_values.push(Vec::new());
2083 }
2084 }
2085
2086 let mut window_contexts: Vec<WindowContext> = vec![
2088 WindowContext {
2089 row_number: 0,
2090 rank: 0,
2091 dense_rank: 0,
2092 };
2093 postings.len()
2094 ];
2095
2096 for indices in partitions.values() {
2098 let mut sorted_indices: Vec<usize> = indices.clone();
2100 if spec.order_by.is_some() {
2101 let order_specs = spec.order_by.as_ref().unwrap();
2102 sorted_indices.sort_by(|&a, &b| {
2103 let vals_a = &order_values[a];
2104 let vals_b = &order_values[b];
2105 for (i, (va, vb)) in vals_a.iter().zip(vals_b.iter()).enumerate() {
2106 let cmp = self.compare_values_for_sort(va, vb);
2107 if cmp != std::cmp::Ordering::Equal {
2108 return if order_specs
2109 .get(i)
2110 .is_some_and(|s| s.direction == SortDirection::Desc)
2111 {
2112 cmp.reverse()
2113 } else {
2114 cmp
2115 };
2116 }
2117 }
2118 std::cmp::Ordering::Equal
2119 });
2120 }
2121
2122 let mut row_num = 1;
2124 let mut rank = 1;
2125 let mut dense_rank = 1;
2126 let mut prev_values: Option<&Vec<Value>> = None;
2127
2128 for (position, &original_idx) in sorted_indices.iter().enumerate() {
2129 let current_values = &order_values[original_idx];
2130
2131 let is_tie = if let Some(prev) = prev_values {
2133 current_values == prev
2134 } else {
2135 false
2136 };
2137
2138 if !is_tie && position > 0 {
2139 rank = position + 1;
2141 dense_rank += 1;
2142 }
2143 window_contexts[original_idx] = WindowContext {
2144 row_number: row_num,
2145 rank,
2146 dense_rank,
2147 };
2148
2149 row_num += 1;
2150 prev_values = Some(current_values);
2151 }
2152 }
2153
2154 Ok(window_contexts)
2155 }
2156
2157 fn find_window_function(targets: &[Target]) -> Option<&WindowFunction> {
2159 for target in targets {
2160 if let Expr::Window(wf) = &target.expr {
2161 return Some(wf);
2162 }
2163 }
2164 None
2165 }
2166
2167 fn make_group_key(values: &[Value]) -> String {
2170 use std::fmt::Write;
2171 let mut key = String::new();
2172 for (i, v) in values.iter().enumerate() {
2173 if i > 0 {
2174 key.push('\x00'); }
2176 match v {
2177 Value::String(s) => {
2178 key.push('S');
2179 key.push_str(s);
2180 }
2181 Value::Number(n) => {
2182 key.push('N');
2183 let _ = write!(key, "{n}");
2184 }
2185 Value::Integer(n) => {
2186 key.push('I');
2187 let _ = write!(key, "{n}");
2188 }
2189 Value::Date(d) => {
2190 key.push('D');
2191 let _ = write!(key, "{d}");
2192 }
2193 Value::Boolean(b) => {
2194 key.push(if *b { 'T' } else { 'F' });
2195 }
2196 Value::Amount(a) => {
2197 key.push('A');
2198 let _ = write!(key, "{} {}", a.number, a.currency);
2199 }
2200 Value::Position(p) => {
2201 key.push('P');
2202 let _ = write!(key, "{} {}", p.units.number, p.units.currency);
2203 }
2204 Value::Inventory(_) => {
2205 key.push('V');
2208 }
2209 Value::StringSet(ss) => {
2210 key.push('Z');
2211 for s in ss {
2212 key.push_str(s);
2213 key.push(',');
2214 }
2215 }
2216 Value::Null => {
2217 key.push('0');
2218 }
2219 }
2220 }
2221 key
2222 }
2223
2224 fn group_postings<'b>(
2227 &self,
2228 postings: &'b [PostingContext<'a>],
2229 group_by: Option<&Vec<Expr>>,
2230 ) -> Result<Vec<(Vec<Value>, Vec<&'b PostingContext<'a>>)>, QueryError> {
2231 if let Some(group_exprs) = group_by {
2232 let mut group_map: HashMap<String, (Vec<Value>, Vec<&PostingContext<'a>>)> =
2234 HashMap::new();
2235
2236 for ctx in postings {
2237 let mut key_values = Vec::with_capacity(group_exprs.len());
2238 for expr in group_exprs {
2239 key_values.push(self.evaluate_expr(expr, ctx)?);
2240 }
2241 let hash_key = Self::make_group_key(&key_values);
2242
2243 group_map
2244 .entry(hash_key)
2245 .or_insert_with(|| (key_values, Vec::new()))
2246 .1
2247 .push(ctx);
2248 }
2249
2250 Ok(group_map.into_values().collect())
2251 } else {
2252 Ok(vec![(Vec::new(), postings.iter().collect())])
2254 }
2255 }
2256
2257 fn evaluate_aggregate_row(
2259 &self,
2260 targets: &[Target],
2261 group: &[&PostingContext],
2262 ) -> Result<Row, QueryError> {
2263 let mut row = Vec::new();
2264 for target in targets {
2265 row.push(self.evaluate_aggregate_expr(&target.expr, group)?);
2266 }
2267 Ok(row)
2268 }
2269
2270 fn evaluate_aggregate_expr(
2272 &self,
2273 expr: &Expr,
2274 group: &[&PostingContext],
2275 ) -> Result<Value, QueryError> {
2276 match expr {
2277 Expr::Function(func) => {
2278 match func.name.to_uppercase().as_str() {
2279 "COUNT" => {
2280 Ok(Value::Integer(group.len() as i64))
2282 }
2283 "SUM" => {
2284 if func.args.len() != 1 {
2285 return Err(QueryError::InvalidArguments(
2286 "SUM".to_string(),
2287 "expected 1 argument".to_string(),
2288 ));
2289 }
2290 let mut total = Inventory::new();
2291 for ctx in group {
2292 let val = self.evaluate_expr(&func.args[0], ctx)?;
2293 match val {
2294 Value::Amount(amt) => {
2295 let pos = Position::simple(amt);
2296 total.add(pos);
2297 }
2298 Value::Position(pos) => {
2299 total.add(pos);
2300 }
2301 Value::Number(n) => {
2302 let pos =
2304 Position::simple(Amount::new(n, "__NUMBER__".to_string()));
2305 total.add(pos);
2306 }
2307 Value::Null => {}
2308 _ => {
2309 return Err(QueryError::Type(
2310 "SUM requires numeric or position value".to_string(),
2311 ));
2312 }
2313 }
2314 }
2315 Ok(Value::Inventory(total))
2316 }
2317 "FIRST" => {
2318 if func.args.len() != 1 {
2319 return Err(QueryError::InvalidArguments(
2320 "FIRST".to_string(),
2321 "expected 1 argument".to_string(),
2322 ));
2323 }
2324 if let Some(ctx) = group.first() {
2325 self.evaluate_expr(&func.args[0], ctx)
2326 } else {
2327 Ok(Value::Null)
2328 }
2329 }
2330 "LAST" => {
2331 if func.args.len() != 1 {
2332 return Err(QueryError::InvalidArguments(
2333 "LAST".to_string(),
2334 "expected 1 argument".to_string(),
2335 ));
2336 }
2337 if let Some(ctx) = group.last() {
2338 self.evaluate_expr(&func.args[0], ctx)
2339 } else {
2340 Ok(Value::Null)
2341 }
2342 }
2343 "MIN" => {
2344 if func.args.len() != 1 {
2345 return Err(QueryError::InvalidArguments(
2346 "MIN".to_string(),
2347 "expected 1 argument".to_string(),
2348 ));
2349 }
2350 let mut min_val: Option<Value> = None;
2351 for ctx in group {
2352 let val = self.evaluate_expr(&func.args[0], ctx)?;
2353 if matches!(val, Value::Null) {
2354 continue;
2355 }
2356 min_val = Some(match min_val {
2357 None => val,
2358 Some(current) => {
2359 if self.value_less_than(&val, ¤t)? {
2360 val
2361 } else {
2362 current
2363 }
2364 }
2365 });
2366 }
2367 Ok(min_val.unwrap_or(Value::Null))
2368 }
2369 "MAX" => {
2370 if func.args.len() != 1 {
2371 return Err(QueryError::InvalidArguments(
2372 "MAX".to_string(),
2373 "expected 1 argument".to_string(),
2374 ));
2375 }
2376 let mut max_val: Option<Value> = None;
2377 for ctx in group {
2378 let val = self.evaluate_expr(&func.args[0], ctx)?;
2379 if matches!(val, Value::Null) {
2380 continue;
2381 }
2382 max_val = Some(match max_val {
2383 None => val,
2384 Some(current) => {
2385 if self.value_less_than(¤t, &val)? {
2386 val
2387 } else {
2388 current
2389 }
2390 }
2391 });
2392 }
2393 Ok(max_val.unwrap_or(Value::Null))
2394 }
2395 "AVG" => {
2396 if func.args.len() != 1 {
2397 return Err(QueryError::InvalidArguments(
2398 "AVG".to_string(),
2399 "expected 1 argument".to_string(),
2400 ));
2401 }
2402 let mut sum = Decimal::ZERO;
2403 let mut count = 0i64;
2404 for ctx in group {
2405 let val = self.evaluate_expr(&func.args[0], ctx)?;
2406 match val {
2407 Value::Number(n) => {
2408 sum += n;
2409 count += 1;
2410 }
2411 Value::Integer(i) => {
2412 sum += Decimal::from(i);
2413 count += 1;
2414 }
2415 Value::Null => {}
2416 _ => {
2417 return Err(QueryError::Type(
2418 "AVG expects numeric values".to_string(),
2419 ));
2420 }
2421 }
2422 }
2423 if count == 0 {
2424 Ok(Value::Null)
2425 } else {
2426 Ok(Value::Number(sum / Decimal::from(count)))
2427 }
2428 }
2429 _ => {
2430 if let Some(ctx) = group.first() {
2432 self.evaluate_function(func, ctx)
2433 } else {
2434 Ok(Value::Null)
2435 }
2436 }
2437 }
2438 }
2439 Expr::Column(_) => {
2440 if let Some(ctx) = group.first() {
2442 self.evaluate_expr(expr, ctx)
2443 } else {
2444 Ok(Value::Null)
2445 }
2446 }
2447 Expr::BinaryOp(op) => {
2448 let left = self.evaluate_aggregate_expr(&op.left, group)?;
2449 let right = self.evaluate_aggregate_expr(&op.right, group)?;
2450 self.binary_op_on_values(op.op, &left, &right)
2452 }
2453 _ => {
2454 if let Some(ctx) = group.first() {
2456 self.evaluate_expr(expr, ctx)
2457 } else {
2458 Ok(Value::Null)
2459 }
2460 }
2461 }
2462 }
2463
2464 fn binary_op_on_values(
2466 &self,
2467 op: BinaryOperator,
2468 left: &Value,
2469 right: &Value,
2470 ) -> Result<Value, QueryError> {
2471 match op {
2472 BinaryOperator::Eq => Ok(Value::Boolean(self.values_equal(left, right))),
2473 BinaryOperator::Ne => Ok(Value::Boolean(!self.values_equal(left, right))),
2474 BinaryOperator::Lt => self.compare_values(left, right, std::cmp::Ordering::is_lt),
2475 BinaryOperator::Le => self.compare_values(left, right, std::cmp::Ordering::is_le),
2476 BinaryOperator::Gt => self.compare_values(left, right, std::cmp::Ordering::is_gt),
2477 BinaryOperator::Ge => self.compare_values(left, right, std::cmp::Ordering::is_ge),
2478 BinaryOperator::And => {
2479 let l = self.to_bool(left)?;
2480 let r = self.to_bool(right)?;
2481 Ok(Value::Boolean(l && r))
2482 }
2483 BinaryOperator::Or => {
2484 let l = self.to_bool(left)?;
2485 let r = self.to_bool(right)?;
2486 Ok(Value::Boolean(l || r))
2487 }
2488 BinaryOperator::Regex => {
2489 let s = match left {
2491 Value::String(s) => s,
2492 _ => {
2493 return Err(QueryError::Type(
2494 "regex requires string left operand".to_string(),
2495 ));
2496 }
2497 };
2498 let pattern = match right {
2499 Value::String(p) => p,
2500 _ => {
2501 return Err(QueryError::Type(
2502 "regex requires string pattern".to_string(),
2503 ));
2504 }
2505 };
2506 let regex_result = self.get_or_compile_regex(pattern);
2508 let matches = if let Some(regex) = regex_result {
2509 regex.is_match(s)
2510 } else {
2511 s.contains(pattern)
2512 };
2513 Ok(Value::Boolean(matches))
2514 }
2515 BinaryOperator::In => {
2516 match right {
2518 Value::StringSet(set) => {
2519 let needle = match left {
2520 Value::String(s) => s,
2521 _ => {
2522 return Err(QueryError::Type(
2523 "IN requires string left operand".to_string(),
2524 ));
2525 }
2526 };
2527 Ok(Value::Boolean(set.contains(needle)))
2528 }
2529 _ => Err(QueryError::Type(
2530 "IN requires set right operand".to_string(),
2531 )),
2532 }
2533 }
2534 BinaryOperator::Add => self.arithmetic_op(left, right, |a, b| a + b),
2535 BinaryOperator::Sub => self.arithmetic_op(left, right, |a, b| a - b),
2536 BinaryOperator::Mul => self.arithmetic_op(left, right, |a, b| a * b),
2537 BinaryOperator::Div => self.arithmetic_op(left, right, |a, b| a / b),
2538 }
2539 }
2540
2541 fn sort_results(
2543 &self,
2544 result: &mut QueryResult,
2545 order_by: &[OrderSpec],
2546 ) -> Result<(), QueryError> {
2547 if order_by.is_empty() {
2548 return Ok(());
2549 }
2550
2551 let column_indices: std::collections::HashMap<&str, usize> = result
2553 .columns
2554 .iter()
2555 .enumerate()
2556 .map(|(i, name)| (name.as_str(), i))
2557 .collect();
2558
2559 let mut sort_specs: Vec<(usize, bool)> = Vec::new();
2561 for spec in order_by {
2562 let idx = match &spec.expr {
2564 Expr::Column(name) => column_indices
2565 .get(name.as_str())
2566 .copied()
2567 .ok_or_else(|| QueryError::UnknownColumn(name.clone()))?,
2568 Expr::Function(func) => {
2569 column_indices
2571 .get(func.name.as_str())
2572 .copied()
2573 .ok_or_else(|| {
2574 QueryError::Evaluation(format!(
2575 "ORDER BY expression not found in SELECT: {}",
2576 func.name
2577 ))
2578 })?
2579 }
2580 _ => {
2581 return Err(QueryError::Evaluation(
2582 "ORDER BY expression must reference a selected column".to_string(),
2583 ));
2584 }
2585 };
2586 let ascending = spec.direction != SortDirection::Desc;
2587 sort_specs.push((idx, ascending));
2588 }
2589
2590 result.rows.sort_by(|a, b| {
2592 for (idx, ascending) in &sort_specs {
2593 if *idx >= a.len() || *idx >= b.len() {
2594 continue;
2595 }
2596 let ord = self.compare_values_for_sort(&a[*idx], &b[*idx]);
2597 if ord != std::cmp::Ordering::Equal {
2598 return if *ascending { ord } else { ord.reverse() };
2599 }
2600 }
2601 std::cmp::Ordering::Equal
2602 });
2603
2604 Ok(())
2605 }
2606
2607 fn compare_values_for_sort(&self, left: &Value, right: &Value) -> std::cmp::Ordering {
2609 match (left, right) {
2610 (Value::Null, Value::Null) => std::cmp::Ordering::Equal,
2611 (Value::Null, _) => std::cmp::Ordering::Greater, (_, Value::Null) => std::cmp::Ordering::Less,
2613 (Value::Number(a), Value::Number(b)) => a.cmp(b),
2614 (Value::Integer(a), Value::Integer(b)) => a.cmp(b),
2615 (Value::Number(a), Value::Integer(b)) => a.cmp(&Decimal::from(*b)),
2616 (Value::Integer(a), Value::Number(b)) => Decimal::from(*a).cmp(b),
2617 (Value::String(a), Value::String(b)) => a.cmp(b),
2618 (Value::Date(a), Value::Date(b)) => a.cmp(b),
2619 (Value::Boolean(a), Value::Boolean(b)) => a.cmp(b),
2620 (Value::Amount(a), Value::Amount(b)) => a.number.cmp(&b.number),
2622 (Value::Position(a), Value::Position(b)) => a.units.number.cmp(&b.units.number),
2624 (Value::Inventory(a), Value::Inventory(b)) => {
2626 let a_val = a.positions().first().map(|p| &p.units.number);
2627 let b_val = b.positions().first().map(|p| &p.units.number);
2628 match (a_val, b_val) {
2629 (Some(av), Some(bv)) => av.cmp(bv),
2630 (Some(_), None) => std::cmp::Ordering::Less,
2631 (None, Some(_)) => std::cmp::Ordering::Greater,
2632 (None, None) => std::cmp::Ordering::Equal,
2633 }
2634 }
2635 _ => std::cmp::Ordering::Equal, }
2637 }
2638
2639 fn evaluate_having_filter(
2645 &self,
2646 having_expr: &Expr,
2647 row: &[Value],
2648 column_names: &[String],
2649 targets: &[Target],
2650 group: &[&PostingContext],
2651 ) -> Result<bool, QueryError> {
2652 let col_map: HashMap<String, usize> = column_names
2654 .iter()
2655 .enumerate()
2656 .map(|(i, name)| (name.to_uppercase(), i))
2657 .collect();
2658
2659 let alias_map: HashMap<String, usize> = targets
2661 .iter()
2662 .enumerate()
2663 .filter_map(|(i, t)| t.alias.as_ref().map(|a| (a.to_uppercase(), i)))
2664 .collect();
2665
2666 let val = self.evaluate_having_expr(having_expr, row, &col_map, &alias_map, group)?;
2667
2668 match val {
2669 Value::Boolean(b) => Ok(b),
2670 Value::Null => Ok(false), _ => Err(QueryError::Type(
2672 "HAVING clause must evaluate to boolean".to_string(),
2673 )),
2674 }
2675 }
2676
2677 fn evaluate_having_expr(
2679 &self,
2680 expr: &Expr,
2681 row: &[Value],
2682 col_map: &HashMap<String, usize>,
2683 alias_map: &HashMap<String, usize>,
2684 group: &[&PostingContext],
2685 ) -> Result<Value, QueryError> {
2686 match expr {
2687 Expr::Column(name) => {
2688 let upper_name = name.to_uppercase();
2689 if let Some(&idx) = alias_map.get(&upper_name) {
2691 Ok(row.get(idx).cloned().unwrap_or(Value::Null))
2692 } else if let Some(&idx) = col_map.get(&upper_name) {
2693 Ok(row.get(idx).cloned().unwrap_or(Value::Null))
2694 } else {
2695 Err(QueryError::Evaluation(format!(
2696 "Column '{name}' not found in SELECT clause for HAVING"
2697 )))
2698 }
2699 }
2700 Expr::Literal(lit) => self.evaluate_literal(lit),
2701 Expr::Function(_) => {
2702 self.evaluate_aggregate_expr(expr, group)
2704 }
2705 Expr::BinaryOp(op) => {
2706 let left = self.evaluate_having_expr(&op.left, row, col_map, alias_map, group)?;
2707 let right = self.evaluate_having_expr(&op.right, row, col_map, alias_map, group)?;
2708 self.binary_op_on_values(op.op, &left, &right)
2709 }
2710 Expr::UnaryOp(op) => {
2711 let val = self.evaluate_having_expr(&op.operand, row, col_map, alias_map, group)?;
2712 match op.op {
2713 UnaryOperator::Not => {
2714 let b = self.to_bool(&val)?;
2715 Ok(Value::Boolean(!b))
2716 }
2717 UnaryOperator::Neg => match val {
2718 Value::Number(n) => Ok(Value::Number(-n)),
2719 Value::Integer(i) => Ok(Value::Integer(-i)),
2720 _ => Err(QueryError::Type(
2721 "Cannot negate non-numeric value".to_string(),
2722 )),
2723 },
2724 }
2725 }
2726 Expr::Paren(inner) => self.evaluate_having_expr(inner, row, col_map, alias_map, group),
2727 Expr::Wildcard => Err(QueryError::Evaluation(
2728 "Wildcard not allowed in HAVING clause".to_string(),
2729 )),
2730 Expr::Window(_) => Err(QueryError::Evaluation(
2731 "Window functions not allowed in HAVING clause".to_string(),
2732 )),
2733 }
2734 }
2735
2736 fn apply_pivot(
2742 &self,
2743 result: &QueryResult,
2744 pivot_exprs: &[Expr],
2745 _targets: &[Target],
2746 ) -> Result<QueryResult, QueryError> {
2747 if pivot_exprs.is_empty() {
2748 return Ok(result.clone());
2749 }
2750
2751 let pivot_expr = &pivot_exprs[0];
2754
2755 let pivot_col_idx = self.find_pivot_column(result, pivot_expr)?;
2757
2758 let mut pivot_values: Vec<Value> = result
2760 .rows
2761 .iter()
2762 .map(|row| row.get(pivot_col_idx).cloned().unwrap_or(Value::Null))
2763 .collect();
2764 pivot_values.sort_by(|a, b| self.compare_values_for_sort(a, b));
2765 pivot_values.dedup();
2766
2767 let mut new_columns: Vec<String> = result
2769 .columns
2770 .iter()
2771 .enumerate()
2772 .filter(|(i, _)| *i != pivot_col_idx)
2773 .map(|(_, c)| c.clone())
2774 .collect();
2775
2776 let value_col_idx = result.columns.len() - 1;
2778
2779 for pv in &pivot_values {
2781 new_columns.push(self.value_to_string(pv));
2782 }
2783
2784 let mut new_result = QueryResult::new(new_columns);
2785
2786 let group_cols: Vec<usize> = (0..result.columns.len())
2788 .filter(|i| *i != pivot_col_idx && *i != value_col_idx)
2789 .collect();
2790
2791 let mut groups: HashMap<String, Vec<&Row>> = HashMap::new();
2792 for row in &result.rows {
2793 let key: String = group_cols
2794 .iter()
2795 .map(|&i| self.value_to_string(&row[i]))
2796 .collect::<Vec<_>>()
2797 .join("|");
2798 groups.entry(key).or_default().push(row);
2799 }
2800
2801 for (_key, group_rows) in groups {
2803 let mut new_row: Vec<Value> = group_cols
2804 .iter()
2805 .map(|&i| group_rows[0][i].clone())
2806 .collect();
2807
2808 for pv in &pivot_values {
2810 let matching_row = group_rows
2811 .iter()
2812 .find(|row| row.get(pivot_col_idx).is_some_and(|v| v == pv));
2813 if let Some(row) = matching_row {
2814 new_row.push(row.get(value_col_idx).cloned().unwrap_or(Value::Null));
2815 } else {
2816 new_row.push(Value::Null);
2817 }
2818 }
2819
2820 new_result.add_row(new_row);
2821 }
2822
2823 Ok(new_result)
2824 }
2825
2826 fn find_pivot_column(
2828 &self,
2829 result: &QueryResult,
2830 pivot_expr: &Expr,
2831 ) -> Result<usize, QueryError> {
2832 match pivot_expr {
2833 Expr::Column(name) => {
2834 let upper_name = name.to_uppercase();
2835 result
2836 .columns
2837 .iter()
2838 .position(|c| c.to_uppercase() == upper_name)
2839 .ok_or_else(|| {
2840 QueryError::Evaluation(format!(
2841 "PIVOT BY column '{name}' not found in SELECT"
2842 ))
2843 })
2844 }
2845 Expr::Literal(Literal::Integer(n)) => {
2846 let idx = (*n as usize).saturating_sub(1);
2847 if idx < result.columns.len() {
2848 Ok(idx)
2849 } else {
2850 Err(QueryError::Evaluation(format!(
2851 "PIVOT BY column index {n} out of range"
2852 )))
2853 }
2854 }
2855 Expr::Literal(Literal::Number(n)) => {
2856 use rust_decimal::prelude::ToPrimitive;
2858 let idx = n.to_usize().unwrap_or(0).saturating_sub(1);
2859 if idx < result.columns.len() {
2860 Ok(idx)
2861 } else {
2862 Err(QueryError::Evaluation(format!(
2863 "PIVOT BY column index {n} out of range"
2864 )))
2865 }
2866 }
2867 _ => {
2868 Err(QueryError::Evaluation(
2871 "PIVOT BY must reference a column name or index".to_string(),
2872 ))
2873 }
2874 }
2875 }
2876
2877 fn value_to_string(&self, val: &Value) -> String {
2879 match val {
2880 Value::String(s) => s.clone(),
2881 Value::Number(n) => n.to_string(),
2882 Value::Integer(i) => i.to_string(),
2883 Value::Date(d) => d.to_string(),
2884 Value::Boolean(b) => b.to_string(),
2885 Value::Amount(a) => format!("{} {}", a.number, a.currency),
2886 Value::Position(p) => format!("{}", p.units),
2887 Value::Inventory(inv) => inv
2888 .positions()
2889 .iter()
2890 .map(|p| format!("{}", p.units))
2891 .collect::<Vec<_>>()
2892 .join(", "),
2893 Value::StringSet(ss) => ss.join(", "),
2894 Value::Null => "NULL".to_string(),
2895 }
2896 }
2897}
2898
2899#[cfg(test)]
2900mod tests {
2901 use super::*;
2902 use crate::parse;
2903 use rust_decimal_macros::dec;
2904 use rustledger_core::Posting;
2905
2906 fn date(year: i32, month: u32, day: u32) -> NaiveDate {
2907 NaiveDate::from_ymd_opt(year, month, day).unwrap()
2908 }
2909
2910 fn sample_directives() -> Vec<Directive> {
2911 vec![
2912 Directive::Transaction(
2913 Transaction::new(date(2024, 1, 15), "Coffee")
2914 .with_flag('*')
2915 .with_payee("Coffee Shop")
2916 .with_posting(Posting::new(
2917 "Expenses:Food:Coffee",
2918 Amount::new(dec!(5.00), "USD"),
2919 ))
2920 .with_posting(Posting::new(
2921 "Assets:Bank:Checking",
2922 Amount::new(dec!(-5.00), "USD"),
2923 )),
2924 ),
2925 Directive::Transaction(
2926 Transaction::new(date(2024, 1, 16), "Groceries")
2927 .with_flag('*')
2928 .with_payee("Supermarket")
2929 .with_posting(Posting::new(
2930 "Expenses:Food:Groceries",
2931 Amount::new(dec!(50.00), "USD"),
2932 ))
2933 .with_posting(Posting::new(
2934 "Assets:Bank:Checking",
2935 Amount::new(dec!(-50.00), "USD"),
2936 )),
2937 ),
2938 ]
2939 }
2940
2941 #[test]
2942 fn test_simple_select() {
2943 let directives = sample_directives();
2944 let mut executor = Executor::new(&directives);
2945
2946 let query = parse("SELECT date, account").unwrap();
2947 let result = executor.execute(&query).unwrap();
2948
2949 assert_eq!(result.columns, vec!["date", "account"]);
2950 assert_eq!(result.len(), 4); }
2952
2953 #[test]
2954 fn test_where_clause() {
2955 let directives = sample_directives();
2956 let mut executor = Executor::new(&directives);
2957
2958 let query = parse("SELECT account WHERE account ~ \"Expenses:\"").unwrap();
2959 let result = executor.execute(&query).unwrap();
2960
2961 assert_eq!(result.len(), 2); }
2963
2964 #[test]
2965 fn test_balances() {
2966 let directives = sample_directives();
2967 let mut executor = Executor::new(&directives);
2968
2969 let query = parse("BALANCES").unwrap();
2970 let result = executor.execute(&query).unwrap();
2971
2972 assert_eq!(result.columns, vec!["account", "balance"]);
2973 assert!(result.len() >= 3); }
2975
2976 #[test]
2977 fn test_account_functions() {
2978 let directives = sample_directives();
2979 let mut executor = Executor::new(&directives);
2980
2981 let query = parse("SELECT DISTINCT LEAF(account) WHERE account ~ \"Expenses:\"").unwrap();
2983 let result = executor.execute(&query).unwrap();
2984 assert_eq!(result.len(), 2); let query = parse("SELECT DISTINCT ROOT(account)").unwrap();
2988 let result = executor.execute(&query).unwrap();
2989 assert_eq!(result.len(), 2); let query = parse("SELECT DISTINCT PARENT(account) WHERE account ~ \"Expenses:\"").unwrap();
2993 let result = executor.execute(&query).unwrap();
2994 assert!(!result.is_empty()); }
2996
2997 #[test]
2998 fn test_min_max_aggregate() {
2999 let directives = sample_directives();
3000 let mut executor = Executor::new(&directives);
3001
3002 let query = parse("SELECT MIN(date)").unwrap();
3004 let result = executor.execute(&query).unwrap();
3005 assert_eq!(result.len(), 1);
3006 assert_eq!(result.rows[0][0], Value::Date(date(2024, 1, 15)));
3007
3008 let query = parse("SELECT MAX(date)").unwrap();
3010 let result = executor.execute(&query).unwrap();
3011 assert_eq!(result.len(), 1);
3012 assert_eq!(result.rows[0][0], Value::Date(date(2024, 1, 16)));
3013 }
3014
3015 #[test]
3016 fn test_order_by() {
3017 let directives = sample_directives();
3018 let mut executor = Executor::new(&directives);
3019
3020 let query = parse("SELECT date, account ORDER BY date DESC").unwrap();
3021 let result = executor.execute(&query).unwrap();
3022
3023 assert_eq!(result.len(), 4);
3025 assert_eq!(result.rows[0][0], Value::Date(date(2024, 1, 16)));
3027 }
3028}