1use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8 AlterTable, AlterTableOperation, AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr,
9 FromTable, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, IndexType, ObjectName,
10 ObjectNamePart, RenameTableNameKind, Statement, TableFactor, TableWithJoins, UnaryOperator,
11 Update, Value as AstValue,
12};
13
14use crate::error::{Result, SQLRiteError};
15use crate::sql::agg::{AggState, DistinctKey, like_match};
16use crate::sql::db::database::Database;
17use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
18use crate::sql::db::table::{
19 DataType, FtsIndexEntry, HnswIndexEntry, Table, Value, parse_vector_literal,
20};
21use crate::sql::fts::{Bm25Params, PostingList};
22use crate::sql::hnsw::{DistanceMetric, HnswIndex};
23use crate::sql::parser::select::{
24 AggregateArg, AggregateFn, GroupByKey, JoinConstraintKind, JoinType, OrderByClause, Projection,
25 ProjectionItem, ProjectionKind, SelectQuery, parse_aggregate_call,
26};
27
28pub(crate) trait RowScope {
58 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value>;
59
60 fn single_table_view(&self) -> Option<(&Table, i64)>;
66
67 fn scope_name(&self) -> Option<&str>;
74}
75
76pub(crate) struct SingleTableScope<'a> {
81 table: &'a Table,
82 rowid: i64,
83 scope_name: &'a str,
84}
85
86impl<'a> SingleTableScope<'a> {
87 pub(crate) fn new(table: &'a Table, rowid: i64, scope_name: &'a str) -> Self {
88 Self {
89 table,
90 rowid,
91 scope_name,
92 }
93 }
94}
95
96impl RowScope for SingleTableScope<'_> {
97 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
98 check_single_scope_qualifier(qualifier, self.scope_name, col)?;
103 if !self.table.contains_column(col.to_string()) {
107 return Err(SQLRiteError::Internal(format!(
108 "Column '{col}' does not exist on table '{}'",
109 self.table.tb_name
110 )));
111 }
112 Ok(self.table.get_value(col, self.rowid).unwrap_or(Value::Null))
113 }
114
115 fn single_table_view(&self) -> Option<(&Table, i64)> {
116 Some((self.table, self.rowid))
117 }
118
119 fn scope_name(&self) -> Option<&str> {
120 Some(self.scope_name)
121 }
122}
123
124pub(crate) struct JoinedTableRef<'a> {
128 pub table: &'a Table,
129 pub scope_name: String,
130}
131
132pub(crate) struct JoinedScope<'a> {
136 pub tables: &'a [JoinedTableRef<'a>],
137 pub rowids: &'a [Option<i64>],
138}
139
140impl RowScope for JoinedScope<'_> {
141 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
142 if let Some(q) = qualifier {
143 let pos = self
146 .tables
147 .iter()
148 .position(|t| t.scope_name.eq_ignore_ascii_case(q))
149 .ok_or_else(|| {
150 SQLRiteError::Internal(format!(
151 "unknown table qualifier '{q}' in column reference '{q}.{col}'"
152 ))
153 })?;
154 if !self.tables[pos].table.contains_column(col.to_string()) {
155 return Err(SQLRiteError::Internal(format!(
156 "column '{col}' does not exist on '{}'",
157 self.tables[pos].scope_name
158 )));
159 }
160 return Ok(match self.rowids[pos] {
161 None => Value::Null,
162 Some(r) => self.tables[pos]
163 .table
164 .get_value(col, r)
165 .unwrap_or(Value::Null),
166 });
167 }
168 let mut hit: Option<usize> = None;
172 for (i, t) in self.tables.iter().enumerate() {
173 if t.table.contains_column(col.to_string()) {
174 if hit.is_some() {
175 return Err(SQLRiteError::Internal(format!(
176 "column reference '{col}' is ambiguous — qualify it as <table>.{col}"
177 )));
178 }
179 hit = Some(i);
180 }
181 }
182 let i = hit.ok_or_else(|| {
183 SQLRiteError::Internal(format!(
184 "unknown column '{col}' in joined SELECT (no in-scope table has it)"
185 ))
186 })?;
187 Ok(match self.rowids[i] {
188 None => Value::Null,
189 Some(r) => self.tables[i]
190 .table
191 .get_value(col, r)
192 .unwrap_or(Value::Null),
193 })
194 }
195
196 fn single_table_view(&self) -> Option<(&Table, i64)> {
197 None
198 }
199
200 fn scope_name(&self) -> Option<&str> {
201 None
204 }
205}
206
207fn check_single_scope_qualifier(
214 qualifier: Option<&str>,
215 scope_name: &str,
216 col: &str,
217) -> Result<()> {
218 if let Some(q) = qualifier
219 && !q.eq_ignore_ascii_case(scope_name)
220 {
221 return Err(SQLRiteError::Internal(format!(
222 "unknown table qualifier '{q}' in column reference '{q}.{col}'"
223 )));
224 }
225 Ok(())
226}
227
228pub struct SelectResult {
237 pub columns: Vec<String>,
238 pub rows: Vec<Vec<Value>>,
239}
240
241pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
245 if !query.joins.is_empty() {
250 return execute_select_rows_joined(query, db);
251 }
252
253 let master_snapshot;
262 let table: &Table = if query.table_name == crate::sql::pager::MASTER_TABLE_NAME {
263 master_snapshot = crate::sql::pager::build_master_table_snapshot(db)?;
264 &master_snapshot
265 } else {
266 db.get_table(query.table_name.clone()).map_err(|_| {
267 SQLRiteError::Internal(format!("Table '{}' not found", query.table_name))
268 })?
269 };
270
271 let scope_name: &str = query.table_alias.as_deref().unwrap_or(&query.table_name);
275
276 let proj_items: Vec<ProjectionItem> = match &query.projection {
281 Projection::All => table
282 .column_names()
283 .into_iter()
284 .map(|c| ProjectionItem {
285 kind: ProjectionKind::Column {
286 qualifier: None,
287 name: c,
288 },
289 alias: None,
290 })
291 .collect(),
292 Projection::Items(items) => items.clone(),
293 };
294 let has_aggregates = proj_items
295 .iter()
296 .any(|i| matches!(i.kind, ProjectionKind::Aggregate(_)));
297 for item in &proj_items {
300 if let ProjectionKind::Column { qualifier, name: c } = &item.kind {
301 check_single_scope_qualifier(qualifier.as_deref(), scope_name, c)?;
302 if !table.contains_column(c.clone()) {
303 return Err(SQLRiteError::Internal(format!(
304 "Column '{c}' does not exist on table '{}'",
305 query.table_name
306 )));
307 }
308 }
309 }
310 for g in &query.group_by {
311 check_single_scope_qualifier(g.qualifier.as_deref(), scope_name, &g.name)?;
312 if !table.contains_column(g.name.clone()) {
313 return Err(SQLRiteError::Internal(format!(
314 "GROUP BY references unknown column '{}' on table '{}'",
315 g.name, query.table_name
316 )));
317 }
318 }
319 let matching = match select_rowids(table, query.selection.as_ref(), scope_name)? {
323 RowidSource::IndexProbe(rowids) => rowids,
324 RowidSource::FullScan => {
325 let mut out = Vec::new();
326 for rowid in table.rowids() {
327 if let Some(expr) = &query.selection
328 && !eval_predicate(expr, table, rowid, scope_name)?
329 {
330 continue;
331 }
332 out.push(rowid);
333 }
334 out
335 }
336 };
337 let mut matching = matching;
338
339 let aggregating = has_aggregates || !query.group_by.is_empty();
340
341 if aggregating {
347 let (all_items, having_expr) = lower_having_into_hidden_slots(&query, &proj_items)?;
348
349 for item in &all_items {
351 if let ProjectionKind::Aggregate(call) = &item.kind
352 && let AggregateArg::Column { qualifier, name: c } = &call.arg
353 {
354 check_single_scope_qualifier(qualifier.as_deref(), scope_name, c)?;
355 if !table.contains_column(c.clone()) {
356 return Err(SQLRiteError::Internal(format!(
357 "{}({}) references unknown column '{c}' on table '{}'",
358 call.func.as_str(),
359 c,
360 query.table_name
361 )));
362 }
363 }
364 }
365
366 let scopes = matching
367 .iter()
368 .map(|&r| SingleTableScope::new(table, r, scope_name));
369 return run_aggregation_pipeline(scopes, &query, &proj_items, &all_items, &having_expr);
370 }
371
372 let defer_limit_for_distinct = query.distinct;
410 match (&query.order_by, query.limit) {
411 (Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
412 matching = try_hnsw_probe(table, &order.expr, k).unwrap();
413 }
414 (Some(order), Some(k))
415 if try_fts_probe(table, &order.expr, order.ascending, k).is_some() =>
416 {
417 matching = try_fts_probe(table, &order.expr, order.ascending, k).unwrap();
418 }
419 (Some(order), Some(k)) if !defer_limit_for_distinct && k < matching.len() => {
420 matching = select_topk(&matching, table, order, k, scope_name)?;
421 }
422 (Some(order), _) => {
423 sort_rowids(&mut matching, table, order, scope_name)?;
424 if let Some(k) = query.limit
425 && !defer_limit_for_distinct
426 {
427 matching.truncate(k);
428 }
429 }
430 (None, Some(k)) if !defer_limit_for_distinct => {
431 matching.truncate(k);
432 }
433 _ => {}
434 }
435
436 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
437 let projected_cols: Vec<String> = proj_items
438 .iter()
439 .map(|i| match &i.kind {
440 ProjectionKind::Column { name, .. } => name.clone(),
441 ProjectionKind::Aggregate(_) => unreachable!("aggregation handled above"),
442 })
443 .collect();
444
445 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
449 for rowid in &matching {
450 let row: Vec<Value> = projected_cols
451 .iter()
452 .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
453 .collect();
454 rows.push(row);
455 }
456
457 if query.distinct {
458 rows = dedupe_rows(rows);
459 if let Some(k) = query.limit {
460 rows.truncate(k);
461 }
462 }
463
464 Ok(SelectResult { columns, rows })
465}
466
467struct ResolvedJoin {
472 on: Expr,
473 using_columns: Vec<String>,
474}
475
476fn resolve_join_constraint(
491 constraint: &JoinConstraintKind,
492 tables: &[JoinedTableRef<'_>],
493 right_pos: usize,
494) -> Result<ResolvedJoin> {
495 match constraint {
496 JoinConstraintKind::On(expr) => Ok(ResolvedJoin {
497 on: (**expr).clone(),
498 using_columns: Vec::new(),
499 }),
500 JoinConstraintKind::Using(cols) => build_using_join(cols, tables, right_pos),
501 JoinConstraintKind::Natural => {
502 let shared: Vec<String> = tables[right_pos]
506 .table
507 .column_names()
508 .into_iter()
509 .filter(|c| {
510 tables[..right_pos]
511 .iter()
512 .any(|t| t.table.contains_column(c.clone()))
513 })
514 .collect();
515 build_using_join(&shared, tables, right_pos)
516 }
517 }
518}
519
520fn build_using_join(
525 cols: &[String],
526 tables: &[JoinedTableRef<'_>],
527 right_pos: usize,
528) -> Result<ResolvedJoin> {
529 let right = &tables[right_pos];
530 let mut predicate: Option<Expr> = None;
531 for col in cols {
532 if !right.table.contains_column(col.clone()) {
534 return Err(SQLRiteError::Internal(format!(
535 "cannot join USING column '{col}' — it is not present on table '{}'",
536 right.scope_name
537 )));
538 }
539 let left = tables[..right_pos]
542 .iter()
543 .find(|t| t.table.contains_column(col.clone()))
544 .ok_or_else(|| {
545 SQLRiteError::Internal(format!(
546 "cannot join USING column '{col}' — it is not present on any left-side table"
547 ))
548 })?;
549 let eq = col_eq(&left.scope_name, &right.scope_name, col);
550 predicate = Some(match predicate {
551 None => eq,
552 Some(prev) => Expr::BinaryOp {
553 left: Box::new(prev),
554 op: BinaryOperator::And,
555 right: Box::new(eq),
556 },
557 });
558 }
559 Ok(ResolvedJoin {
560 on: predicate
561 .unwrap_or_else(|| Expr::Value(sqlparser::ast::Value::Boolean(true).with_empty_span())),
562 using_columns: cols.to_vec(),
563 })
564}
565
566fn col_eq(left_scope: &str, right_scope: &str, col: &str) -> Expr {
569 let col_ref = |scope: &str| {
570 Expr::CompoundIdentifier(vec![
571 Ident::new(scope.to_string()),
572 Ident::new(col.to_string()),
573 ])
574 };
575 Expr::BinaryOp {
576 left: Box::new(col_ref(left_scope)),
577 op: BinaryOperator::Eq,
578 right: Box::new(col_ref(right_scope)),
579 }
580}
581
582fn execute_select_rows_joined(query: SelectQuery, db: &Database) -> Result<SelectResult> {
607 let mut joined_tables: Vec<JoinedTableRef<'_>> = Vec::with_capacity(1 + query.joins.len());
614
615 let primary = db
616 .get_table(query.table_name.clone())
617 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
618 joined_tables.push(JoinedTableRef {
619 table: primary,
620 scope_name: query
621 .table_alias
622 .clone()
623 .unwrap_or_else(|| query.table_name.clone()),
624 });
625 for j in &query.joins {
626 let t = db
627 .get_table(j.right_table.clone())
628 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", j.right_table)))?;
629 joined_tables.push(JoinedTableRef {
630 table: t,
631 scope_name: j
632 .right_alias
633 .clone()
634 .unwrap_or_else(|| j.right_table.clone()),
635 });
636 }
637
638 {
643 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
644 for t in &joined_tables {
645 let key = t.scope_name.to_ascii_lowercase();
646 if !seen.insert(key) {
647 return Err(SQLRiteError::Internal(format!(
648 "duplicate table reference '{}' in FROM/JOIN — use AS to alias one side",
649 t.scope_name
650 )));
651 }
652 }
653 }
654
655 let resolved: Vec<ResolvedJoin> = query
663 .joins
664 .iter()
665 .enumerate()
666 .map(|(j_idx, join)| resolve_join_constraint(&join.constraint, &joined_tables, j_idx + 1))
667 .collect::<Result<Vec<_>>>()?;
668
669 let proj_items: Vec<ProjectionItem> = match &query.projection {
675 Projection::All => {
676 let mut all = Vec::new();
692 for (t_idx, t) in joined_tables.iter().enumerate() {
693 let dedup: &[String] = t_idx
696 .checked_sub(1)
697 .map(|r| resolved[r].using_columns.as_slice())
698 .unwrap_or(&[]);
699 for col in t.table.column_names() {
700 if dedup.contains(&col) {
701 continue;
702 }
703 all.push(ProjectionItem {
704 kind: ProjectionKind::Column {
705 qualifier: Some(t.scope_name.clone()),
710 name: col,
711 },
712 alias: None,
713 });
714 }
715 }
716 all
717 }
718 Projection::Items(items) => items.clone(),
719 };
720
721 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
722
723 let mut acc: Vec<Vec<Option<i64>>> = primary
728 .rowids()
729 .into_iter()
730 .map(|r| {
731 let mut row = Vec::with_capacity(joined_tables.len());
732 row.push(Some(r));
733 row
734 })
735 .collect();
736
737 for (j_idx, join) in query.joins.iter().enumerate() {
742 let right_pos = j_idx + 1;
743 let right_table = joined_tables[right_pos].table;
744 let right_rowids: Vec<i64> = right_table.rowids();
745
746 let mut right_matched: Vec<bool> = vec![false; right_rowids.len()];
750
751 let mut next_acc: Vec<Vec<Option<i64>>> = Vec::with_capacity(acc.len());
752
753 let on_scope_tables: &[JoinedTableRef<'_>] = &joined_tables[..=right_pos];
761
762 for left_row in acc.into_iter() {
763 let mut left_match_count = 0usize;
767 for (r_idx, &rrid) in right_rowids.iter().enumerate() {
768 let mut on_rowids: Vec<Option<i64>> = left_row.clone();
769 on_rowids.push(Some(rrid));
770 debug_assert_eq!(on_rowids.len(), on_scope_tables.len());
771 let scope = JoinedScope {
772 tables: on_scope_tables,
773 rowids: &on_rowids,
774 };
775 if eval_predicate_scope(&resolved[j_idx].on, &scope)? {
782 left_match_count += 1;
783 right_matched[r_idx] = true;
784 next_acc.push(on_rowids);
789 }
790 }
791
792 if left_match_count == 0
793 && matches!(join.join_type, JoinType::LeftOuter | JoinType::FullOuter)
794 {
795 let mut padded = left_row;
798 padded.push(None);
799 next_acc.push(padded);
800 }
801 }
802
803 if matches!(join.join_type, JoinType::RightOuter | JoinType::FullOuter) {
807 for (r_idx, matched) in right_matched.iter().enumerate() {
808 if *matched {
809 continue;
810 }
811 let mut row: Vec<Option<i64>> = vec![None; right_pos];
812 row.push(Some(right_rowids[r_idx]));
813 next_acc.push(row);
814 }
815 }
816
817 acc = next_acc;
818 }
819
820 let mut filtered: Vec<Vec<Option<i64>>> = if let Some(where_expr) = &query.selection {
825 let mut out = Vec::with_capacity(acc.len());
826 for row in acc {
827 let scope = JoinedScope {
828 tables: &joined_tables,
829 rowids: &row,
830 };
831 if eval_predicate_scope(where_expr, &scope)? {
832 out.push(row);
833 }
834 }
835 out
836 } else {
837 acc
838 };
839
840 let has_aggregates = proj_items
848 .iter()
849 .any(|i| matches!(i.kind, ProjectionKind::Aggregate(_)));
850 if has_aggregates || !query.group_by.is_empty() {
851 let (all_items, having_expr) = lower_having_into_hidden_slots(&query, &proj_items)?;
852
853 for g in &query.group_by {
860 resolve_scope_column(&joined_tables, g.qualifier.as_deref(), &g.name)?;
861 }
862 for item in &all_items {
863 match &item.kind {
864 ProjectionKind::Aggregate(call) => {
865 if let AggregateArg::Column { qualifier, name } = &call.arg {
866 resolve_scope_column(&joined_tables, qualifier.as_deref(), name)?;
867 }
868 }
869 ProjectionKind::Column { qualifier, name } => {
870 let pos = resolve_scope_column(&joined_tables, qualifier.as_deref(), name)?;
871 let in_group_by = query.group_by.iter().any(|g| {
872 g.name == *name
873 && resolve_scope_column(&joined_tables, g.qualifier.as_deref(), &g.name)
874 == Ok(pos)
875 });
876 if !in_group_by {
877 return Err(SQLRiteError::Internal(format!(
878 "column '{name}' must appear in GROUP BY or be used in an \
879 aggregate function"
880 )));
881 }
882 }
883 }
884 }
885
886 let scopes = filtered.iter().map(|row| JoinedScope {
887 tables: &joined_tables,
888 rowids: row,
889 });
890 return run_aggregation_pipeline(scopes, &query, &proj_items, &all_items, &having_expr);
891 }
892
893 if let Some(order) = &query.order_by {
897 let mut keys: Vec<(usize, Value)> = Vec::with_capacity(filtered.len());
900 for (i, row) in filtered.iter().enumerate() {
901 let scope = JoinedScope {
902 tables: &joined_tables,
903 rowids: row,
904 };
905 let v = eval_expr_scope(&order.expr, &scope)?;
906 keys.push((i, v));
907 }
908 keys.sort_by(|(_, a), (_, b)| {
909 let ord = compare_values(Some(a), Some(b));
910 if order.ascending { ord } else { ord.reverse() }
911 });
912 let mut sorted = Vec::with_capacity(filtered.len());
913 for (i, _) in keys {
914 sorted.push(filtered[i].clone());
915 }
916 filtered = sorted;
917 }
918
919 if let Some(k) = query.limit
924 && !query.distinct
925 {
926 filtered.truncate(k);
927 }
928
929 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(filtered.len());
932 for row in &filtered {
933 let scope = JoinedScope {
934 tables: &joined_tables,
935 rowids: row,
936 };
937 let mut out_row = Vec::with_capacity(proj_items.len());
938 for item in &proj_items {
939 let v = match &item.kind {
940 ProjectionKind::Column { qualifier, name } => {
941 scope.lookup(qualifier.as_deref(), name)?
942 }
943 ProjectionKind::Aggregate(_) => {
944 return Err(SQLRiteError::Internal(
948 "aggregate projection reached the non-aggregating join path".to_string(),
949 ));
950 }
951 };
952 out_row.push(v);
953 }
954 rows.push(out_row);
955 }
956
957 if query.distinct {
960 rows = dedupe_rows(rows);
961 if let Some(k) = query.limit {
962 rows.truncate(k);
963 }
964 }
965
966 Ok(SelectResult { columns, rows })
967}
968
969fn resolve_scope_column(
976 tables: &[JoinedTableRef<'_>],
977 qualifier: Option<&str>,
978 name: &str,
979) -> Result<usize> {
980 if let Some(q) = qualifier {
981 let pos = tables
982 .iter()
983 .position(|t| t.scope_name.eq_ignore_ascii_case(q))
984 .ok_or_else(|| {
985 SQLRiteError::Internal(format!(
986 "unknown table qualifier '{q}' in column reference '{q}.{name}'"
987 ))
988 })?;
989 if !tables[pos].table.contains_column(name.to_string()) {
990 return Err(SQLRiteError::Internal(format!(
991 "column '{name}' does not exist on '{}'",
992 tables[pos].scope_name
993 )));
994 }
995 return Ok(pos);
996 }
997 let mut hit: Option<usize> = None;
998 for (i, t) in tables.iter().enumerate() {
999 if t.table.contains_column(name.to_string()) {
1000 if hit.is_some() {
1001 return Err(SQLRiteError::Internal(format!(
1002 "column reference '{name}' is ambiguous — qualify it as <table>.{name}"
1003 )));
1004 }
1005 hit = Some(i);
1006 }
1007 }
1008 hit.ok_or_else(|| {
1009 SQLRiteError::Internal(format!(
1010 "unknown column '{name}' in joined SELECT (no in-scope table has it)"
1011 ))
1012 })
1013}
1014
1015pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
1020 let result = execute_select_rows(query, db)?;
1021 let row_count = result.rows.len();
1022
1023 let mut print_table = PrintTable::new();
1024 let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
1025 print_table.add_row(PrintRow::new(header_cells));
1026
1027 for row in &result.rows {
1028 let cells: Vec<PrintCell> = row
1029 .iter()
1030 .map(|v| PrintCell::new(&v.to_display_string()))
1031 .collect();
1032 print_table.add_row(PrintRow::new(cells));
1033 }
1034
1035 Ok((print_table.to_string(), row_count))
1036}
1037
1038pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
1040 let Statement::Delete(Delete {
1041 from, selection, ..
1042 }) = stmt
1043 else {
1044 return Err(SQLRiteError::Internal(
1045 "execute_delete called on a non-DELETE statement".to_string(),
1046 ));
1047 };
1048
1049 let tables = match from {
1050 FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
1051 };
1052 let (table_name, table_alias) = extract_single_table_name(tables)?;
1053 let scope_name = table_alias.as_deref().unwrap_or(&table_name);
1056
1057 let matching: Vec<i64> = {
1059 let table = db
1060 .get_table(table_name.clone())
1061 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
1062 match select_rowids(table, selection.as_ref(), scope_name)? {
1063 RowidSource::IndexProbe(rowids) => rowids,
1064 RowidSource::FullScan => {
1065 let mut out = Vec::new();
1066 for rowid in table.rowids() {
1067 if let Some(expr) = selection {
1068 if !eval_predicate(expr, table, rowid, scope_name)? {
1069 continue;
1070 }
1071 }
1072 out.push(rowid);
1073 }
1074 out
1075 }
1076 }
1077 };
1078
1079 let table = db.get_table_mut(table_name)?;
1080 for rowid in &matching {
1081 table.delete_row(*rowid);
1082 }
1083 if !matching.is_empty() {
1092 for entry in &mut table.hnsw_indexes {
1093 entry.needs_rebuild = true;
1094 }
1095 for entry in &mut table.fts_indexes {
1096 entry.needs_rebuild = true;
1097 }
1098 }
1099 Ok(matching.len())
1100}
1101
1102pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
1104 let Statement::Update(Update {
1105 table,
1106 assignments,
1107 from,
1108 selection,
1109 ..
1110 }) = stmt
1111 else {
1112 return Err(SQLRiteError::Internal(
1113 "execute_update called on a non-UPDATE statement".to_string(),
1114 ));
1115 };
1116
1117 if from.is_some() {
1118 return Err(SQLRiteError::NotImplemented(
1119 "UPDATE ... FROM is not supported yet".to_string(),
1120 ));
1121 }
1122
1123 let (table_name, table_alias) = extract_table_name(table)?;
1124 let scope_name = table_alias.as_deref().unwrap_or(&table_name);
1128
1129 let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
1131 {
1132 let tbl = db
1133 .get_table(table_name.clone())
1134 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
1135 for a in assignments {
1136 let col = match &a.target {
1137 AssignmentTarget::ColumnName(name) => name
1138 .0
1139 .last()
1140 .map(|p| p.to_string())
1141 .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
1142 AssignmentTarget::Tuple(_) => {
1143 return Err(SQLRiteError::NotImplemented(
1144 "tuple assignment targets are not supported".to_string(),
1145 ));
1146 }
1147 };
1148 if !tbl.contains_column(col.clone()) {
1149 return Err(SQLRiteError::Internal(format!(
1150 "UPDATE references unknown column '{col}'"
1151 )));
1152 }
1153 parsed_assignments.push((col, a.value.clone()));
1154 }
1155 }
1156
1157 let work: Vec<(i64, Vec<(String, Value)>)> = {
1161 let tbl = db.get_table(table_name.clone())?;
1162 let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref(), scope_name)? {
1163 RowidSource::IndexProbe(rowids) => rowids,
1164 RowidSource::FullScan => {
1165 let mut out = Vec::new();
1166 for rowid in tbl.rowids() {
1167 if let Some(expr) = selection {
1168 if !eval_predicate(expr, tbl, rowid, scope_name)? {
1169 continue;
1170 }
1171 }
1172 out.push(rowid);
1173 }
1174 out
1175 }
1176 };
1177 let mut rows_to_update = Vec::new();
1178 for rowid in matched_rowids {
1179 let mut values = Vec::with_capacity(parsed_assignments.len());
1180 for (col, expr) in &parsed_assignments {
1181 let v = eval_expr(expr, tbl, rowid, scope_name)?;
1184 values.push((col.clone(), v));
1185 }
1186 rows_to_update.push((rowid, values));
1187 }
1188 rows_to_update
1189 };
1190
1191 let tbl = db.get_table_mut(table_name)?;
1192 for (rowid, values) in &work {
1193 for (col, v) in values {
1194 tbl.set_value(col, *rowid, v.clone())?;
1195 }
1196 }
1197
1198 if !work.is_empty() {
1207 let updated_columns: std::collections::HashSet<&str> = work
1208 .iter()
1209 .flat_map(|(_, values)| values.iter().map(|(c, _)| c.as_str()))
1210 .collect();
1211 for entry in &mut tbl.hnsw_indexes {
1212 if updated_columns.contains(entry.column_name.as_str()) {
1213 entry.needs_rebuild = true;
1214 }
1215 }
1216 for entry in &mut tbl.fts_indexes {
1217 if updated_columns.contains(entry.column_name.as_str()) {
1218 entry.needs_rebuild = true;
1219 }
1220 }
1221 }
1222 Ok(work.len())
1223}
1224
1225pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
1237 let Statement::CreateIndex(CreateIndex {
1238 name,
1239 table_name,
1240 columns,
1241 using,
1242 unique,
1243 if_not_exists,
1244 predicate,
1245 with,
1246 ..
1247 }) = stmt
1248 else {
1249 return Err(SQLRiteError::Internal(
1250 "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
1251 ));
1252 };
1253
1254 if predicate.is_some() {
1255 return Err(SQLRiteError::NotImplemented(
1256 "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
1257 ));
1258 }
1259
1260 if columns.len() != 1 {
1261 return Err(SQLRiteError::NotImplemented(format!(
1262 "multi-column indexes are not supported yet ({} columns given)",
1263 columns.len()
1264 )));
1265 }
1266
1267 let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
1268 SQLRiteError::NotImplemented(
1269 "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
1270 )
1271 })?;
1272
1273 let method = match using {
1279 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
1280 IndexMethod::Hnsw
1281 }
1282 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("fts") => {
1283 IndexMethod::Fts
1284 }
1285 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
1286 IndexMethod::Btree
1287 }
1288 Some(other) => {
1289 return Err(SQLRiteError::NotImplemented(format!(
1290 "CREATE INDEX … USING {other:?} is not supported \
1291 (try `hnsw`, `fts`, or no USING clause)"
1292 )));
1293 }
1294 None => IndexMethod::Btree,
1295 };
1296
1297 let hnsw_metric = parse_hnsw_with_options(with, &index_name, method)?;
1303
1304 let table_name_str = table_name.to_string();
1305 let column_name = match &columns[0].column.expr {
1306 Expr::Identifier(ident) => ident.value.clone(),
1307 Expr::CompoundIdentifier(parts) => parts
1308 .last()
1309 .map(|p| p.value.clone())
1310 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
1311 other => {
1312 return Err(SQLRiteError::NotImplemented(format!(
1313 "CREATE INDEX only supports simple column references, got {other:?}"
1314 )));
1315 }
1316 };
1317
1318 let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
1323 let table = db.get_table(table_name_str.clone()).map_err(|_| {
1324 SQLRiteError::General(format!(
1325 "CREATE INDEX references unknown table '{table_name_str}'"
1326 ))
1327 })?;
1328 if !table.contains_column(column_name.clone()) {
1329 return Err(SQLRiteError::General(format!(
1330 "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
1331 )));
1332 }
1333 let col = table
1334 .columns
1335 .iter()
1336 .find(|c| c.column_name == column_name)
1337 .expect("we just verified the column exists");
1338
1339 if table.index_by_name(&index_name).is_some()
1342 || table.hnsw_indexes.iter().any(|i| i.name == index_name)
1343 || table.fts_indexes.iter().any(|i| i.name == index_name)
1344 {
1345 if *if_not_exists {
1346 return Ok(index_name);
1347 }
1348 return Err(SQLRiteError::General(format!(
1349 "index '{index_name}' already exists"
1350 )));
1351 }
1352 let datatype = clone_datatype(&col.datatype);
1353
1354 let mut pairs = Vec::new();
1355 for rowid in table.rowids() {
1356 if let Some(v) = table.get_value(&column_name, rowid) {
1357 pairs.push((rowid, v));
1358 }
1359 }
1360 (datatype, pairs)
1361 };
1362
1363 match method {
1364 IndexMethod::Btree => create_btree_index(
1365 db,
1366 &table_name_str,
1367 &index_name,
1368 &column_name,
1369 &datatype,
1370 *unique,
1371 &existing_rowids_and_values,
1372 ),
1373 IndexMethod::Hnsw => create_hnsw_index(
1374 db,
1375 &table_name_str,
1376 &index_name,
1377 &column_name,
1378 &datatype,
1379 *unique,
1380 hnsw_metric.unwrap_or(DistanceMetric::L2),
1381 &existing_rowids_and_values,
1382 ),
1383 IndexMethod::Fts => create_fts_index(
1384 db,
1385 &table_name_str,
1386 &index_name,
1387 &column_name,
1388 &datatype,
1389 *unique,
1390 &existing_rowids_and_values,
1391 ),
1392 }
1393}
1394
1395pub fn execute_drop_table(
1406 names: &[ObjectName],
1407 if_exists: bool,
1408 db: &mut Database,
1409) -> Result<usize> {
1410 if names.len() != 1 {
1411 return Err(SQLRiteError::NotImplemented(
1412 "DROP TABLE supports a single table per statement".to_string(),
1413 ));
1414 }
1415 let name = names[0].to_string();
1416
1417 if name == crate::sql::pager::MASTER_TABLE_NAME {
1418 return Err(SQLRiteError::General(format!(
1419 "'{}' is a reserved name used by the internal schema catalog",
1420 crate::sql::pager::MASTER_TABLE_NAME
1421 )));
1422 }
1423
1424 if !db.contains_table(name.clone()) {
1425 return if if_exists {
1426 Ok(0)
1427 } else {
1428 Err(SQLRiteError::General(format!(
1429 "Table '{name}' does not exist"
1430 )))
1431 };
1432 }
1433
1434 db.tables.remove(&name);
1435 Ok(1)
1436}
1437
1438pub fn execute_drop_index(
1447 names: &[ObjectName],
1448 if_exists: bool,
1449 db: &mut Database,
1450) -> Result<usize> {
1451 if names.len() != 1 {
1452 return Err(SQLRiteError::NotImplemented(
1453 "DROP INDEX supports a single index per statement".to_string(),
1454 ));
1455 }
1456 let name = names[0].to_string();
1457
1458 for table in db.tables.values_mut() {
1459 if let Some(secondary) = table.secondary_indexes.iter().find(|i| i.name == name) {
1460 if secondary.origin == IndexOrigin::Auto {
1461 return Err(SQLRiteError::General(format!(
1462 "cannot drop auto-created index '{name}' (drop the column or table instead)"
1463 )));
1464 }
1465 table.secondary_indexes.retain(|i| i.name != name);
1466 return Ok(1);
1467 }
1468 if table.hnsw_indexes.iter().any(|i| i.name == name) {
1469 table.hnsw_indexes.retain(|i| i.name != name);
1470 return Ok(1);
1471 }
1472 if table.fts_indexes.iter().any(|i| i.name == name) {
1473 table.fts_indexes.retain(|i| i.name != name);
1474 return Ok(1);
1475 }
1476 }
1477
1478 if if_exists {
1479 Ok(0)
1480 } else {
1481 Err(SQLRiteError::General(format!(
1482 "Index '{name}' does not exist"
1483 )))
1484 }
1485}
1486
1487pub fn execute_alter_table(alter: AlterTable, db: &mut Database) -> Result<String> {
1499 let table_name = alter.name.to_string();
1500
1501 if table_name == crate::sql::pager::MASTER_TABLE_NAME {
1502 return Err(SQLRiteError::General(format!(
1503 "'{}' is a reserved name used by the internal schema catalog",
1504 crate::sql::pager::MASTER_TABLE_NAME
1505 )));
1506 }
1507
1508 if !db.contains_table(table_name.clone()) {
1509 return if alter.if_exists {
1510 Ok("ALTER TABLE: no-op (table does not exist)".to_string())
1511 } else {
1512 Err(SQLRiteError::General(format!(
1513 "Table '{table_name}' does not exist"
1514 )))
1515 };
1516 }
1517
1518 if alter.operations.len() != 1 {
1519 return Err(SQLRiteError::NotImplemented(
1520 "ALTER TABLE supports one operation per statement".to_string(),
1521 ));
1522 }
1523
1524 match &alter.operations[0] {
1525 AlterTableOperation::RenameTable { table_name: kind } => {
1526 let new_name = match kind {
1527 RenameTableNameKind::To(name) => name.to_string(),
1528 RenameTableNameKind::As(_) => {
1529 return Err(SQLRiteError::NotImplemented(
1530 "ALTER TABLE ... RENAME AS (MySQL-only) is not supported; use RENAME TO"
1531 .to_string(),
1532 ));
1533 }
1534 };
1535 alter_rename_table(db, &table_name, &new_name)?;
1536 Ok(format!(
1537 "ALTER TABLE '{table_name}' RENAME TO '{new_name}' executed."
1538 ))
1539 }
1540 AlterTableOperation::RenameColumn {
1541 old_column_name,
1542 new_column_name,
1543 } => {
1544 let old = old_column_name.value.clone();
1545 let new = new_column_name.value.clone();
1546 db.get_table_mut(table_name.clone())?
1547 .rename_column(&old, &new)?;
1548 Ok(format!(
1549 "ALTER TABLE '{table_name}' RENAME COLUMN '{old}' TO '{new}' executed."
1550 ))
1551 }
1552 AlterTableOperation::AddColumn {
1553 column_def,
1554 if_not_exists,
1555 ..
1556 } => {
1557 let parsed = crate::sql::parser::create::parse_one_column(column_def)?;
1558 let table = db.get_table_mut(table_name.clone())?;
1559 if *if_not_exists && table.contains_column(parsed.name.clone()) {
1560 return Ok(format!(
1561 "ALTER TABLE '{table_name}' ADD COLUMN: no-op (column '{}' already exists)",
1562 parsed.name
1563 ));
1564 }
1565 let col_name = parsed.name.clone();
1566 table.add_column(parsed)?;
1567 Ok(format!(
1568 "ALTER TABLE '{table_name}' ADD COLUMN '{col_name}' executed."
1569 ))
1570 }
1571 AlterTableOperation::DropColumn {
1572 column_names,
1573 if_exists,
1574 ..
1575 } => {
1576 if column_names.len() != 1 {
1577 return Err(SQLRiteError::NotImplemented(
1578 "ALTER TABLE DROP COLUMN supports a single column per statement".to_string(),
1579 ));
1580 }
1581 let col_name = column_names[0].value.clone();
1582 let table = db.get_table_mut(table_name.clone())?;
1583 if *if_exists && !table.contains_column(col_name.clone()) {
1584 return Ok(format!(
1585 "ALTER TABLE '{table_name}' DROP COLUMN: no-op (column '{col_name}' does not exist)"
1586 ));
1587 }
1588 table.drop_column(&col_name)?;
1589 Ok(format!(
1590 "ALTER TABLE '{table_name}' DROP COLUMN '{col_name}' executed."
1591 ))
1592 }
1593 other => Err(SQLRiteError::NotImplemented(format!(
1594 "ALTER TABLE operation {other:?} is not supported"
1595 ))),
1596 }
1597}
1598
1599pub fn execute_vacuum(db: &mut Database) -> Result<String> {
1609 if db.in_transaction() {
1610 return Err(SQLRiteError::General(
1611 "VACUUM cannot run inside a transaction".to_string(),
1612 ));
1613 }
1614 let path = match db.source_path.clone() {
1615 Some(p) => p,
1616 None => {
1617 return Ok("VACUUM is a no-op for in-memory databases".to_string());
1618 }
1619 };
1620 if let Some(pager) = db.pager.as_mut() {
1626 let _ = pager.checkpoint();
1627 }
1628 let size_before = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
1629 let pages_before = db
1630 .pager
1631 .as_ref()
1632 .map(|p| p.header().page_count)
1633 .unwrap_or(0);
1634 crate::sql::pager::vacuum_database(db, &path)?;
1635 if let Some(pager) = db.pager.as_mut() {
1638 let _ = pager.checkpoint();
1639 }
1640 let size_after = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
1641 let pages_after = db
1642 .pager
1643 .as_ref()
1644 .map(|p| p.header().page_count)
1645 .unwrap_or(0);
1646 let pages_reclaimed = pages_before.saturating_sub(pages_after);
1647 let bytes_reclaimed = size_before.saturating_sub(size_after);
1648 Ok(format!(
1649 "VACUUM completed. {pages_reclaimed} pages reclaimed ({bytes_reclaimed} bytes)."
1650 ))
1651}
1652
1653fn alter_rename_table(db: &mut Database, old: &str, new: &str) -> Result<()> {
1659 if new == crate::sql::pager::MASTER_TABLE_NAME {
1660 return Err(SQLRiteError::General(format!(
1661 "'{}' is a reserved name used by the internal schema catalog",
1662 crate::sql::pager::MASTER_TABLE_NAME
1663 )));
1664 }
1665 if old == new {
1666 return Ok(());
1667 }
1668 if db.contains_table(new.to_string()) {
1669 return Err(SQLRiteError::General(format!(
1670 "target table '{new}' already exists"
1671 )));
1672 }
1673
1674 let mut table = db
1675 .tables
1676 .remove(old)
1677 .ok_or_else(|| SQLRiteError::General(format!("Table '{old}' does not exist")))?;
1678 table.tb_name = new.to_string();
1679 for idx in table.secondary_indexes.iter_mut() {
1680 idx.table_name = new.to_string();
1681 if idx.origin == IndexOrigin::Auto
1682 && idx.name == SecondaryIndex::auto_name(old, &idx.column_name)
1683 {
1684 idx.name = SecondaryIndex::auto_name(new, &idx.column_name);
1685 }
1686 }
1687 db.tables.insert(new.to_string(), table);
1688 Ok(())
1689}
1690
1691#[derive(Debug, Clone, Copy)]
1695enum IndexMethod {
1696 Btree,
1697 Hnsw,
1698 Fts,
1700}
1701
1702fn create_btree_index(
1704 db: &mut Database,
1705 table_name: &str,
1706 index_name: &str,
1707 column_name: &str,
1708 datatype: &DataType,
1709 unique: bool,
1710 existing: &[(i64, Value)],
1711) -> Result<String> {
1712 let mut idx = SecondaryIndex::new(
1713 index_name.to_string(),
1714 table_name.to_string(),
1715 column_name.to_string(),
1716 datatype,
1717 unique,
1718 IndexOrigin::Explicit,
1719 )?;
1720
1721 for (rowid, v) in existing {
1725 if unique && idx.would_violate_unique(v) {
1726 return Err(SQLRiteError::General(format!(
1727 "cannot create UNIQUE index '{index_name}': column '{column_name}' \
1728 already contains the duplicate value {}",
1729 v.to_display_string()
1730 )));
1731 }
1732 idx.insert(v, *rowid)?;
1733 }
1734
1735 let table_mut = db.get_table_mut(table_name.to_string())?;
1736 table_mut.secondary_indexes.push(idx);
1737 Ok(index_name.to_string())
1738}
1739
1740fn create_hnsw_index(
1742 db: &mut Database,
1743 table_name: &str,
1744 index_name: &str,
1745 column_name: &str,
1746 datatype: &DataType,
1747 unique: bool,
1748 metric: DistanceMetric,
1749 existing: &[(i64, Value)],
1750) -> Result<String> {
1751 let dim = match datatype {
1754 DataType::Vector(d) => *d,
1755 other => {
1756 return Err(SQLRiteError::General(format!(
1757 "USING hnsw requires a VECTOR column; '{column_name}' is {other}"
1758 )));
1759 }
1760 };
1761
1762 if unique {
1763 return Err(SQLRiteError::General(
1764 "UNIQUE has no meaning for HNSW indexes".to_string(),
1765 ));
1766 }
1767
1768 let seed = hash_str_to_seed(index_name);
1779 let mut idx = HnswIndex::new(metric, seed);
1780
1781 let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
1785 std::collections::HashMap::with_capacity(existing.len());
1786 for (rowid, v) in existing {
1787 match v {
1788 Value::Vector(vec) => {
1789 if vec.len() != dim {
1790 return Err(SQLRiteError::Internal(format!(
1791 "row {rowid} stores a {}-dim vector in column '{column_name}' \
1792 declared as VECTOR({dim}) — schema invariant violated",
1793 vec.len()
1794 )));
1795 }
1796 vec_map.insert(*rowid, vec.clone());
1797 }
1798 _ => continue,
1802 }
1803 }
1804
1805 for (rowid, _) in existing {
1806 if let Some(v) = vec_map.get(rowid) {
1807 let v_clone = v.clone();
1808 idx.insert(*rowid, &v_clone, |id| {
1809 vec_map.get(&id).cloned().unwrap_or_default()
1810 })?;
1811 }
1812 }
1813
1814 let table_mut = db.get_table_mut(table_name.to_string())?;
1815 table_mut.hnsw_indexes.push(HnswIndexEntry {
1816 name: index_name.to_string(),
1817 column_name: column_name.to_string(),
1818 metric,
1819 index: idx,
1820 needs_rebuild: false,
1822 });
1823 Ok(index_name.to_string())
1824}
1825
1826fn parse_hnsw_with_options(
1837 with: &[Expr],
1838 index_name: &str,
1839 method: IndexMethod,
1840) -> Result<Option<DistanceMetric>> {
1841 if with.is_empty() {
1842 return Ok(None);
1843 }
1844 if !matches!(method, IndexMethod::Hnsw) {
1845 return Err(SQLRiteError::General(format!(
1846 "CREATE INDEX '{index_name}' has a WITH (...) clause but its index method \
1847 doesn't support any options — only `USING hnsw` recognises `WITH (metric = ...)`"
1848 )));
1849 }
1850
1851 let mut metric: Option<DistanceMetric> = None;
1852 for opt in with {
1853 let Expr::BinaryOp { left, op, right } = opt else {
1854 return Err(SQLRiteError::General(format!(
1855 "CREATE INDEX '{index_name}': unsupported WITH option {opt:?} \
1856 (expected `key = 'value'`)"
1857 )));
1858 };
1859 if !matches!(op, BinaryOperator::Eq) {
1860 return Err(SQLRiteError::General(format!(
1861 "CREATE INDEX '{index_name}': WITH options must use `=` (got {op:?})"
1862 )));
1863 }
1864 let key = match left.as_ref() {
1865 Expr::Identifier(ident) => ident.value.clone(),
1866 other => {
1867 return Err(SQLRiteError::General(format!(
1868 "CREATE INDEX '{index_name}': WITH option key must be a bare identifier, \
1869 got {other:?}"
1870 )));
1871 }
1872 };
1873 let value = match right.as_ref() {
1874 Expr::Value(v) => match &v.value {
1875 AstValue::SingleQuotedString(s) => s.clone(),
1876 AstValue::DoubleQuotedString(s) => s.clone(),
1877 other => {
1878 return Err(SQLRiteError::General(format!(
1879 "CREATE INDEX '{index_name}': WITH option '{key}' value must be \
1880 a quoted string, got {other:?}"
1881 )));
1882 }
1883 },
1884 Expr::Identifier(ident) => ident.value.clone(),
1885 other => {
1886 return Err(SQLRiteError::General(format!(
1887 "CREATE INDEX '{index_name}': WITH option '{key}' value must be a \
1888 quoted string, got {other:?}"
1889 )));
1890 }
1891 };
1892
1893 if key.eq_ignore_ascii_case("metric") {
1894 let parsed = DistanceMetric::from_sql_name(&value).ok_or_else(|| {
1895 SQLRiteError::General(format!(
1896 "CREATE INDEX '{index_name}': unknown HNSW metric '{value}' \
1897 (try 'l2', 'cosine', or 'dot')"
1898 ))
1899 })?;
1900 if metric.is_some() {
1901 return Err(SQLRiteError::General(format!(
1902 "CREATE INDEX '{index_name}': metric specified more than once in WITH (...)"
1903 )));
1904 }
1905 metric = Some(parsed);
1906 } else {
1907 return Err(SQLRiteError::General(format!(
1908 "CREATE INDEX '{index_name}': unknown WITH option '{key}' \
1909 (only 'metric' is recognised on HNSW indexes)"
1910 )));
1911 }
1912 }
1913
1914 Ok(metric)
1915}
1916
1917fn create_fts_index(
1922 db: &mut Database,
1923 table_name: &str,
1924 index_name: &str,
1925 column_name: &str,
1926 datatype: &DataType,
1927 unique: bool,
1928 existing: &[(i64, Value)],
1929) -> Result<String> {
1930 match datatype {
1935 DataType::Text => {}
1936 other => {
1937 return Err(SQLRiteError::General(format!(
1938 "USING fts requires a TEXT column; '{column_name}' is {other}"
1939 )));
1940 }
1941 }
1942
1943 if unique {
1944 return Err(SQLRiteError::General(
1945 "UNIQUE has no meaning for FTS indexes".to_string(),
1946 ));
1947 }
1948
1949 let mut idx = PostingList::new();
1950 for (rowid, v) in existing {
1951 if let Value::Text(text) = v {
1952 idx.insert(*rowid, text);
1953 }
1954 }
1957
1958 let table_mut = db.get_table_mut(table_name.to_string())?;
1959 table_mut.fts_indexes.push(FtsIndexEntry {
1960 name: index_name.to_string(),
1961 column_name: column_name.to_string(),
1962 index: idx,
1963 needs_rebuild: false,
1964 });
1965 Ok(index_name.to_string())
1966}
1967
1968fn hash_str_to_seed(s: &str) -> u64 {
1972 let mut h: u64 = 0xCBF29CE484222325;
1973 for b in s.as_bytes() {
1974 h ^= *b as u64;
1975 h = h.wrapping_mul(0x100000001B3);
1976 }
1977 h
1978}
1979
1980fn clone_datatype(dt: &DataType) -> DataType {
1983 match dt {
1984 DataType::Integer => DataType::Integer,
1985 DataType::Text => DataType::Text,
1986 DataType::Real => DataType::Real,
1987 DataType::Bool => DataType::Bool,
1988 DataType::Vector(dim) => DataType::Vector(*dim),
1989 DataType::Json => DataType::Json,
1990 DataType::None => DataType::None,
1991 DataType::Invalid => DataType::Invalid,
1992 }
1993}
1994
1995fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<(String, Option<String>)> {
1996 if tables.len() != 1 {
1997 return Err(SQLRiteError::NotImplemented(
1998 "multi-table DELETE is not supported yet".to_string(),
1999 ));
2000 }
2001 extract_table_name(&tables[0])
2002}
2003
2004fn extract_table_name(twj: &TableWithJoins) -> Result<(String, Option<String>)> {
2008 if !twj.joins.is_empty() {
2009 return Err(SQLRiteError::NotImplemented(
2010 "JOIN is not supported yet".to_string(),
2011 ));
2012 }
2013 match &twj.relation {
2014 TableFactor::Table { name, alias, .. } => Ok((
2015 name.to_string(),
2016 alias.as_ref().map(|a| a.name.value.clone()),
2017 )),
2018 _ => Err(SQLRiteError::NotImplemented(
2019 "only plain table references are supported".to_string(),
2020 )),
2021 }
2022}
2023
2024enum RowidSource {
2026 IndexProbe(Vec<i64>),
2030 FullScan,
2033}
2034
2035fn select_rowids(table: &Table, selection: Option<&Expr>, scope_name: &str) -> Result<RowidSource> {
2040 let Some(expr) = selection else {
2041 return Ok(RowidSource::FullScan);
2042 };
2043 let Some((qualifier, col, literal)) = try_extract_equality(expr) else {
2044 return Ok(RowidSource::FullScan);
2045 };
2046 check_single_scope_qualifier(qualifier.as_deref(), scope_name, &col)?;
2050 let Some(idx) = table.index_for_column(&col) else {
2051 return Ok(RowidSource::FullScan);
2052 };
2053
2054 let literal_value = match convert_literal(&literal) {
2058 Ok(v) => v,
2059 Err(_) => return Ok(RowidSource::FullScan),
2060 };
2061
2062 let mut rowids = idx.lookup(&literal_value);
2066 rowids.sort_unstable();
2067 Ok(RowidSource::IndexProbe(rowids))
2068}
2069
2070fn try_extract_equality(expr: &Expr) -> Option<(Option<String>, String, sqlparser::ast::Value)> {
2076 let peeled = match expr {
2078 Expr::Nested(inner) => inner.as_ref(),
2079 other => other,
2080 };
2081 let Expr::BinaryOp { left, op, right } = peeled else {
2082 return None;
2083 };
2084 if !matches!(op, BinaryOperator::Eq) {
2085 return None;
2086 }
2087 let col_from = |e: &Expr| -> Option<(Option<String>, String)> {
2088 match e {
2089 Expr::Identifier(ident) => Some((None, ident.value.clone())),
2090 Expr::CompoundIdentifier(parts) => match parts.as_slice() {
2091 [only] => Some((None, only.value.clone())),
2092 [q, c] => Some((Some(q.value.clone()), c.value.clone())),
2093 _ => None,
2094 },
2095 _ => None,
2096 }
2097 };
2098 let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
2099 if let Expr::Value(v) = e {
2100 Some(v.value.clone())
2101 } else {
2102 None
2103 }
2104 };
2105 if let (Some((q, c)), Some(l)) = (col_from(left), literal_from(right)) {
2106 return Some((q, c, l));
2107 }
2108 if let (Some(l), Some((q, c))) = (literal_from(left), col_from(right)) {
2109 return Some((q, c, l));
2110 }
2111 None
2112}
2113
2114fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
2139 if k == 0 {
2140 return None;
2141 }
2142
2143 let func = match order_expr {
2146 Expr::Function(f) => f,
2147 _ => return None,
2148 };
2149 let fname = match func.name.0.as_slice() {
2150 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
2151 _ => return None,
2152 };
2153 let query_metric = match fname.as_str() {
2154 "vec_distance_l2" => DistanceMetric::L2,
2155 "vec_distance_cosine" => DistanceMetric::Cosine,
2156 "vec_distance_dot" => DistanceMetric::Dot,
2157 _ => return None,
2158 };
2159
2160 let arg_list = match &func.args {
2162 FunctionArguments::List(l) => &l.args,
2163 _ => return None,
2164 };
2165 if arg_list.len() != 2 {
2166 return None;
2167 }
2168 let exprs: Vec<&Expr> = arg_list
2169 .iter()
2170 .filter_map(|a| match a {
2171 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
2172 _ => None,
2173 })
2174 .collect();
2175 if exprs.len() != 2 {
2176 return None;
2177 }
2178
2179 let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
2184 Some(v) => v,
2185 None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
2186 Some(v) => v,
2187 None => return None,
2188 },
2189 };
2190
2191 let entry = table
2196 .hnsw_indexes
2197 .iter()
2198 .find(|e| e.column_name == col_name && e.metric == query_metric)?;
2199
2200 let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
2206 Some(c) => match &c.datatype {
2207 DataType::Vector(d) => *d,
2208 _ => return None,
2209 },
2210 None => return None,
2211 };
2212 if query_vec.len() != declared_dim {
2213 return None;
2214 }
2215
2216 let column_for_closure = col_name.clone();
2220 let table_ref = table;
2221 let result = entry
2222 .index
2223 .search(&query_vec, k, |id| {
2224 match table_ref.get_value(&column_for_closure, id) {
2225 Some(Value::Vector(v)) => v,
2226 _ => Vec::new(),
2227 }
2228 })
2229 .ok()?;
2230 Some(result)
2231}
2232
2233fn try_fts_probe(table: &Table, order_expr: &Expr, ascending: bool, k: usize) -> Option<Vec<i64>> {
2249 if k == 0 || ascending {
2250 return None;
2254 }
2255
2256 let func = match order_expr {
2257 Expr::Function(f) => f,
2258 _ => return None,
2259 };
2260 let fname = match func.name.0.as_slice() {
2261 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
2262 _ => return None,
2263 };
2264 if fname != "bm25_score" {
2265 return None;
2266 }
2267
2268 let arg_list = match &func.args {
2269 FunctionArguments::List(l) => &l.args,
2270 _ => return None,
2271 };
2272 if arg_list.len() != 2 {
2273 return None;
2274 }
2275 let exprs: Vec<&Expr> = arg_list
2276 .iter()
2277 .filter_map(|a| match a {
2278 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
2279 _ => None,
2280 })
2281 .collect();
2282 if exprs.len() != 2 {
2283 return None;
2284 }
2285
2286 let col_name = match exprs[0] {
2288 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
2289 _ => return None,
2290 };
2291
2292 let query = match exprs[1] {
2296 Expr::Value(v) => match &v.value {
2297 AstValue::SingleQuotedString(s) => s.clone(),
2298 _ => return None,
2299 },
2300 _ => return None,
2301 };
2302
2303 let entry = table
2304 .fts_indexes
2305 .iter()
2306 .find(|e| e.column_name == col_name)?;
2307
2308 let scored = entry.index.query(&query, &Bm25Params::default());
2309 let mut out: Vec<i64> = scored.into_iter().map(|(id, _)| id).collect();
2310 if out.len() > k {
2311 out.truncate(k);
2312 }
2313 Some(out)
2314}
2315
2316fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
2321 let col_name = match a {
2322 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
2323 _ => return None,
2324 };
2325 let lit_str = match b {
2326 Expr::Identifier(ident) if ident.quote_style == Some('[') => {
2327 format!("[{}]", ident.value)
2328 }
2329 _ => return None,
2330 };
2331 let v = parse_vector_literal(&lit_str).ok()?;
2332 Some((col_name, v))
2333}
2334
2335struct HeapEntry {
2348 key: Value,
2349 rowid: i64,
2350 asc: bool,
2351}
2352
2353impl PartialEq for HeapEntry {
2354 fn eq(&self, other: &Self) -> bool {
2355 self.cmp(other) == Ordering::Equal
2356 }
2357}
2358
2359impl Eq for HeapEntry {}
2360
2361impl PartialOrd for HeapEntry {
2362 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2363 Some(self.cmp(other))
2364 }
2365}
2366
2367impl Ord for HeapEntry {
2368 fn cmp(&self, other: &Self) -> Ordering {
2369 let raw = compare_values(Some(&self.key), Some(&other.key));
2370 if self.asc { raw } else { raw.reverse() }
2371 }
2372}
2373
2374fn select_topk(
2383 matching: &[i64],
2384 table: &Table,
2385 order: &OrderByClause,
2386 k: usize,
2387 scope_name: &str,
2388) -> Result<Vec<i64>> {
2389 use std::collections::BinaryHeap;
2390
2391 if k == 0 || matching.is_empty() {
2392 return Ok(Vec::new());
2393 }
2394
2395 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
2396
2397 for &rowid in matching {
2398 let key = eval_expr(&order.expr, table, rowid, scope_name)?;
2399 let entry = HeapEntry {
2400 key,
2401 rowid,
2402 asc: order.ascending,
2403 };
2404
2405 if heap.len() < k {
2406 heap.push(entry);
2407 } else {
2408 if entry < *heap.peek().unwrap() {
2412 heap.pop();
2413 heap.push(entry);
2414 }
2415 }
2416 }
2417
2418 Ok(heap
2423 .into_sorted_vec()
2424 .into_iter()
2425 .map(|e| e.rowid)
2426 .collect())
2427}
2428
2429fn sort_rowids(
2430 rowids: &mut [i64],
2431 table: &Table,
2432 order: &OrderByClause,
2433 scope_name: &str,
2434) -> Result<()> {
2435 let mut keys: Vec<(i64, Result<Value>)> = rowids
2443 .iter()
2444 .map(|r| (*r, eval_expr(&order.expr, table, *r, scope_name)))
2445 .collect();
2446
2447 for (_, k) in &keys {
2451 if let Err(e) = k {
2452 return Err(SQLRiteError::General(format!(
2453 "ORDER BY expression failed: {e}"
2454 )));
2455 }
2456 }
2457
2458 keys.sort_by(|(_, ka), (_, kb)| {
2459 let va = ka.as_ref().unwrap();
2462 let vb = kb.as_ref().unwrap();
2463 let ord = compare_values(Some(va), Some(vb));
2464 if order.ascending { ord } else { ord.reverse() }
2465 });
2466
2467 for (i, (rowid, _)) in keys.into_iter().enumerate() {
2469 rowids[i] = rowid;
2470 }
2471 Ok(())
2472}
2473
2474fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
2475 match (a, b) {
2476 (None, None) => Ordering::Equal,
2477 (None, _) => Ordering::Less,
2478 (_, None) => Ordering::Greater,
2479 (Some(a), Some(b)) => match (a, b) {
2480 (Value::Null, Value::Null) => Ordering::Equal,
2481 (Value::Null, _) => Ordering::Less,
2482 (_, Value::Null) => Ordering::Greater,
2483 (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
2484 (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
2485 (Value::Integer(x), Value::Real(y)) => {
2486 (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
2487 }
2488 (Value::Real(x), Value::Integer(y)) => {
2489 x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
2490 }
2491 (Value::Text(x), Value::Text(y)) => x.cmp(y),
2492 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
2493 (x, y) => x.to_display_string().cmp(&y.to_display_string()),
2495 },
2496 }
2497}
2498
2499pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64, scope_name: &str) -> Result<bool> {
2504 eval_predicate_scope(expr, &SingleTableScope::new(table, rowid, scope_name))
2505}
2506
2507pub(crate) fn eval_predicate_scope(expr: &Expr, scope: &dyn RowScope) -> Result<bool> {
2511 let v = eval_expr_scope(expr, scope)?;
2512 match v {
2513 Value::Bool(b) => Ok(b),
2514 Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
2516 other => Err(SQLRiteError::Internal(format!(
2517 "WHERE clause must evaluate to boolean, got {}",
2518 other.to_display_string()
2519 ))),
2520 }
2521}
2522
2523fn eval_expr(expr: &Expr, table: &Table, rowid: i64, scope_name: &str) -> Result<Value> {
2525 eval_expr_scope(expr, &SingleTableScope::new(table, rowid, scope_name))
2526}
2527
2528fn eval_expr_scope(expr: &Expr, scope: &dyn RowScope) -> Result<Value> {
2529 match expr {
2530 Expr::Nested(inner) => eval_expr_scope(inner, scope),
2531
2532 Expr::Identifier(ident) => {
2533 if ident.quote_style == Some('[') {
2543 let raw = format!("[{}]", ident.value);
2544 let v = parse_vector_literal(&raw)?;
2545 return Ok(Value::Vector(v));
2546 }
2547 scope.lookup(None, &ident.value)
2548 }
2549
2550 Expr::CompoundIdentifier(parts) => {
2551 match parts.as_slice() {
2558 [only] => scope.lookup(None, &only.value),
2559 [q, c] => scope.lookup(Some(&q.value), &c.value),
2560 _ => Err(SQLRiteError::NotImplemented(format!(
2561 "compound identifier with {} parts is not supported",
2562 parts.len()
2563 ))),
2564 }
2565 }
2566
2567 Expr::Value(v) => convert_literal(&v.value),
2568
2569 Expr::UnaryOp { op, expr } => {
2570 let inner = eval_expr_scope(expr, scope)?;
2571 match op {
2572 UnaryOperator::Not => match inner {
2573 Value::Bool(b) => Ok(Value::Bool(!b)),
2574 Value::Null => Ok(Value::Null),
2575 other => Err(SQLRiteError::Internal(format!(
2576 "NOT applied to non-boolean value: {}",
2577 other.to_display_string()
2578 ))),
2579 },
2580 UnaryOperator::Minus => match inner {
2581 Value::Integer(i) => Ok(Value::Integer(-i)),
2582 Value::Real(f) => Ok(Value::Real(-f)),
2583 Value::Null => Ok(Value::Null),
2584 other => Err(SQLRiteError::Internal(format!(
2585 "unary minus on non-numeric value: {}",
2586 other.to_display_string()
2587 ))),
2588 },
2589 UnaryOperator::Plus => Ok(inner),
2590 other => Err(SQLRiteError::NotImplemented(format!(
2591 "unary operator {other:?} is not supported"
2592 ))),
2593 }
2594 }
2595
2596 Expr::BinaryOp { left, op, right } => match op {
2597 BinaryOperator::And => {
2598 let l = eval_expr_scope(left, scope)?;
2599 let r = eval_expr_scope(right, scope)?;
2600 Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
2601 }
2602 BinaryOperator::Or => {
2603 let l = eval_expr_scope(left, scope)?;
2604 let r = eval_expr_scope(right, scope)?;
2605 Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
2606 }
2607 cmp @ (BinaryOperator::Eq
2608 | BinaryOperator::NotEq
2609 | BinaryOperator::Lt
2610 | BinaryOperator::LtEq
2611 | BinaryOperator::Gt
2612 | BinaryOperator::GtEq) => {
2613 let l = eval_expr_scope(left, scope)?;
2614 let r = eval_expr_scope(right, scope)?;
2615 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2617 return Ok(Value::Bool(false));
2618 }
2619 let ord = compare_values(Some(&l), Some(&r));
2620 let result = match cmp {
2621 BinaryOperator::Eq => ord == Ordering::Equal,
2622 BinaryOperator::NotEq => ord != Ordering::Equal,
2623 BinaryOperator::Lt => ord == Ordering::Less,
2624 BinaryOperator::LtEq => ord != Ordering::Greater,
2625 BinaryOperator::Gt => ord == Ordering::Greater,
2626 BinaryOperator::GtEq => ord != Ordering::Less,
2627 _ => unreachable!(),
2628 };
2629 Ok(Value::Bool(result))
2630 }
2631 arith @ (BinaryOperator::Plus
2632 | BinaryOperator::Minus
2633 | BinaryOperator::Multiply
2634 | BinaryOperator::Divide
2635 | BinaryOperator::Modulo) => {
2636 let l = eval_expr_scope(left, scope)?;
2637 let r = eval_expr_scope(right, scope)?;
2638 eval_arith(arith, &l, &r)
2639 }
2640 BinaryOperator::StringConcat => {
2641 let l = eval_expr_scope(left, scope)?;
2642 let r = eval_expr_scope(right, scope)?;
2643 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2644 return Ok(Value::Null);
2645 }
2646 Ok(Value::Text(format!(
2647 "{}{}",
2648 l.to_display_string(),
2649 r.to_display_string()
2650 )))
2651 }
2652 other => Err(SQLRiteError::NotImplemented(format!(
2653 "binary operator {other:?} is not supported yet"
2654 ))),
2655 },
2656
2657 Expr::IsNull(inner) => {
2665 let v = eval_expr_scope(inner, scope)?;
2666 Ok(Value::Bool(matches!(v, Value::Null)))
2667 }
2668 Expr::IsNotNull(inner) => {
2669 let v = eval_expr_scope(inner, scope)?;
2670 Ok(Value::Bool(!matches!(v, Value::Null)))
2671 }
2672
2673 Expr::Like {
2680 negated,
2681 any,
2682 expr: lhs,
2683 pattern,
2684 escape_char,
2685 } => eval_like(
2686 scope,
2687 *negated,
2688 *any,
2689 lhs,
2690 pattern,
2691 escape_char.as_ref(),
2692 true,
2693 ),
2694 Expr::ILike {
2695 negated,
2696 any,
2697 expr: lhs,
2698 pattern,
2699 escape_char,
2700 } => eval_like(
2701 scope,
2702 *negated,
2703 *any,
2704 lhs,
2705 pattern,
2706 escape_char.as_ref(),
2707 true,
2708 ),
2709
2710 Expr::InList {
2716 expr: lhs,
2717 list,
2718 negated,
2719 } => eval_in_list(scope, lhs, list, *negated),
2720 Expr::InSubquery { .. } => Err(SQLRiteError::NotImplemented(
2721 "IN (subquery) is not supported (only literal lists are)".to_string(),
2722 )),
2723
2724 Expr::Function(func) => eval_function(func, scope),
2735
2736 other => Err(SQLRiteError::NotImplemented(format!(
2737 "unsupported expression in WHERE/projection: {other:?}"
2738 ))),
2739 }
2740}
2741
2742fn eval_function(func: &sqlparser::ast::Function, scope: &dyn RowScope) -> Result<Value> {
2747 let name = match func.name.0.as_slice() {
2750 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
2751 _ => {
2752 return Err(SQLRiteError::NotImplemented(format!(
2753 "qualified function names not supported: {:?}",
2754 func.name
2755 )));
2756 }
2757 };
2758
2759 match name.as_str() {
2760 "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
2761 let (a, b) = extract_two_vector_args(&name, &func.args, scope)?;
2762 let dist = match name.as_str() {
2763 "vec_distance_l2" => vec_distance_l2(&a, &b),
2764 "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
2765 "vec_distance_dot" => vec_distance_dot(&a, &b),
2766 _ => unreachable!(),
2767 };
2768 Ok(Value::Real(dist as f64))
2774 }
2775 "json_extract" => json_fn_extract(&name, &func.args, scope),
2780 "json_type" => json_fn_type(&name, &func.args, scope),
2781 "json_array_length" => json_fn_array_length(&name, &func.args, scope),
2782 "json_object_keys" => json_fn_object_keys(&name, &func.args, scope),
2783 "fts_match" | "bm25_score" => {
2794 let Some((table, rowid)) = scope.single_table_view() else {
2795 return Err(SQLRiteError::NotImplemented(format!(
2796 "{name}() is not yet supported inside a JOIN query — \
2797 use it on a single-table SELECT or move the FTS lookup into a subquery"
2798 )));
2799 };
2800 let (entry, query) = resolve_fts_args(&name, &func.args, table, scope)?;
2801 Ok(match name.as_str() {
2802 "fts_match" => Value::Bool(entry.index.matches(rowid, &query)),
2803 "bm25_score" => {
2804 Value::Real(entry.index.score(rowid, &query, &Bm25Params::default()))
2805 }
2806 _ => unreachable!(),
2807 })
2808 }
2809 "count" | "sum" | "avg" | "min" | "max" => Err(SQLRiteError::NotImplemented(format!(
2813 "aggregate function '{name}' is not allowed in WHERE / projection-scalar position; \
2814 use it as a top-level projection item or in HAVING"
2815 ))),
2816 other => Err(SQLRiteError::NotImplemented(format!(
2817 "unknown function: {other}(...)"
2818 ))),
2819 }
2820}
2821
2822fn resolve_fts_args<'t>(
2827 fn_name: &str,
2828 args: &FunctionArguments,
2829 table: &'t Table,
2830 scope: &dyn RowScope,
2831) -> Result<(&'t FtsIndexEntry, String)> {
2832 let arg_list = match args {
2833 FunctionArguments::List(l) => &l.args,
2834 _ => {
2835 return Err(SQLRiteError::General(format!(
2836 "{fn_name}() expects exactly two arguments: (column, query_text)"
2837 )));
2838 }
2839 };
2840 if arg_list.len() != 2 {
2841 return Err(SQLRiteError::General(format!(
2842 "{fn_name}() expects exactly 2 arguments, got {}",
2843 arg_list.len()
2844 )));
2845 }
2846
2847 let col_expr = match &arg_list[0] {
2851 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2852 other => {
2853 return Err(SQLRiteError::NotImplemented(format!(
2854 "{fn_name}() argument 0 must be a column name, got {other:?}"
2855 )));
2856 }
2857 };
2858 let col_name = match col_expr {
2859 Expr::Identifier(ident) => ident.value.clone(),
2860 Expr::CompoundIdentifier(parts) => match parts.as_slice() {
2865 [only] => only.value.clone(),
2866 [q, c] => {
2867 if let Some(scope_name) = scope.scope_name() {
2868 check_single_scope_qualifier(Some(&q.value), scope_name, &c.value)?;
2869 }
2870 c.value.clone()
2871 }
2872 _ => {
2873 return Err(SQLRiteError::NotImplemented(format!(
2874 "compound identifier with {} parts is not supported",
2875 parts.len()
2876 )));
2877 }
2878 },
2879 other => {
2880 return Err(SQLRiteError::General(format!(
2881 "{fn_name}() argument 0 must be a column reference, got {other:?}"
2882 )));
2883 }
2884 };
2885
2886 let q_expr = match &arg_list[1] {
2890 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2891 other => {
2892 return Err(SQLRiteError::NotImplemented(format!(
2893 "{fn_name}() argument 1 must be a text expression, got {other:?}"
2894 )));
2895 }
2896 };
2897 let query = match eval_expr_scope(q_expr, scope)? {
2898 Value::Text(s) => s,
2899 other => {
2900 return Err(SQLRiteError::General(format!(
2901 "{fn_name}() argument 1 must be TEXT, got {}",
2902 other.to_display_string()
2903 )));
2904 }
2905 };
2906
2907 let entry = table
2908 .fts_indexes
2909 .iter()
2910 .find(|e| e.column_name == col_name)
2911 .ok_or_else(|| {
2912 SQLRiteError::General(format!(
2913 "{fn_name}({col_name}, ...): no FTS index on column '{col_name}' \
2914 (run CREATE INDEX <name> ON <table> USING fts({col_name}) first)"
2915 ))
2916 })?;
2917 Ok((entry, query))
2918}
2919
2920fn extract_json_and_path(
2934 fn_name: &str,
2935 args: &FunctionArguments,
2936 scope: &dyn RowScope,
2937) -> Result<(String, String)> {
2938 let arg_list = match args {
2939 FunctionArguments::List(l) => &l.args,
2940 _ => {
2941 return Err(SQLRiteError::General(format!(
2942 "{fn_name}() expects 1 or 2 arguments"
2943 )));
2944 }
2945 };
2946 if !(arg_list.len() == 1 || arg_list.len() == 2) {
2947 return Err(SQLRiteError::General(format!(
2948 "{fn_name}() expects 1 or 2 arguments, got {}",
2949 arg_list.len()
2950 )));
2951 }
2952 let first_expr = match &arg_list[0] {
2954 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2955 other => {
2956 return Err(SQLRiteError::NotImplemented(format!(
2957 "{fn_name}() argument 0 has unsupported shape: {other:?}"
2958 )));
2959 }
2960 };
2961 let json_text = match eval_expr_scope(first_expr, scope)? {
2962 Value::Text(s) => s,
2963 Value::Null => {
2964 return Err(SQLRiteError::General(format!(
2965 "{fn_name}() called on NULL — JSON column has no value for this row"
2966 )));
2967 }
2968 other => {
2969 return Err(SQLRiteError::General(format!(
2970 "{fn_name}() argument 0 is not JSON-typed: got {}",
2971 other.to_display_string()
2972 )));
2973 }
2974 };
2975
2976 let path = if arg_list.len() == 2 {
2978 let path_expr = match &arg_list[1] {
2979 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2980 other => {
2981 return Err(SQLRiteError::NotImplemented(format!(
2982 "{fn_name}() argument 1 has unsupported shape: {other:?}"
2983 )));
2984 }
2985 };
2986 match eval_expr_scope(path_expr, scope)? {
2987 Value::Text(s) => s,
2988 other => {
2989 return Err(SQLRiteError::General(format!(
2990 "{fn_name}() path argument must be a string literal, got {}",
2991 other.to_display_string()
2992 )));
2993 }
2994 }
2995 } else {
2996 "$".to_string()
2997 };
2998
2999 Ok((json_text, path))
3000}
3001
3002fn walk_json_path<'a>(
3012 value: &'a serde_json::Value,
3013 path: &str,
3014) -> Result<Option<&'a serde_json::Value>> {
3015 let mut chars = path.chars().peekable();
3016 if chars.next() != Some('$') {
3017 return Err(SQLRiteError::General(format!(
3018 "JSON path must start with '$', got `{path}`"
3019 )));
3020 }
3021 let mut current = value;
3022 while let Some(&c) = chars.peek() {
3023 match c {
3024 '.' => {
3025 chars.next();
3026 let mut key = String::new();
3027 while let Some(&c) = chars.peek() {
3028 if c == '.' || c == '[' {
3029 break;
3030 }
3031 key.push(c);
3032 chars.next();
3033 }
3034 if key.is_empty() {
3035 return Err(SQLRiteError::General(format!(
3036 "JSON path has empty key after '.' in `{path}`"
3037 )));
3038 }
3039 match current.get(&key) {
3040 Some(v) => current = v,
3041 None => return Ok(None),
3042 }
3043 }
3044 '[' => {
3045 chars.next();
3046 let mut idx_str = String::new();
3047 while let Some(&c) = chars.peek() {
3048 if c == ']' {
3049 break;
3050 }
3051 idx_str.push(c);
3052 chars.next();
3053 }
3054 if chars.next() != Some(']') {
3055 return Err(SQLRiteError::General(format!(
3056 "JSON path has unclosed `[` in `{path}`"
3057 )));
3058 }
3059 let idx: usize = idx_str.trim().parse().map_err(|_| {
3060 SQLRiteError::General(format!(
3061 "JSON path has non-integer index `[{idx_str}]` in `{path}`"
3062 ))
3063 })?;
3064 match current.get(idx) {
3065 Some(v) => current = v,
3066 None => return Ok(None),
3067 }
3068 }
3069 other => {
3070 return Err(SQLRiteError::General(format!(
3071 "JSON path has unexpected character `{other}` in `{path}` \
3072 (expected `.`, `[`, or end-of-path)"
3073 )));
3074 }
3075 }
3076 }
3077 Ok(Some(current))
3078}
3079
3080fn json_value_to_sql(v: &serde_json::Value) -> Value {
3084 match v {
3085 serde_json::Value::Null => Value::Null,
3086 serde_json::Value::Bool(b) => Value::Bool(*b),
3087 serde_json::Value::Number(n) => {
3088 if let Some(i) = n.as_i64() {
3090 Value::Integer(i)
3091 } else if let Some(f) = n.as_f64() {
3092 Value::Real(f)
3093 } else {
3094 Value::Null
3095 }
3096 }
3097 serde_json::Value::String(s) => Value::Text(s.clone()),
3098 composite => Value::Text(composite.to_string()),
3102 }
3103}
3104
3105fn json_fn_extract(name: &str, args: &FunctionArguments, scope: &dyn RowScope) -> Result<Value> {
3106 let (json_text, path) = extract_json_and_path(name, args, scope)?;
3107 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
3108 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
3109 })?;
3110 match walk_json_path(&parsed, &path)? {
3111 Some(v) => Ok(json_value_to_sql(v)),
3112 None => Ok(Value::Null),
3113 }
3114}
3115
3116fn json_fn_type(name: &str, args: &FunctionArguments, scope: &dyn RowScope) -> Result<Value> {
3117 let (json_text, path) = extract_json_and_path(name, args, scope)?;
3118 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
3119 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
3120 })?;
3121 let resolved = match walk_json_path(&parsed, &path)? {
3122 Some(v) => v,
3123 None => return Ok(Value::Null),
3124 };
3125 let ty = match resolved {
3126 serde_json::Value::Null => "null",
3127 serde_json::Value::Bool(true) => "true",
3128 serde_json::Value::Bool(false) => "false",
3129 serde_json::Value::Number(n) => {
3130 if n.is_i64() || n.is_u64() {
3131 "integer"
3132 } else {
3133 "real"
3134 }
3135 }
3136 serde_json::Value::String(_) => "text",
3137 serde_json::Value::Array(_) => "array",
3138 serde_json::Value::Object(_) => "object",
3139 };
3140 Ok(Value::Text(ty.to_string()))
3141}
3142
3143fn json_fn_array_length(
3144 name: &str,
3145 args: &FunctionArguments,
3146 scope: &dyn RowScope,
3147) -> Result<Value> {
3148 let (json_text, path) = extract_json_and_path(name, args, scope)?;
3149 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
3150 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
3151 })?;
3152 let resolved = match walk_json_path(&parsed, &path)? {
3153 Some(v) => v,
3154 None => return Ok(Value::Null),
3155 };
3156 match resolved.as_array() {
3157 Some(arr) => Ok(Value::Integer(arr.len() as i64)),
3158 None => Err(SQLRiteError::General(format!(
3159 "{name}() resolved to a non-array value at path `{path}`"
3160 ))),
3161 }
3162}
3163
3164fn json_fn_object_keys(
3165 name: &str,
3166 args: &FunctionArguments,
3167 scope: &dyn RowScope,
3168) -> Result<Value> {
3169 let (json_text, path) = extract_json_and_path(name, args, scope)?;
3170 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
3171 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
3172 })?;
3173 let resolved = match walk_json_path(&parsed, &path)? {
3174 Some(v) => v,
3175 None => return Ok(Value::Null),
3176 };
3177 let obj = resolved.as_object().ok_or_else(|| {
3178 SQLRiteError::General(format!(
3179 "{name}() resolved to a non-object value at path `{path}`"
3180 ))
3181 })?;
3182 let keys: Vec<serde_json::Value> = obj
3189 .keys()
3190 .map(|k| serde_json::Value::String(k.clone()))
3191 .collect();
3192 Ok(Value::Text(serde_json::Value::Array(keys).to_string()))
3193}
3194
3195fn extract_two_vector_args(
3199 fn_name: &str,
3200 args: &FunctionArguments,
3201 scope: &dyn RowScope,
3202) -> Result<(Vec<f32>, Vec<f32>)> {
3203 let arg_list = match args {
3204 FunctionArguments::List(l) => &l.args,
3205 _ => {
3206 return Err(SQLRiteError::General(format!(
3207 "{fn_name}() expects exactly two vector arguments"
3208 )));
3209 }
3210 };
3211 if arg_list.len() != 2 {
3212 return Err(SQLRiteError::General(format!(
3213 "{fn_name}() expects exactly 2 arguments, got {}",
3214 arg_list.len()
3215 )));
3216 }
3217 let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
3218 for (i, arg) in arg_list.iter().enumerate() {
3219 let expr = match arg {
3220 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
3221 other => {
3222 return Err(SQLRiteError::NotImplemented(format!(
3223 "{fn_name}() argument {i} has unsupported shape: {other:?}"
3224 )));
3225 }
3226 };
3227 let val = eval_expr_scope(expr, scope)?;
3228 match val {
3229 Value::Vector(v) => out.push(v),
3230 other => {
3231 return Err(SQLRiteError::General(format!(
3232 "{fn_name}() argument {i} is not a vector: got {}",
3233 other.to_display_string()
3234 )));
3235 }
3236 }
3237 }
3238 let b = out.pop().unwrap();
3239 let a = out.pop().unwrap();
3240 if a.len() != b.len() {
3241 return Err(SQLRiteError::General(format!(
3242 "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
3243 a.len(),
3244 b.len()
3245 )));
3246 }
3247 Ok((a, b))
3248}
3249
3250pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
3253 debug_assert_eq!(a.len(), b.len());
3254 let mut sum = 0.0f32;
3255 for i in 0..a.len() {
3256 let d = a[i] - b[i];
3257 sum += d * d;
3258 }
3259 sum.sqrt()
3260}
3261
3262pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
3272 debug_assert_eq!(a.len(), b.len());
3273 let mut dot = 0.0f32;
3274 let mut norm_a_sq = 0.0f32;
3275 let mut norm_b_sq = 0.0f32;
3276 for i in 0..a.len() {
3277 dot += a[i] * b[i];
3278 norm_a_sq += a[i] * a[i];
3279 norm_b_sq += b[i] * b[i];
3280 }
3281 let denom = (norm_a_sq * norm_b_sq).sqrt();
3282 if denom == 0.0 {
3283 return Err(SQLRiteError::General(
3284 "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
3285 ));
3286 }
3287 Ok(1.0 - dot / denom)
3288}
3289
3290pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
3294 debug_assert_eq!(a.len(), b.len());
3295 let mut dot = 0.0f32;
3296 for i in 0..a.len() {
3297 dot += a[i] * b[i];
3298 }
3299 -dot
3300}
3301
3302fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
3305 if matches!(l, Value::Null) || matches!(r, Value::Null) {
3306 return Ok(Value::Null);
3307 }
3308 match (l, r) {
3309 (Value::Integer(a), Value::Integer(b)) => match op {
3310 BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
3311 BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
3312 BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
3313 BinaryOperator::Divide => {
3314 if *b == 0 {
3315 Err(SQLRiteError::General("division by zero".to_string()))
3316 } else {
3317 Ok(Value::Integer(a / b))
3318 }
3319 }
3320 BinaryOperator::Modulo => {
3321 if *b == 0 {
3322 Err(SQLRiteError::General("modulo by zero".to_string()))
3323 } else {
3324 Ok(Value::Integer(a % b))
3325 }
3326 }
3327 _ => unreachable!(),
3328 },
3329 (a, b) => {
3331 let af = as_number(a)?;
3332 let bf = as_number(b)?;
3333 match op {
3334 BinaryOperator::Plus => Ok(Value::Real(af + bf)),
3335 BinaryOperator::Minus => Ok(Value::Real(af - bf)),
3336 BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
3337 BinaryOperator::Divide => {
3338 if bf == 0.0 {
3339 Err(SQLRiteError::General("division by zero".to_string()))
3340 } else {
3341 Ok(Value::Real(af / bf))
3342 }
3343 }
3344 BinaryOperator::Modulo => {
3345 if bf == 0.0 {
3346 Err(SQLRiteError::General("modulo by zero".to_string()))
3347 } else {
3348 Ok(Value::Real(af % bf))
3349 }
3350 }
3351 _ => unreachable!(),
3352 }
3353 }
3354 }
3355}
3356
3357fn as_number(v: &Value) -> Result<f64> {
3358 match v {
3359 Value::Integer(i) => Ok(*i as f64),
3360 Value::Real(f) => Ok(*f),
3361 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
3362 other => Err(SQLRiteError::General(format!(
3363 "arithmetic on non-numeric value '{}'",
3364 other.to_display_string()
3365 ))),
3366 }
3367}
3368
3369fn as_bool(v: &Value) -> Result<bool> {
3370 match v {
3371 Value::Bool(b) => Ok(*b),
3372 Value::Null => Ok(false),
3373 Value::Integer(i) => Ok(*i != 0),
3374 other => Err(SQLRiteError::Internal(format!(
3375 "expected boolean, got {}",
3376 other.to_display_string()
3377 ))),
3378 }
3379}
3380
3381#[allow(clippy::too_many_arguments)]
3386fn eval_like(
3387 scope: &dyn RowScope,
3388 negated: bool,
3389 any: bool,
3390 lhs: &Expr,
3391 pattern: &Expr,
3392 escape_char: Option<&AstValue>,
3393 case_insensitive: bool,
3394) -> Result<Value> {
3395 if any {
3396 return Err(SQLRiteError::NotImplemented(
3397 "LIKE ANY (...) is not supported".to_string(),
3398 ));
3399 }
3400 if escape_char.is_some() {
3401 return Err(SQLRiteError::NotImplemented(
3402 "LIKE ... ESCAPE '<char>' is not supported (default `\\` escape only)".to_string(),
3403 ));
3404 }
3405
3406 let l = eval_expr_scope(lhs, scope)?;
3407 let p = eval_expr_scope(pattern, scope)?;
3408 if matches!(l, Value::Null) || matches!(p, Value::Null) {
3409 return Ok(Value::Null);
3410 }
3411 let text = match l {
3412 Value::Text(s) => s,
3413 other => other.to_display_string(),
3414 };
3415 let pat = match p {
3416 Value::Text(s) => s,
3417 other => other.to_display_string(),
3418 };
3419 let m = like_match(&text, &pat, case_insensitive);
3420 Ok(Value::Bool(if negated { !m } else { m }))
3421}
3422
3423fn eval_in_list(scope: &dyn RowScope, lhs: &Expr, list: &[Expr], negated: bool) -> Result<Value> {
3424 let l = eval_expr_scope(lhs, scope)?;
3425 if matches!(l, Value::Null) {
3426 return Ok(Value::Null);
3427 }
3428 let mut saw_null = false;
3429 for item in list {
3430 let r = eval_expr_scope(item, scope)?;
3431 if matches!(r, Value::Null) {
3432 saw_null = true;
3433 continue;
3434 }
3435 if compare_values(Some(&l), Some(&r)) == Ordering::Equal {
3436 return Ok(Value::Bool(!negated));
3437 }
3438 }
3439 if saw_null {
3440 Ok(Value::Null)
3443 } else {
3444 Ok(Value::Bool(negated))
3445 }
3446}
3447
3448fn lower_having_into_hidden_slots(
3465 query: &SelectQuery,
3466 proj_items: &[ProjectionItem],
3467) -> Result<(Vec<ProjectionItem>, Option<Expr>)> {
3468 let mut all_items = proj_items.to_vec();
3469 let having_expr = match &query.having {
3470 Some(h) => {
3471 for g in &query.group_by {
3472 if !all_items
3473 .iter()
3474 .any(|i| i.output_name().eq_ignore_ascii_case(&g.name))
3475 {
3476 all_items.push(ProjectionItem {
3477 kind: ProjectionKind::Column {
3478 qualifier: g.qualifier.clone(),
3479 name: g.name.clone(),
3480 },
3481 alias: None,
3482 });
3483 }
3484 }
3485 Some(lower_having_expr(h, &mut all_items)?)
3486 }
3487 None => None,
3488 };
3489 Ok((all_items, having_expr))
3490}
3491
3492fn run_aggregation_pipeline<S: RowScope>(
3498 scopes: impl IntoIterator<Item = S>,
3499 query: &SelectQuery,
3500 proj_items: &[ProjectionItem],
3501 all_items: &[ProjectionItem],
3502 having_expr: &Option<Expr>,
3503) -> Result<SelectResult> {
3504 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
3505 let mut rows = aggregate_rows(scopes, &query.group_by, all_items)?;
3506
3507 if let Some(h) = having_expr {
3508 let all_columns: Vec<String> = all_items.iter().map(|i| i.output_name()).collect();
3509 rows = filter_groups_by_having(rows, h, &all_columns)?;
3510 }
3511 if all_items.len() > proj_items.len() {
3513 for row in &mut rows {
3514 row.truncate(proj_items.len());
3515 }
3516 }
3517
3518 if query.distinct {
3519 rows = dedupe_rows(rows);
3520 }
3521
3522 if let Some(order) = &query.order_by {
3523 sort_output_rows(&mut rows, &columns, proj_items, order)?;
3524 }
3525 if let Some(k) = query.limit {
3526 rows.truncate(k);
3527 }
3528
3529 Ok(SelectResult { columns, rows })
3530}
3531
3532fn aggregate_rows<S: RowScope>(
3546 scopes: impl IntoIterator<Item = S>,
3547 group_by: &[GroupByKey],
3548 proj_items: &[ProjectionItem],
3549) -> Result<Vec<Vec<Value>>> {
3550 let template: Vec<Option<AggState>> = proj_items
3554 .iter()
3555 .map(|i| match &i.kind {
3556 ProjectionKind::Aggregate(call) => Some(AggState::new(call)),
3557 ProjectionKind::Column { .. } => None,
3558 })
3559 .collect();
3560
3561 let mut keys: Vec<Vec<DistinctKey>> = Vec::new();
3567 let mut group_states: Vec<Vec<Option<AggState>>> = Vec::new();
3568 let mut group_key_values: Vec<Vec<Value>> = Vec::new();
3569
3570 for scope in scopes {
3571 let mut key_values: Vec<Value> = Vec::with_capacity(group_by.len());
3572 let mut key: Vec<DistinctKey> = Vec::with_capacity(group_by.len());
3573 for g in group_by {
3574 let v = scope.lookup(g.qualifier.as_deref(), &g.name)?;
3575 key.push(DistinctKey::from_value(&v));
3576 key_values.push(v);
3577 }
3578 let idx = match keys.iter().position(|k| k == &key) {
3579 Some(i) => i,
3580 None => {
3581 keys.push(key);
3582 group_states.push(template.clone());
3583 group_key_values.push(key_values);
3584 keys.len() - 1
3585 }
3586 };
3587
3588 for (slot, item) in proj_items.iter().enumerate() {
3589 if let ProjectionKind::Aggregate(call) = &item.kind {
3590 let v = match &call.arg {
3591 AggregateArg::Star => Value::Null,
3592 AggregateArg::Column { qualifier, name } => {
3593 scope.lookup(qualifier.as_deref(), name)?
3594 }
3595 };
3596 if let Some(state) = group_states[idx][slot].as_mut() {
3597 state.update(&v)?;
3598 }
3599 }
3600 }
3601 }
3602
3603 if keys.is_empty() && group_by.is_empty() {
3609 keys.push(Vec::new());
3612 group_states.push(template.clone());
3613 group_key_values.push(Vec::new());
3614 }
3615
3616 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(keys.len());
3618 for (group_idx, _) in keys.iter().enumerate() {
3619 let mut row: Vec<Value> = Vec::with_capacity(proj_items.len());
3620 for (slot, item) in proj_items.iter().enumerate() {
3621 match &item.kind {
3622 ProjectionKind::Column { qualifier, name: c } => {
3623 let pos = group_by
3628 .iter()
3629 .position(|g| g.matches_column(qualifier.as_deref(), c))
3630 .ok_or_else(|| {
3631 SQLRiteError::Internal(format!(
3632 "column '{c}' must appear in GROUP BY or be used in an \
3633 aggregate function"
3634 ))
3635 })?;
3636 row.push(group_key_values[group_idx][pos].clone());
3637 }
3638 ProjectionKind::Aggregate(_) => {
3639 let state = group_states[group_idx][slot]
3640 .as_ref()
3641 .expect("aggregate slot has state");
3642 row.push(state.finalize());
3643 }
3644 }
3645 }
3646 rows.push(row);
3647 }
3648 Ok(rows)
3649}
3650
3651struct GroupRowScope<'a> {
3661 columns: &'a [String],
3662 values: &'a [Value],
3663}
3664
3665impl RowScope for GroupRowScope<'_> {
3666 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
3667 let _ = qualifier;
3670 self.columns
3671 .iter()
3672 .position(|c| c.eq_ignore_ascii_case(col))
3673 .map(|i| self.values[i].clone())
3674 .ok_or_else(|| {
3675 SQLRiteError::Internal(format!(
3676 "HAVING references '{col}', which is neither a GROUP BY column nor an \
3677 aggregate in scope"
3678 ))
3679 })
3680 }
3681
3682 fn single_table_view(&self) -> Option<(&Table, i64)> {
3683 None
3684 }
3685
3686 fn scope_name(&self) -> Option<&str> {
3687 None
3690 }
3691}
3692
3693fn lower_having_expr(expr: &Expr, items: &mut Vec<ProjectionItem>) -> Result<Expr> {
3701 Ok(match expr {
3702 Expr::Function(func) => {
3703 let is_aggregate = matches!(
3704 func.name.0.as_slice(),
3705 [ObjectNamePart::Identifier(ident)] if AggregateFn::from_name(&ident.value).is_some()
3706 );
3707 if !is_aggregate {
3708 return Ok(expr.clone());
3709 }
3710 let call = parse_aggregate_call(func)?;
3711 let display = call.display_name();
3712 let already_known = items
3718 .iter()
3719 .any(|i| i.output_name().eq_ignore_ascii_case(&display));
3720 if !already_known {
3721 items.push(ProjectionItem {
3722 kind: ProjectionKind::Aggregate(call),
3723 alias: None,
3724 });
3725 }
3726 Expr::Identifier(Ident::new(display))
3727 }
3728 Expr::Nested(inner) => Expr::Nested(Box::new(lower_having_expr(inner, items)?)),
3729 Expr::UnaryOp { op, expr: inner } => Expr::UnaryOp {
3730 op: *op,
3731 expr: Box::new(lower_having_expr(inner, items)?),
3732 },
3733 Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
3734 left: Box::new(lower_having_expr(left, items)?),
3735 op: op.clone(),
3736 right: Box::new(lower_having_expr(right, items)?),
3737 },
3738 Expr::IsNull(inner) => Expr::IsNull(Box::new(lower_having_expr(inner, items)?)),
3739 Expr::IsNotNull(inner) => Expr::IsNotNull(Box::new(lower_having_expr(inner, items)?)),
3740 Expr::InList {
3741 expr: lhs,
3742 list,
3743 negated,
3744 } => Expr::InList {
3745 expr: Box::new(lower_having_expr(lhs, items)?),
3746 list: list
3747 .iter()
3748 .map(|e| lower_having_expr(e, items))
3749 .collect::<Result<Vec<_>>>()?,
3750 negated: *negated,
3751 },
3752 Expr::Like {
3753 negated,
3754 any,
3755 expr: lhs,
3756 pattern,
3757 escape_char,
3758 } => Expr::Like {
3759 negated: *negated,
3760 any: *any,
3761 expr: Box::new(lower_having_expr(lhs, items)?),
3762 pattern: Box::new(lower_having_expr(pattern, items)?),
3763 escape_char: escape_char.clone(),
3764 },
3765 Expr::ILike {
3766 negated,
3767 any,
3768 expr: lhs,
3769 pattern,
3770 escape_char,
3771 } => Expr::ILike {
3772 negated: *negated,
3773 any: *any,
3774 expr: Box::new(lower_having_expr(lhs, items)?),
3775 pattern: Box::new(lower_having_expr(pattern, items)?),
3776 escape_char: escape_char.clone(),
3777 },
3778 other => other.clone(),
3781 })
3782}
3783
3784fn filter_groups_by_having(
3788 rows: Vec<Vec<Value>>,
3789 having: &Expr,
3790 columns: &[String],
3791) -> Result<Vec<Vec<Value>>> {
3792 let mut out = Vec::with_capacity(rows.len());
3793 for row in rows {
3794 let scope = GroupRowScope {
3795 columns,
3796 values: &row,
3797 };
3798 let keep = match eval_expr_scope(having, &scope)? {
3799 Value::Bool(b) => b,
3800 Value::Null => false,
3801 Value::Integer(i) => i != 0,
3802 other => {
3803 return Err(SQLRiteError::Internal(format!(
3804 "HAVING clause must evaluate to boolean, got {}",
3805 other.to_display_string()
3806 )));
3807 }
3808 };
3809 if keep {
3810 out.push(row);
3811 }
3812 }
3813 Ok(out)
3814}
3815
3816fn dedupe_rows(rows: Vec<Vec<Value>>) -> Vec<Vec<Value>> {
3820 use std::collections::HashSet;
3821 let mut seen: HashSet<Vec<DistinctKey>> = HashSet::new();
3822 let mut out = Vec::with_capacity(rows.len());
3823 for row in rows {
3824 let key: Vec<DistinctKey> = row.iter().map(DistinctKey::from_value).collect();
3825 if seen.insert(key) {
3826 out.push(row);
3827 }
3828 }
3829 out
3830}
3831
3832fn sort_output_rows(
3836 rows: &mut [Vec<Value>],
3837 columns: &[String],
3838 proj_items: &[ProjectionItem],
3839 order: &OrderByClause,
3840) -> Result<()> {
3841 let target_idx = resolve_order_by_index(&order.expr, columns, proj_items)?;
3842 rows.sort_by(|a, b| {
3843 let va = &a[target_idx];
3844 let vb = &b[target_idx];
3845 let ord = compare_values(Some(va), Some(vb));
3846 if order.ascending { ord } else { ord.reverse() }
3847 });
3848 Ok(())
3849}
3850
3851fn resolve_order_by_index(
3854 expr: &Expr,
3855 columns: &[String],
3856 proj_items: &[ProjectionItem],
3857) -> Result<usize> {
3858 let target_name: Option<String> = match expr {
3860 Expr::Identifier(ident) => Some(ident.value.clone()),
3861 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
3862 Expr::Function(_) => None,
3863 Expr::Nested(inner) => return resolve_order_by_index(inner, columns, proj_items),
3864 other => {
3865 return Err(SQLRiteError::NotImplemented(format!(
3866 "ORDER BY expression not supported on aggregating queries: {other:?}"
3867 )));
3868 }
3869 };
3870 if let Some(name) = target_name {
3871 if let Some(i) = columns.iter().position(|c| c.eq_ignore_ascii_case(&name)) {
3872 return Ok(i);
3873 }
3874 return Err(SQLRiteError::Internal(format!(
3875 "ORDER BY references unknown column '{name}' in the SELECT output"
3876 )));
3877 }
3878 if let Expr::Function(func) = expr {
3886 let user_disp = format_function_display(func, true);
3887 for (i, item) in proj_items.iter().enumerate() {
3888 if let ProjectionKind::Aggregate(call) = &item.kind
3889 && call.display_name().eq_ignore_ascii_case(&user_disp)
3890 {
3891 return Ok(i);
3892 }
3893 }
3894 let user_disp_unqualified = format_function_display(func, false);
3895 for (i, item) in proj_items.iter().enumerate() {
3896 if let ProjectionKind::Aggregate(call) = &item.kind
3897 && call
3898 .display_name_unqualified()
3899 .eq_ignore_ascii_case(&user_disp_unqualified)
3900 {
3901 return Ok(i);
3902 }
3903 }
3904 return Err(SQLRiteError::Internal(format!(
3905 "ORDER BY references aggregate '{user_disp}' that isn't in the SELECT output"
3906 )));
3907 }
3908 Err(SQLRiteError::Internal(
3909 "ORDER BY expression could not be resolved against the output columns".to_string(),
3910 ))
3911}
3912
3913fn format_function_display(func: &sqlparser::ast::Function, qualified: bool) -> String {
3919 let name = match func.name.0.as_slice() {
3920 [ObjectNamePart::Identifier(ident)] => ident.value.to_uppercase(),
3921 _ => format!("{:?}", func.name).to_uppercase(),
3922 };
3923 let inner = match &func.args {
3924 FunctionArguments::List(l) => {
3925 let distinct = matches!(
3926 l.duplicate_treatment,
3927 Some(sqlparser::ast::DuplicateTreatment::Distinct)
3928 );
3929 let arg = l.args.first().map(|a| match a {
3930 FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => "*".to_string(),
3931 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(i))) => i.value.clone(),
3932 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::CompoundIdentifier(parts))) => {
3933 if qualified {
3934 parts
3935 .iter()
3936 .map(|p| p.value.clone())
3937 .collect::<Vec<_>>()
3938 .join(".")
3939 } else {
3940 parts.last().map(|p| p.value.clone()).unwrap_or_default()
3941 }
3942 }
3943 _ => String::new(),
3944 });
3945 match (distinct, arg) {
3946 (true, Some(a)) if a != "*" => format!("DISTINCT {a}"),
3947 (_, Some(a)) => a,
3948 _ => String::new(),
3949 }
3950 }
3951 _ => String::new(),
3952 };
3953 format!("{name}({inner})")
3954}
3955
3956fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
3957 use sqlparser::ast::Value as AstValue;
3958 match v {
3959 AstValue::Number(n, _) => {
3960 if let Ok(i) = n.parse::<i64>() {
3961 Ok(Value::Integer(i))
3962 } else if let Ok(f) = n.parse::<f64>() {
3963 Ok(Value::Real(f))
3964 } else {
3965 Err(SQLRiteError::Internal(format!(
3966 "could not parse numeric literal '{n}'"
3967 )))
3968 }
3969 }
3970 AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
3971 AstValue::Boolean(b) => Ok(Value::Bool(*b)),
3972 AstValue::Null => Ok(Value::Null),
3973 other => Err(SQLRiteError::NotImplemented(format!(
3974 "unsupported literal value: {other:?}"
3975 ))),
3976 }
3977}
3978
3979#[cfg(test)]
3980mod tests {
3981 use super::*;
3982
3983 fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
3990 (a - b).abs() < eps
3991 }
3992
3993 #[test]
3994 fn vec_distance_l2_identical_is_zero() {
3995 let v = vec![0.1, 0.2, 0.3];
3996 assert_eq!(vec_distance_l2(&v, &v), 0.0);
3997 }
3998
3999 #[test]
4000 fn vec_distance_l2_unit_basis_is_sqrt2() {
4001 let a = vec![1.0, 0.0];
4003 let b = vec![0.0, 1.0];
4004 assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
4005 }
4006
4007 #[test]
4008 fn vec_distance_l2_known_value() {
4009 let a = vec![0.0, 0.0, 0.0];
4011 let b = vec![3.0, 4.0, 0.0];
4012 assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
4013 }
4014
4015 #[test]
4016 fn vec_distance_cosine_identical_is_zero() {
4017 let v = vec![0.1, 0.2, 0.3];
4018 let d = vec_distance_cosine(&v, &v).unwrap();
4019 assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
4020 }
4021
4022 #[test]
4023 fn vec_distance_cosine_orthogonal_is_one() {
4024 let a = vec![1.0, 0.0];
4027 let b = vec![0.0, 1.0];
4028 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
4029 }
4030
4031 #[test]
4032 fn vec_distance_cosine_opposite_is_two() {
4033 let a = vec![1.0, 0.0, 0.0];
4035 let b = vec![-1.0, 0.0, 0.0];
4036 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
4037 }
4038
4039 #[test]
4040 fn vec_distance_cosine_zero_magnitude_errors() {
4041 let a = vec![0.0, 0.0];
4043 let b = vec![1.0, 0.0];
4044 let err = vec_distance_cosine(&a, &b).unwrap_err();
4045 assert!(format!("{err}").contains("zero-magnitude"));
4046 }
4047
4048 #[test]
4049 fn vec_distance_dot_negates() {
4050 let a = vec![1.0, 2.0, 3.0];
4052 let b = vec![4.0, 5.0, 6.0];
4053 assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
4054 }
4055
4056 #[test]
4057 fn vec_distance_dot_orthogonal_is_zero() {
4058 let a = vec![1.0, 0.0];
4060 let b = vec![0.0, 1.0];
4061 assert_eq!(vec_distance_dot(&a, &b), 0.0);
4062 }
4063
4064 #[test]
4065 fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
4066 let a = vec![0.6f32, 0.8]; let b = vec![0.8f32, 0.6]; let dot = vec_distance_dot(&a, &b);
4072 let cos = vec_distance_cosine(&a, &b).unwrap();
4073 assert!(approx_eq(dot, cos - 1.0, 1e-5));
4074 }
4075
4076 use crate::sql::db::database::Database;
4081 use crate::sql::dialect::SqlriteDialect;
4082 use crate::sql::parser::select::SelectQuery;
4083 use sqlparser::parser::Parser;
4084
4085 fn seed_score_table(n: usize) -> Database {
4098 let mut db = Database::new("tempdb".to_string());
4099 crate::sql::process_command(
4100 "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
4101 &mut db,
4102 )
4103 .expect("create");
4104 for i in 0..n {
4105 let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
4109 let sql = format!("INSERT INTO docs (score) VALUES ({score});");
4110 crate::sql::process_command(&sql, &mut db).expect("insert");
4111 }
4112 db
4113 }
4114
4115 fn parse_select(sql: &str) -> SelectQuery {
4119 let dialect = SqlriteDialect::new();
4120 let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
4121 let stmt = ast.pop().expect("one statement");
4122 SelectQuery::new(&stmt).expect("select-query")
4123 }
4124
4125 #[test]
4126 fn topk_matches_full_sort_asc() {
4127 let db = seed_score_table(200);
4130 let table = db.get_table("docs".to_string()).unwrap();
4131 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
4132 let order = q.order_by.as_ref().unwrap();
4133 let all_rowids = table.rowids();
4134
4135 let mut full = all_rowids.clone();
4137 sort_rowids(&mut full, table, order, "docs").unwrap();
4138 full.truncate(10);
4139
4140 let topk = select_topk(&all_rowids, table, order, 10, "docs").unwrap();
4142
4143 assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
4144 }
4145
4146 #[test]
4147 fn topk_matches_full_sort_desc() {
4148 let db = seed_score_table(200);
4150 let table = db.get_table("docs".to_string()).unwrap();
4151 let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
4152 let order = q.order_by.as_ref().unwrap();
4153 let all_rowids = table.rowids();
4154
4155 let mut full = all_rowids.clone();
4156 sort_rowids(&mut full, table, order, "docs").unwrap();
4157 full.truncate(10);
4158
4159 let topk = select_topk(&all_rowids, table, order, 10, "docs").unwrap();
4160
4161 assert_eq!(
4162 topk, full,
4163 "top-k DESC via heap should match full-sort+truncate"
4164 );
4165 }
4166
4167 #[test]
4168 fn topk_k_larger_than_n_returns_everything_sorted() {
4169 let db = seed_score_table(50);
4174 let table = db.get_table("docs".to_string()).unwrap();
4175 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
4176 let order = q.order_by.as_ref().unwrap();
4177 let topk = select_topk(&table.rowids(), table, order, 1000, "docs").unwrap();
4178 assert_eq!(topk.len(), 50);
4179 let scores: Vec<f64> = topk
4181 .iter()
4182 .filter_map(|r| match table.get_value("score", *r) {
4183 Some(Value::Real(f)) => Some(f),
4184 _ => None,
4185 })
4186 .collect();
4187 assert!(scores.windows(2).all(|w| w[0] <= w[1]));
4188 }
4189
4190 #[test]
4191 fn topk_k_zero_returns_empty() {
4192 let db = seed_score_table(10);
4193 let table = db.get_table("docs".to_string()).unwrap();
4194 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
4195 let order = q.order_by.as_ref().unwrap();
4196 let topk = select_topk(&table.rowids(), table, order, 0, "docs").unwrap();
4197 assert!(topk.is_empty());
4198 }
4199
4200 #[test]
4201 fn topk_empty_input_returns_empty() {
4202 let db = seed_score_table(0);
4203 let table = db.get_table("docs".to_string()).unwrap();
4204 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
4205 let order = q.order_by.as_ref().unwrap();
4206 let topk = select_topk(&[], table, order, 5, "docs").unwrap();
4207 assert!(topk.is_empty());
4208 }
4209
4210 #[test]
4211 fn topk_works_through_select_executor_with_distance_function() {
4212 let mut db = Database::new("tempdb".to_string());
4216 crate::sql::process_command(
4217 "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
4218 &mut db,
4219 )
4220 .unwrap();
4221 for v in &[
4228 "[1.0, 0.0]",
4229 "[2.0, 0.0]",
4230 "[0.0, 3.0]",
4231 "[1.0, 4.0]",
4232 "[10.0, 10.0]",
4233 ] {
4234 crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
4235 .unwrap();
4236 }
4237 let resp = crate::sql::process_command(
4238 "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
4239 &mut db,
4240 )
4241 .unwrap();
4242 assert!(resp.contains("3 rows returned"), "got: {resp}");
4245 }
4246
4247 #[test]
4270 #[ignore]
4271 fn topk_benchmark() {
4272 use std::time::Instant;
4273 const N: usize = 10_000;
4274 const K: usize = 10;
4275
4276 let db = seed_score_table(N);
4277 let table = db.get_table("docs".to_string()).unwrap();
4278 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
4279 let order = q.order_by.as_ref().unwrap();
4280 let all_rowids = table.rowids();
4281
4282 let t0 = Instant::now();
4284 let _topk = select_topk(&all_rowids, table, order, K, "docs").unwrap();
4285 let heap_dur = t0.elapsed();
4286
4287 let t1 = Instant::now();
4289 let mut full = all_rowids.clone();
4290 sort_rowids(&mut full, table, order, "docs").unwrap();
4291 full.truncate(K);
4292 let sort_dur = t1.elapsed();
4293
4294 let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
4295 println!("\n--- topk_benchmark (N={N}, k={K}) ---");
4296 println!(" bounded heap: {heap_dur:?}");
4297 println!(" full sort+trunc: {sort_dur:?}");
4298 println!(" speedup ratio: {ratio:.2}×");
4299
4300 assert!(
4307 ratio > 1.4,
4308 "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
4309 );
4310 }
4311
4312 fn run_select(db: &mut Database, sql: &str) -> String {
4320 crate::sql::process_command(sql, db).expect("select")
4321 }
4322
4323 #[test]
4324 fn where_is_null_returns_null_rows() {
4325 let mut db = Database::new("t".to_string());
4326 crate::sql::process_command(
4327 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
4328 &mut db,
4329 )
4330 .unwrap();
4331 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
4332 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
4333 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
4334 crate::sql::process_command("INSERT INTO t (id, n) VALUES (4, NULL);", &mut db).unwrap();
4335
4336 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL;");
4337 assert!(
4338 response.contains("2 rows returned"),
4339 "IS NULL should return 2 rows, got: {response}"
4340 );
4341 }
4342
4343 #[test]
4344 fn where_is_not_null_returns_non_null_rows() {
4345 let mut db = Database::new("t".to_string());
4346 crate::sql::process_command(
4347 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
4348 &mut db,
4349 )
4350 .unwrap();
4351 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
4352 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
4353 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
4354
4355 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NOT NULL;");
4356 assert!(
4357 response.contains("2 rows returned"),
4358 "IS NOT NULL should return 2 rows, got: {response}"
4359 );
4360 }
4361
4362 #[test]
4363 fn where_is_null_on_indexed_column() {
4364 let mut db = Database::new("t".to_string());
4369 crate::sql::process_command(
4370 "CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT UNIQUE);",
4371 &mut db,
4372 )
4373 .unwrap();
4374 crate::sql::process_command("INSERT INTO t (id, name) VALUES (1, 'alice');", &mut db)
4375 .unwrap();
4376 crate::sql::process_command("INSERT INTO t (id, name) VALUES (2, NULL);", &mut db).unwrap();
4377 crate::sql::process_command("INSERT INTO t (id, name) VALUES (3, 'bob');", &mut db)
4378 .unwrap();
4379
4380 let null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NULL;");
4381 assert!(
4382 null_rows.contains("1 row returned"),
4383 "indexed IS NULL should return 1 row, got: {null_rows}"
4384 );
4385 let not_null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NOT NULL;");
4386 assert!(
4387 not_null_rows.contains("2 rows returned"),
4388 "indexed IS NOT NULL should return 2 rows, got: {not_null_rows}"
4389 );
4390 }
4391
4392 #[test]
4393 fn where_is_null_works_on_omitted_column() {
4394 let mut db = Database::new("t".to_string());
4398 crate::sql::process_command(
4399 "CREATE TABLE t (id INTEGER PRIMARY KEY, qty INTEGER, label TEXT);",
4400 &mut db,
4401 )
4402 .unwrap();
4403 crate::sql::process_command(
4404 "INSERT INTO t (id, qty, label) VALUES (1, 7, 'a');",
4405 &mut db,
4406 )
4407 .unwrap();
4408 crate::sql::process_command("INSERT INTO t (id, label) VALUES (2, 'b');", &mut db).unwrap();
4410
4411 let response = run_select(&mut db, "SELECT id FROM t WHERE qty IS NULL;");
4412 assert!(
4413 response.contains("1 row returned"),
4414 "IS NULL should match the omitted-column row, got: {response}"
4415 );
4416 }
4417
4418 fn seed_sqlr2() -> Database {
4426 let mut db = Database::new("t".to_string());
4427 crate::sql::process_command(
4428 "CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT);",
4429 &mut db,
4430 )
4431 .unwrap();
4432 crate::sql::process_command("INSERT INTO t (id, name) VALUES (1, 'alice');", &mut db)
4433 .unwrap();
4434 crate::sql::process_command("INSERT INTO t (id, name) VALUES (2, 'bob');", &mut db)
4435 .unwrap();
4436 db
4437 }
4438
4439 #[test]
4440 fn where_unknown_column_errors_single_table() {
4441 let mut db = seed_sqlr2();
4442 let res = crate::sql::process_command("SELECT id FROM t WHERE typo IS NULL;", &mut db);
4443 let err = res.expect_err("WHERE on an unknown column must error, not match via NULL");
4444 assert!(
4445 err.to_string().contains("does not exist"),
4446 "expected unknown-column error, got: {err}"
4447 );
4448 }
4449
4450 #[test]
4451 fn order_by_unknown_column_errors_single_table() {
4452 let mut db = seed_sqlr2();
4453 let res = crate::sql::process_command("SELECT id FROM t ORDER BY typo;", &mut db);
4454 assert!(
4455 res.is_err(),
4456 "ORDER BY on an unknown column must error, not sort by NULL"
4457 );
4458 }
4459
4460 #[test]
4461 fn update_with_unknown_column_in_where_errors_and_mutates_nothing() {
4462 let mut db = seed_sqlr2();
4463 let res =
4464 crate::sql::process_command("UPDATE t SET name = 'x' WHERE typo IS NULL;", &mut db);
4465 assert!(
4466 res.is_err(),
4467 "UPDATE with a typo'd WHERE column must error, not update every row"
4468 );
4469 let rows = run_select(&mut db, "SELECT id FROM t WHERE name = 'x';");
4470 assert!(
4471 rows.contains("0 rows returned"),
4472 "no row may be updated when the WHERE errors, got: {rows}"
4473 );
4474 }
4475
4476 #[test]
4477 fn delete_with_unknown_column_in_where_errors_and_deletes_nothing() {
4478 let mut db = seed_sqlr2();
4479 let res = crate::sql::process_command("DELETE FROM t WHERE typo IS NULL;", &mut db);
4480 assert!(
4481 res.is_err(),
4482 "DELETE with a typo'd WHERE column must error, not delete every row"
4483 );
4484 let rows = run_select(&mut db, "SELECT id FROM t;");
4485 assert!(
4486 rows.contains("2 rows returned"),
4487 "no row may be deleted when the WHERE errors, got: {rows}"
4488 );
4489 }
4490
4491 #[test]
4492 fn where_is_null_combines_with_and_or() {
4493 let mut db = Database::new("t".to_string());
4497 crate::sql::process_command(
4498 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
4499 &mut db,
4500 )
4501 .unwrap();
4502 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, NULL);", &mut db).unwrap();
4503 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
4504 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
4505
4506 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL AND id > 1;");
4507 assert!(
4508 response.contains("1 row returned"),
4509 "IS NULL combined with AND should match exactly row 2, got: {response}"
4510 );
4511 }
4512
4513 fn assert_unknown_qualifier(db: &mut Database, sql: &str, qualifier: &str) {
4521 let err = crate::sql::process_command(sql, db)
4522 .expect_err("a bogus table qualifier must error, not be ignored");
4523 assert!(
4524 err.to_string()
4525 .contains(&format!("unknown table qualifier '{qualifier}'")),
4526 "expected unknown-qualifier error for `{sql}`, got: {err}"
4527 );
4528 }
4529
4530 #[test]
4531 fn qualifier_matching_table_name_works() {
4532 let mut db = seed_sqlr2();
4533 let rows = run_select(&mut db, "SELECT t.id FROM t WHERE t.id = 1 ORDER BY t.id;");
4534 assert!(rows.contains("1 row returned"), "got: {rows}");
4535 }
4536
4537 #[test]
4538 fn qualifier_matching_alias_works() {
4539 let mut db = seed_sqlr2();
4540 let rows = run_select(
4541 &mut db,
4542 "SELECT a.id FROM t AS a WHERE a.id = 1 ORDER BY a.id;",
4543 );
4544 assert!(rows.contains("1 row returned"), "got: {rows}");
4545 }
4546
4547 #[test]
4548 fn qualifier_match_is_case_insensitive() {
4549 let mut db = seed_sqlr2();
4550 let rows = run_select(&mut db, "SELECT T.id FROM t WHERE T.id = 1;");
4551 assert!(rows.contains("1 row returned"), "got: {rows}");
4552 }
4553
4554 #[test]
4555 fn unknown_qualifier_in_projection_errors() {
4556 let mut db = seed_sqlr2();
4557 assert_unknown_qualifier(&mut db, "SELECT x.id FROM t;", "x");
4558 }
4559
4560 #[test]
4561 fn unknown_qualifier_in_where_errors() {
4562 let mut db = seed_sqlr2();
4563 assert_unknown_qualifier(&mut db, "SELECT id FROM t WHERE bogus.id = 1;", "bogus");
4564 }
4565
4566 #[test]
4567 fn unknown_qualifier_in_indexed_where_errors() {
4568 let mut db = seed_sqlr2();
4572 crate::sql::process_command("CREATE INDEX idx_name ON t (name);", &mut db).unwrap();
4573 assert_unknown_qualifier(
4574 &mut db,
4575 "SELECT id FROM t WHERE bogus.name = 'alice';",
4576 "bogus",
4577 );
4578 }
4579
4580 #[test]
4581 fn unknown_qualifier_in_order_by_errors() {
4582 let mut db = seed_sqlr2();
4583 let res = crate::sql::process_command("SELECT id FROM t ORDER BY x.id;", &mut db);
4584 let err = res.expect_err("ORDER BY with a bogus qualifier must error");
4585 assert!(
4586 err.to_string().contains("unknown table qualifier 'x'"),
4587 "got: {err}"
4588 );
4589 }
4590
4591 #[test]
4592 fn alias_shadows_table_name_as_qualifier() {
4593 let mut db = seed_sqlr2();
4596 assert_unknown_qualifier(&mut db, "SELECT t.id FROM t AS a;", "t");
4597 assert_unknown_qualifier(&mut db, "SELECT id FROM t AS a WHERE t.id = 1;", "t");
4598 }
4599
4600 #[test]
4601 fn unknown_qualifier_in_group_by_and_aggregates_errors() {
4602 let mut db = seed_sqlr2();
4603 assert_unknown_qualifier(&mut db, "SELECT COUNT(x.id) FROM t;", "x");
4604 assert_unknown_qualifier(
4605 &mut db,
4606 "SELECT x.name, COUNT(*) FROM t GROUP BY x.name;",
4607 "x",
4608 );
4609 let rows = run_select(
4611 &mut db,
4612 "SELECT t.name, COUNT(t.id) FROM t GROUP BY t.name;",
4613 );
4614 assert!(rows.contains("2 rows returned"), "got: {rows}");
4615 }
4616
4617 #[test]
4618 fn update_unknown_qualifier_in_where_errors_and_mutates_nothing() {
4619 let mut db = seed_sqlr2();
4620 let res = crate::sql::process_command("UPDATE t SET name = 'x' WHERE x.id = 1;", &mut db);
4621 assert!(
4622 res.is_err(),
4623 "UPDATE with a bogus WHERE qualifier must error"
4624 );
4625 let rows = run_select(&mut db, "SELECT id FROM t WHERE name = 'x';");
4626 assert!(
4627 rows.contains("0 rows returned"),
4628 "no row may be updated when the WHERE errors, got: {rows}"
4629 );
4630 }
4631
4632 #[test]
4633 fn update_set_rhs_unknown_qualifier_errors() {
4634 let mut db = seed_sqlr2();
4635 assert_unknown_qualifier(&mut db, "UPDATE t SET name = x.name;", "x");
4636 }
4637
4638 #[test]
4639 fn update_with_alias_validates_against_alias() {
4640 let mut db = seed_sqlr2();
4641 crate::sql::process_command("UPDATE t AS a SET name = 'x' WHERE a.id = 1;", &mut db)
4643 .expect("alias qualifier must be accepted in UPDATE");
4644 let rows = run_select(&mut db, "SELECT id FROM t WHERE name = 'x';");
4645 assert!(rows.contains("1 row returned"), "got: {rows}");
4646 assert_unknown_qualifier(&mut db, "UPDATE t AS a SET name = 'y' WHERE t.id = 1;", "t");
4648 }
4649
4650 #[test]
4651 fn delete_unknown_qualifier_in_where_errors_and_deletes_nothing() {
4652 let mut db = seed_sqlr2();
4653 let res = crate::sql::process_command("DELETE FROM t WHERE x.id = 1;", &mut db);
4654 assert!(
4655 res.is_err(),
4656 "DELETE with a bogus WHERE qualifier must error"
4657 );
4658 let rows = run_select(&mut db, "SELECT id FROM t;");
4659 assert!(
4660 rows.contains("2 rows returned"),
4661 "no row may be deleted when the WHERE errors, got: {rows}"
4662 );
4663 }
4664
4665 #[test]
4666 fn delete_with_alias_validates_against_alias() {
4667 let mut db = seed_sqlr2();
4668 crate::sql::process_command("DELETE FROM t AS a WHERE a.id = 2;", &mut db)
4669 .expect("alias qualifier must be accepted in DELETE");
4670 let rows = run_select(&mut db, "SELECT id FROM t;");
4671 assert!(rows.contains("1 row returned"), "got: {rows}");
4672 assert_unknown_qualifier(&mut db, "DELETE FROM t AS a WHERE t.id = 1;", "t");
4673 }
4674
4675 fn seed_employees() -> Database {
4681 let mut db = Database::new("t".to_string());
4682 crate::sql::process_command(
4683 "CREATE TABLE emp (id INTEGER PRIMARY KEY, name TEXT, dept TEXT, salary INTEGER);",
4684 &mut db,
4685 )
4686 .unwrap();
4687 let rows = [
4688 "INSERT INTO emp (name, dept, salary) VALUES ('Alice', 'eng', 100);",
4689 "INSERT INTO emp (name, dept, salary) VALUES ('alex', 'eng', 120);",
4690 "INSERT INTO emp (name, dept, salary) VALUES ('Bob', 'eng', 100);",
4691 "INSERT INTO emp (name, dept, salary) VALUES ('Carol', 'sales', 90);",
4692 "INSERT INTO emp (name, dept, salary) VALUES ('Dave', 'sales', NULL);",
4693 "INSERT INTO emp (name, dept, salary) VALUES ('Eve', 'ops', 80);",
4694 ];
4695 for sql in rows {
4696 crate::sql::process_command(sql, &mut db).unwrap();
4697 }
4698 db
4699 }
4700
4701 fn run_rows(db: &Database, sql: &str) -> SelectResult {
4703 let q = parse_select(sql);
4704 execute_select_rows(q, db).expect("select")
4705 }
4706
4707 #[test]
4710 fn like_percent_prefix_case_insensitive() {
4711 let db = seed_employees();
4712 let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE 'a%';");
4713 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
4715 assert_eq!(names.len(), 2, "expected 2 rows, got {names:?}");
4716 assert!(names.contains(&"Alice".to_string()));
4717 assert!(names.contains(&"alex".to_string()));
4718 }
4719
4720 #[test]
4721 fn like_underscore_singlechar() {
4722 let db = seed_employees();
4723 let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE '_ve';");
4724 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
4726 assert_eq!(names, vec!["Eve".to_string()]);
4727 }
4728
4729 #[test]
4730 fn not_like_excludes_match() {
4731 let db = seed_employees();
4732 let r = run_rows(&db, "SELECT name FROM emp WHERE name NOT LIKE 'a%';");
4733 assert_eq!(r.rows.len(), 4);
4735 }
4736
4737 #[test]
4738 fn like_with_null_excludes_row() {
4739 let db = seed_employees();
4740 let r = run_rows(
4742 &db,
4743 "SELECT name FROM emp WHERE dept LIKE 'sales' AND salary IS NULL;",
4744 );
4745 assert_eq!(r.rows.len(), 1);
4746 assert_eq!(r.rows[0][0].to_display_string(), "Dave");
4747 }
4748
4749 #[test]
4752 fn in_list_positive() {
4753 let db = seed_employees();
4754 let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, 3, 5);");
4755 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
4756 assert_eq!(names.len(), 3);
4757 assert!(names.contains(&"Alice".to_string()));
4758 assert!(names.contains(&"Bob".to_string()));
4759 assert!(names.contains(&"Dave".to_string()));
4760 }
4761
4762 #[test]
4763 fn not_in_excludes_listed() {
4764 let db = seed_employees();
4765 let r = run_rows(&db, "SELECT name FROM emp WHERE id NOT IN (1, 2);");
4766 assert_eq!(r.rows.len(), 4);
4768 }
4769
4770 #[test]
4771 fn in_list_with_null_three_valued() {
4772 let db = seed_employees();
4773 let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, NULL);");
4776 assert_eq!(r.rows.len(), 1);
4777 assert_eq!(r.rows[0][0].to_display_string(), "Alice");
4778 }
4779
4780 #[test]
4783 fn distinct_single_column() {
4784 let db = seed_employees();
4785 let r = run_rows(&db, "SELECT DISTINCT dept FROM emp;");
4786 assert_eq!(r.rows.len(), 3);
4788 }
4789
4790 #[test]
4791 fn distinct_multi_column_with_null() {
4792 let db = seed_employees();
4793 let r = run_rows(&db, "SELECT DISTINCT dept, salary FROM emp;");
4795 assert_eq!(r.rows.len(), 5);
4797 }
4798
4799 #[test]
4802 fn count_star_no_groupby() {
4803 let db = seed_employees();
4804 let r = run_rows(&db, "SELECT COUNT(*) FROM emp;");
4805 assert_eq!(r.rows.len(), 1);
4806 assert_eq!(r.rows[0][0], Value::Integer(6));
4807 }
4808
4809 #[test]
4810 fn count_col_skips_nulls() {
4811 let db = seed_employees();
4812 let r = run_rows(&db, "SELECT COUNT(salary) FROM emp;");
4813 assert_eq!(r.rows[0][0], Value::Integer(5));
4815 }
4816
4817 #[test]
4818 fn count_distinct_dedupes_and_skips_nulls() {
4819 let db = seed_employees();
4820 let r = run_rows(&db, "SELECT COUNT(DISTINCT salary) FROM emp;");
4821 assert_eq!(r.rows[0][0], Value::Integer(4));
4823 }
4824
4825 #[test]
4826 fn sum_int_stays_integer() {
4827 let db = seed_employees();
4828 let r = run_rows(&db, "SELECT SUM(salary) FROM emp;");
4829 assert_eq!(r.rows[0][0], Value::Integer(490));
4831 }
4832
4833 #[test]
4834 fn avg_returns_real() {
4835 let db = seed_employees();
4836 let r = run_rows(&db, "SELECT AVG(salary) FROM emp;");
4837 match &r.rows[0][0] {
4839 Value::Real(v) => assert!((v - 98.0).abs() < 1e-9),
4840 other => panic!("expected Real, got {other:?}"),
4841 }
4842 }
4843
4844 #[test]
4845 fn min_max_skip_nulls() {
4846 let db = seed_employees();
4847 let r = run_rows(&db, "SELECT MIN(salary), MAX(salary) FROM emp;");
4848 assert_eq!(r.rows[0][0], Value::Integer(80));
4849 assert_eq!(r.rows[0][1], Value::Integer(120));
4850 }
4851
4852 #[test]
4853 fn aggregates_on_empty_table_emit_one_row() {
4854 let mut db = Database::new("t".to_string());
4855 crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
4856 let r = run_rows(
4857 &db,
4858 "SELECT COUNT(*), SUM(x), AVG(x), MIN(x), MAX(x) FROM t;",
4859 );
4860 assert_eq!(r.rows.len(), 1);
4861 assert_eq!(r.rows[0][0], Value::Integer(0));
4862 assert_eq!(r.rows[0][1], Value::Null);
4863 assert_eq!(r.rows[0][2], Value::Null);
4864 assert_eq!(r.rows[0][3], Value::Null);
4865 assert_eq!(r.rows[0][4], Value::Null);
4866 }
4867
4868 #[test]
4871 fn group_by_single_col_with_count() {
4872 let db = seed_employees();
4873 let r = run_rows(&db, "SELECT dept, COUNT(*) FROM emp GROUP BY dept;");
4874 assert_eq!(r.rows.len(), 3);
4875 let mut by_dept: std::collections::HashMap<String, i64> = Default::default();
4877 for row in &r.rows {
4878 let d = row[0].to_display_string();
4879 let c = match &row[1] {
4880 Value::Integer(i) => *i,
4881 v => panic!("expected Integer count, got {v:?}"),
4882 };
4883 by_dept.insert(d, c);
4884 }
4885 assert_eq!(by_dept["eng"], 3);
4886 assert_eq!(by_dept["sales"], 2);
4887 assert_eq!(by_dept["ops"], 1);
4888 }
4889
4890 #[test]
4891 fn group_by_with_where_filter() {
4892 let db = seed_employees();
4893 let r = run_rows(
4894 &db,
4895 "SELECT dept, SUM(salary) FROM emp WHERE salary > 80 GROUP BY dept;",
4896 );
4897 let by: std::collections::HashMap<String, i64> = r
4900 .rows
4901 .iter()
4902 .map(|row| {
4903 (
4904 row[0].to_display_string(),
4905 match &row[1] {
4906 Value::Integer(i) => *i,
4907 v => panic!("expected Integer sum, got {v:?}"),
4908 },
4909 )
4910 })
4911 .collect();
4912 assert_eq!(by.len(), 2);
4913 assert_eq!(by["eng"], 320);
4914 assert_eq!(by["sales"], 90);
4915 }
4916
4917 #[test]
4918 fn group_by_without_aggregates_is_distinct() {
4919 let db = seed_employees();
4920 let r = run_rows(&db, "SELECT dept FROM emp GROUP BY dept;");
4921 assert_eq!(r.rows.len(), 3);
4922 }
4923
4924 #[test]
4925 fn order_by_count_desc() {
4926 let db = seed_employees();
4927 let r = run_rows(
4928 &db,
4929 "SELECT dept, COUNT(*) AS n FROM emp GROUP BY dept ORDER BY n DESC LIMIT 2;",
4930 );
4931 assert_eq!(r.rows.len(), 2);
4932 assert_eq!(r.rows[0][0].to_display_string(), "eng");
4934 assert_eq!(r.rows[0][1], Value::Integer(3));
4935 }
4936
4937 #[test]
4938 fn order_by_aggregate_call_form() {
4939 let db = seed_employees();
4940 let r = run_rows(
4942 &db,
4943 "SELECT dept, COUNT(*) FROM emp GROUP BY dept ORDER BY COUNT(*) DESC;",
4944 );
4945 assert_eq!(r.rows.len(), 3);
4946 assert_eq!(r.rows[0][0].to_display_string(), "eng");
4947 }
4948
4949 #[test]
4950 fn group_by_invalid_bare_column_errors() {
4951 let mut db = Database::new("t".to_string());
4953 crate::sql::process_command(
4954 "CREATE TABLE t (id INTEGER PRIMARY KEY, dept TEXT, name TEXT);",
4955 &mut db,
4956 )
4957 .unwrap();
4958 let err = crate::sql::process_command("SELECT dept, name FROM t GROUP BY dept;", &mut db);
4959 assert!(err.is_err(), "should reject bare 'name' not in GROUP BY");
4960 }
4961
4962 #[test]
4963 fn aggregate_in_where_errors_friendly() {
4964 let mut db = Database::new("t".to_string());
4965 crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
4966 crate::sql::process_command("INSERT INTO t (x) VALUES (1);", &mut db).unwrap();
4967 let err = crate::sql::process_command("SELECT x FROM t WHERE COUNT(*) > 0;", &mut db);
4968 assert!(err.is_err(), "aggregates must not be allowed in WHERE");
4969 }
4970
4971 #[test]
4980 fn having_count_filters_groups() {
4981 let db = seed_employees();
4982 let r = run_rows(
4983 &db,
4984 "SELECT dept, COUNT(*) FROM emp GROUP BY dept HAVING COUNT(*) > 1;",
4985 );
4986 assert_eq!(r.columns, vec!["dept".to_string(), "COUNT(*)".to_string()]);
4989 let got: Vec<(String, i64)> = r
4990 .rows
4991 .iter()
4992 .map(|row| (row[0].to_display_string(), expect_int(&row[1])))
4993 .collect();
4994 assert_eq!(got, vec![("eng".to_string(), 3), ("sales".to_string(), 2)]);
4995 }
4996
4997 #[test]
4998 fn having_sum_threshold() {
4999 let db = seed_employees();
5000 let r = run_rows(
5001 &db,
5002 "SELECT dept, SUM(salary) FROM emp GROUP BY dept HAVING SUM(salary) > 100;",
5003 );
5004 assert_eq!(r.rows.len(), 1);
5005 assert_eq!(r.rows[0][0].to_display_string(), "eng");
5006 assert_eq!(r.rows[0][1], Value::Integer(320));
5007 }
5008
5009 #[test]
5010 fn having_references_aggregate_alias() {
5011 let db = seed_employees();
5012 let r = run_rows(
5013 &db,
5014 "SELECT dept, SUM(salary) AS total FROM emp GROUP BY dept HAVING total > 100;",
5015 );
5016 assert_eq!(r.columns, vec!["dept".to_string(), "total".to_string()]);
5017 assert_eq!(r.rows.len(), 1);
5018 assert_eq!(r.rows[0][1], Value::Integer(320));
5019 }
5020
5021 #[test]
5022 fn having_aggregate_not_in_projection() {
5023 let db = seed_employees();
5024 let r = run_rows(
5027 &db,
5028 "SELECT dept FROM emp GROUP BY dept HAVING COUNT(*) > 1;",
5029 );
5030 assert_eq!(r.columns, vec!["dept".to_string()]);
5031 let depts: Vec<String> = r
5032 .rows
5033 .iter()
5034 .map(|row| row[0].to_display_string())
5035 .collect();
5036 assert_eq!(depts, vec!["eng".to_string(), "sales".to_string()]);
5037 }
5038
5039 #[test]
5040 fn having_group_key_not_in_projection() {
5041 let db = seed_employees();
5042 let r = run_rows(
5044 &db,
5045 "SELECT COUNT(*) FROM emp GROUP BY dept HAVING dept = 'eng';",
5046 );
5047 assert_eq!(r.columns, vec!["COUNT(*)".to_string()]);
5048 assert_eq!(r.rows.len(), 1);
5049 assert_eq!(r.rows[0][0], Value::Integer(3));
5050 }
5051
5052 #[test]
5053 fn having_compound_and_predicate() {
5054 let db = seed_employees();
5055 let r = run_rows(
5056 &db,
5057 "SELECT dept FROM emp GROUP BY dept \
5058 HAVING COUNT(*) > 1 AND SUM(salary) > 100;",
5059 );
5060 assert_eq!(r.rows.len(), 1);
5062 assert_eq!(r.rows[0][0].to_display_string(), "eng");
5063 }
5064
5065 #[test]
5066 fn having_composes_with_order_by_and_limit() {
5067 let db = seed_employees();
5068 let r = run_rows(
5069 &db,
5070 "SELECT dept, COUNT(*) AS n FROM emp GROUP BY dept \
5071 HAVING n >= 1 ORDER BY n DESC LIMIT 2;",
5072 );
5073 let got: Vec<(String, i64)> = r
5074 .rows
5075 .iter()
5076 .map(|row| (row[0].to_display_string(), expect_int(&row[1])))
5077 .collect();
5078 assert_eq!(got, vec![("eng".to_string(), 3), ("sales".to_string(), 2)]);
5079 }
5080
5081 #[test]
5082 fn having_can_exclude_every_group() {
5083 let db = seed_employees();
5084 let r = run_rows(
5085 &db,
5086 "SELECT dept FROM emp GROUP BY dept HAVING COUNT(*) > 99;",
5087 );
5088 assert_eq!(r.rows.len(), 0);
5089 }
5090
5091 #[test]
5092 fn having_null_aggregate_collapses_to_false() {
5093 let mut db = seed_employees();
5094 crate::sql::process_command(
5097 "INSERT INTO emp (name, dept, salary) VALUES ('Zoe', 'mkt', NULL);",
5098 &mut db,
5099 )
5100 .unwrap();
5101 let r = run_rows(
5102 &db,
5103 "SELECT dept FROM emp GROUP BY dept HAVING SUM(salary) > 0;",
5104 );
5105 let depts: Vec<String> = r
5106 .rows
5107 .iter()
5108 .map(|row| row[0].to_display_string())
5109 .collect();
5110 assert_eq!(
5111 depts,
5112 vec!["eng".to_string(), "sales".to_string(), "ops".to_string()],
5113 "mkt (all-NULL salaries) must be filtered out"
5114 );
5115 }
5116
5117 #[test]
5118 fn having_lowercase_function_form_matches() {
5119 let db = seed_employees();
5120 let r = run_rows(
5121 &db,
5122 "SELECT dept FROM emp GROUP BY dept HAVING count(*) > 1;",
5123 );
5124 assert_eq!(r.rows.len(), 2);
5125 }
5126
5127 #[test]
5128 fn having_without_group_by_is_rejected() {
5129 let mut db = seed_employees();
5130 let err =
5131 crate::sql::process_command("SELECT COUNT(*) FROM emp HAVING COUNT(*) > 0;", &mut db);
5132 match err {
5133 Err(SQLRiteError::NotImplemented(msg)) => assert!(
5134 msg.contains("HAVING without GROUP BY"),
5135 "unexpected message: {msg}"
5136 ),
5137 other => panic!("expected NotImplemented, got {other:?}"),
5138 }
5139 }
5140
5141 #[test]
5142 fn having_unknown_column_is_rejected() {
5143 let mut db = seed_employees();
5144 let err = crate::sql::process_command(
5147 "SELECT dept, COUNT(*) FROM emp GROUP BY dept HAVING name = 'Alice';",
5148 &mut db,
5149 );
5150 match err {
5151 Err(e) => {
5152 let msg = e.to_string();
5153 assert!(
5154 msg.contains("HAVING references"),
5155 "unexpected message: {msg}"
5156 );
5157 }
5158 Ok(_) => panic!("HAVING on an out-of-scope column must error"),
5159 }
5160 }
5161
5162 #[test]
5163 fn having_over_join_filters_groups_for_all_flavors() {
5164 for flavor in ["INNER", "LEFT OUTER", "RIGHT OUTER", "FULL OUTER"] {
5168 let sql = format!(
5169 "SELECT customers.name, COUNT(*) FROM customers \
5170 {flavor} JOIN orders ON customers.id = orders.customer_id \
5171 GROUP BY customers.name HAVING COUNT(*) > 1;"
5172 );
5173 let db = seed_join_fixture();
5174 let r = run_rows(&db, &sql);
5175 assert_eq!(r.rows.len(), 1, "{flavor}: only Alice has >1 order");
5176 assert_eq!(r.rows[0][0].to_display_string(), "Alice", "{flavor}");
5177 assert_eq!(expect_int(&r.rows[0][1]), 2, "{flavor}");
5178 }
5179 }
5180
5181 fn expect_int(v: &Value) -> i64 {
5183 match v {
5184 Value::Integer(i) => *i,
5185 other => panic!("expected integer value, got {other:?}"),
5186 }
5187 }
5188
5189 fn seed_join_fixture() -> Database {
5200 let mut db = Database::new("t".to_string());
5201 for sql in [
5202 "CREATE TABLE customers (id INTEGER PRIMARY KEY, name TEXT);",
5203 "CREATE TABLE orders (id INTEGER PRIMARY KEY, customer_id INTEGER, amount INTEGER);",
5204 "INSERT INTO customers (name) VALUES ('Alice');",
5205 "INSERT INTO customers (name) VALUES ('Bob');",
5206 "INSERT INTO customers (name) VALUES ('Carol');",
5207 "INSERT INTO orders (customer_id, amount) VALUES (1, 100);",
5208 "INSERT INTO orders (customer_id, amount) VALUES (1, 200);",
5209 "INSERT INTO orders (customer_id, amount) VALUES (2, 50);",
5210 "INSERT INTO orders (customer_id, amount) VALUES (4, 999);",
5211 ] {
5212 crate::sql::process_command(sql, &mut db).unwrap();
5213 }
5214 db
5215 }
5216
5217 #[test]
5218 fn inner_join_returns_only_matched_rows() {
5219 let db = seed_join_fixture();
5220 let r = run_rows(
5221 &db,
5222 "SELECT customers.name, orders.amount FROM customers \
5223 INNER JOIN orders ON customers.id = orders.customer_id;",
5224 );
5225 assert_eq!(r.columns, vec!["name".to_string(), "amount".to_string()]);
5226 let pairs: Vec<(String, i64)> = r
5229 .rows
5230 .iter()
5231 .map(|row| {
5232 (
5233 row[0].to_display_string(),
5234 match row[1] {
5235 Value::Integer(i) => i,
5236 ref v => panic!("expected integer amount, got {v:?}"),
5237 },
5238 )
5239 })
5240 .collect();
5241 assert_eq!(pairs.len(), 3);
5242 assert!(pairs.contains(&("Alice".to_string(), 100)));
5243 assert!(pairs.contains(&("Alice".to_string(), 200)));
5244 assert!(pairs.contains(&("Bob".to_string(), 50)));
5245 }
5246
5247 #[test]
5248 fn bare_join_defaults_to_inner() {
5249 let db = seed_join_fixture();
5250 let r = run_rows(
5251 &db,
5252 "SELECT customers.name FROM customers \
5253 JOIN orders ON customers.id = orders.customer_id;",
5254 );
5255 assert_eq!(r.rows.len(), 3, "JOIN without prefix should be INNER");
5256 }
5257
5258 #[test]
5259 fn left_outer_join_preserves_unmatched_left() {
5260 let db = seed_join_fixture();
5261 let r = run_rows(
5262 &db,
5263 "SELECT customers.name, orders.amount FROM customers \
5264 LEFT OUTER JOIN orders ON customers.id = orders.customer_id;",
5265 );
5266 assert_eq!(r.rows.len(), 4);
5269 let carol = r
5270 .rows
5271 .iter()
5272 .find(|row| row[0].to_display_string() == "Carol")
5273 .expect("Carol should appear with a NULL-padded right side");
5274 assert_eq!(carol[1], Value::Null);
5275 }
5276
5277 #[test]
5278 fn right_outer_join_preserves_unmatched_right() {
5279 let db = seed_join_fixture();
5280 let r = run_rows(
5281 &db,
5282 "SELECT customers.name, orders.amount FROM customers \
5283 RIGHT OUTER JOIN orders ON customers.id = orders.customer_id;",
5284 );
5285 assert_eq!(r.rows.len(), 4);
5289 let dangling = r
5290 .rows
5291 .iter()
5292 .find(|row| matches!(row[1], Value::Integer(999)))
5293 .expect("dangling order 999 should appear with a NULL-padded customer name");
5294 assert_eq!(dangling[0], Value::Null);
5295 }
5296
5297 #[test]
5298 fn full_outer_join_preserves_both_sides() {
5299 let db = seed_join_fixture();
5300 let r = run_rows(
5301 &db,
5302 "SELECT customers.name, orders.amount FROM customers \
5303 FULL OUTER JOIN orders ON customers.id = orders.customer_id;",
5304 );
5305 assert_eq!(r.rows.len(), 5);
5308 assert!(
5310 r.rows
5311 .iter()
5312 .any(|row| row[0].to_display_string() == "Carol" && matches!(row[1], Value::Null))
5313 );
5314 assert!(
5316 r.rows
5317 .iter()
5318 .any(|row| matches!(row[1], Value::Integer(999)) && matches!(row[0], Value::Null))
5319 );
5320 }
5321
5322 #[test]
5323 fn join_with_table_aliases_resolves_qualifiers() {
5324 let db = seed_join_fixture();
5325 let r = run_rows(
5326 &db,
5327 "SELECT c.name, o.amount FROM customers AS c \
5328 INNER JOIN orders AS o ON c.id = o.customer_id;",
5329 );
5330 assert_eq!(r.rows.len(), 3);
5331 assert_eq!(r.columns, vec!["name".to_string(), "amount".to_string()]);
5332 }
5333
5334 #[test]
5335 fn join_with_where_filter_applies_after_join() {
5336 let db = seed_join_fixture();
5337 let r = run_rows(
5340 &db,
5341 "SELECT customers.name, orders.amount FROM customers \
5342 INNER JOIN orders ON customers.id = orders.customer_id \
5343 WHERE orders.amount >= 100;",
5344 );
5345 assert_eq!(r.rows.len(), 2);
5346 assert!(
5347 r.rows
5348 .iter()
5349 .all(|row| row[0].to_display_string() == "Alice")
5350 );
5351 }
5352
5353 #[test]
5354 fn left_join_with_where_on_right_side_is_not_inner() {
5355 let db = seed_join_fixture();
5359 let r = run_rows(
5360 &db,
5361 "SELECT customers.name, orders.amount FROM customers \
5362 LEFT OUTER JOIN orders ON customers.id = orders.customer_id \
5363 WHERE orders.amount IS NULL;",
5364 );
5365 assert_eq!(r.rows.len(), 1);
5367 assert_eq!(r.rows[0][0].to_display_string(), "Carol");
5368 assert_eq!(r.rows[0][1], Value::Null);
5369 }
5370
5371 #[test]
5372 fn select_star_over_join_emits_all_columns_from_both_tables() {
5373 let db = seed_join_fixture();
5374 let r = run_rows(
5375 &db,
5376 "SELECT * FROM customers \
5377 INNER JOIN orders ON customers.id = orders.customer_id;",
5378 );
5379 assert_eq!(
5383 r.columns,
5384 vec![
5385 "id".to_string(),
5386 "name".to_string(),
5387 "id".to_string(),
5388 "customer_id".to_string(),
5389 "amount".to_string(),
5390 ]
5391 );
5392 assert_eq!(r.rows.len(), 3);
5393 }
5394
5395 #[test]
5396 fn join_order_by_sorts_full_joined_rows() {
5397 let db = seed_join_fixture();
5398 let r = run_rows(
5399 &db,
5400 "SELECT c.name, o.amount FROM customers AS c \
5401 INNER JOIN orders AS o ON c.id = o.customer_id \
5402 ORDER BY o.amount;",
5403 );
5404 let amounts: Vec<i64> = r
5405 .rows
5406 .iter()
5407 .map(|row| match row[1] {
5408 Value::Integer(i) => i,
5409 ref v => panic!("expected integer, got {v:?}"),
5410 })
5411 .collect();
5412 assert_eq!(amounts, vec![50, 100, 200]);
5413 }
5414
5415 #[test]
5416 fn join_limit_truncates_after_join_and_sort() {
5417 let db = seed_join_fixture();
5418 let r = run_rows(
5419 &db,
5420 "SELECT c.name, o.amount FROM customers AS c \
5421 INNER JOIN orders AS o ON c.id = o.customer_id \
5422 ORDER BY o.amount DESC LIMIT 2;",
5423 );
5424 assert_eq!(r.rows.len(), 2);
5425 let amounts: Vec<i64> = r
5427 .rows
5428 .iter()
5429 .map(|row| match row[1] {
5430 Value::Integer(i) => i,
5431 ref v => panic!("expected integer, got {v:?}"),
5432 })
5433 .collect();
5434 assert_eq!(amounts, vec![200, 100]);
5435 }
5436
5437 #[test]
5438 fn three_table_join_chains_correctly() {
5439 let mut db = Database::new("t".to_string());
5440 for sql in [
5441 "CREATE TABLE a (id INTEGER PRIMARY KEY, label TEXT);",
5442 "CREATE TABLE b (id INTEGER PRIMARY KEY, a_id INTEGER, tag TEXT);",
5443 "CREATE TABLE c (id INTEGER PRIMARY KEY, b_id INTEGER, note TEXT);",
5444 "INSERT INTO a (label) VALUES ('a-one');",
5445 "INSERT INTO a (label) VALUES ('a-two');",
5446 "INSERT INTO b (a_id, tag) VALUES (1, 'b1');",
5447 "INSERT INTO b (a_id, tag) VALUES (2, 'b2');",
5448 "INSERT INTO c (b_id, note) VALUES (1, 'c1');",
5449 ] {
5450 crate::sql::process_command(sql, &mut db).unwrap();
5451 }
5452 let r = run_rows(
5453 &db,
5454 "SELECT a.label, b.tag, c.note FROM a \
5455 INNER JOIN b ON a.id = b.a_id \
5456 INNER JOIN c ON b.id = c.b_id;",
5457 );
5458 assert_eq!(r.rows.len(), 1);
5460 assert_eq!(r.rows[0][0].to_display_string(), "a-one");
5461 assert_eq!(r.rows[0][1].to_display_string(), "b1");
5462 assert_eq!(r.rows[0][2].to_display_string(), "c1");
5463 }
5464
5465 #[test]
5466 fn ambiguous_unqualified_column_in_join_errors() {
5467 let db = seed_join_fixture();
5471 let q = parse_select(
5472 "SELECT id FROM customers INNER JOIN orders ON customers.id = orders.customer_id;",
5473 );
5474 let res = execute_select_rows(q, &db);
5475 assert!(res.is_err(), "unqualified ambiguous 'id' should error");
5476 }
5477
5478 #[test]
5479 fn join_self_without_alias_is_rejected() {
5480 let mut db = Database::new("t".to_string());
5481 crate::sql::process_command(
5482 "CREATE TABLE n (id INTEGER PRIMARY KEY, parent INTEGER);",
5483 &mut db,
5484 )
5485 .unwrap();
5486 let q = parse_select("SELECT n.id FROM n INNER JOIN n ON n.id = n.parent;");
5487 let res = execute_select_rows(q, &db);
5488 assert!(
5489 res.is_err(),
5490 "self-join without an alias should error on duplicate qualifier"
5491 );
5492 }
5493
5494 #[test]
5500 fn join_using_matches_same_rows_as_on() {
5501 let db = seed_join_fixture();
5502 let using = run_rows(
5503 &db,
5504 "SELECT customers.name, orders.amount FROM customers \
5505 INNER JOIN orders USING (id) ORDER BY orders.amount;",
5506 );
5507 let on = run_rows(
5508 &db,
5509 "SELECT customers.name, orders.amount FROM customers \
5510 INNER JOIN orders ON customers.id = orders.id ORDER BY orders.amount;",
5511 );
5512 let pairs: Vec<(String, Value)> = using
5514 .rows
5515 .iter()
5516 .map(|r| (r[0].to_display_string(), r[1].clone()))
5517 .collect();
5518 assert_eq!(pairs.len(), 3);
5519 assert_eq!(
5520 using.rows, on.rows,
5521 "USING must mirror the explicit ON rows"
5522 );
5523 }
5524
5525 #[test]
5528 fn select_star_using_dedups_joined_column() {
5529 let db = seed_join_fixture();
5530 let r = run_rows(&db, "SELECT * FROM customers INNER JOIN orders USING (id);");
5531 assert_eq!(
5535 r.columns,
5536 vec![
5537 "id".to_string(),
5538 "name".to_string(),
5539 "customer_id".to_string(),
5540 "amount".to_string(),
5541 ]
5542 );
5543 assert_eq!(r.rows.len(), 3);
5544 for row in &r.rows {
5547 assert!(matches!(row[0], Value::Integer(_)));
5548 }
5549 }
5550
5551 fn seed_natural_fixture() -> Database {
5552 let mut db = Database::new("t".to_string());
5553 for sql in [
5554 "CREATE TABLE l (lid INTEGER PRIMARY KEY, k1 INTEGER, k2 INTEGER, v1 TEXT);",
5557 "CREATE TABLE r (rid INTEGER PRIMARY KEY, k1 INTEGER, k2 INTEGER, v2 TEXT);",
5558 "INSERT INTO l (k1, k2, v1) VALUES (1, 1, 'l-a');",
5559 "INSERT INTO l (k1, k2, v1) VALUES (1, 2, 'l-b');",
5560 "INSERT INTO l (k1, k2, v1) VALUES (2, 1, 'l-c');",
5561 "INSERT INTO r (k1, k2, v2) VALUES (1, 1, 'r-a');",
5562 "INSERT INTO r (k1, k2, v2) VALUES (1, 2, 'r-b');",
5563 "INSERT INTO r (k1, k2, v2) VALUES (9, 9, 'r-z');",
5564 ] {
5565 crate::sql::process_command(sql, &mut db).unwrap();
5566 }
5567 db
5568 }
5569
5570 #[test]
5573 fn natural_join_matches_on_all_shared_columns() {
5574 let db = seed_natural_fixture();
5575 let natural = run_rows(&db, "SELECT v1, v2 FROM l NATURAL JOIN r ORDER BY v1;");
5576 let pairs: Vec<(String, String)> = natural
5578 .rows
5579 .iter()
5580 .map(|r| (r[0].to_display_string(), r[1].to_display_string()))
5581 .collect();
5582 assert_eq!(
5583 pairs,
5584 vec![
5585 ("l-a".to_string(), "r-a".to_string()),
5586 ("l-b".to_string(), "r-b".to_string()),
5587 ]
5588 );
5589 let explicit = run_rows(
5591 &db,
5592 "SELECT v1, v2 FROM l INNER JOIN r ON l.k1 = r.k1 AND l.k2 = r.k2 ORDER BY v1;",
5593 );
5594 assert_eq!(natural.rows, explicit.rows);
5595 }
5596
5597 #[test]
5599 fn select_star_natural_dedups_shared_columns() {
5600 let db = seed_natural_fixture();
5601 let r = run_rows(&db, "SELECT * FROM l NATURAL JOIN r;");
5602 assert_eq!(
5605 r.columns,
5606 vec![
5607 "lid".to_string(),
5608 "k1".to_string(),
5609 "k2".to_string(),
5610 "v1".to_string(),
5611 "rid".to_string(),
5612 "v2".to_string(),
5613 ]
5614 );
5615 assert_eq!(r.rows.len(), 2);
5616 }
5617
5618 #[test]
5621 fn natural_join_without_common_columns_is_cross_product() {
5622 let mut db = Database::new("t".to_string());
5623 for sql in [
5624 "CREATE TABLE p (pid INTEGER PRIMARY KEY, pa TEXT);",
5625 "CREATE TABLE q (qid INTEGER PRIMARY KEY, qb TEXT);",
5626 "INSERT INTO p (pa) VALUES ('p1');",
5627 "INSERT INTO p (pa) VALUES ('p2');",
5628 "INSERT INTO q (qb) VALUES ('q1');",
5629 "INSERT INTO q (qb) VALUES ('q2');",
5630 "INSERT INTO q (qb) VALUES ('q3');",
5631 ] {
5632 crate::sql::process_command(sql, &mut db).unwrap();
5633 }
5634 let r = run_rows(&db, "SELECT p.pa, q.qb FROM p NATURAL JOIN q;");
5635 assert_eq!(r.rows.len(), 2 * 3, "no shared columns ⇒ cross product");
5636 }
5637
5638 #[test]
5641 fn cross_join_produces_cartesian_product() {
5642 let db = seed_join_fixture();
5643 let cross = run_rows(
5644 &db,
5645 "SELECT customers.name, orders.amount FROM customers CROSS JOIN orders;",
5646 );
5647 assert_eq!(cross.rows.len(), 12);
5649 let on_true = run_rows(
5650 &db,
5651 "SELECT customers.name, orders.amount FROM customers INNER JOIN orders ON 1;",
5652 );
5653 assert_eq!(cross.rows.len(), on_true.rows.len());
5654 let star = run_rows(&db, "SELECT * FROM customers CROSS JOIN orders;");
5656 assert_eq!(star.columns.len(), 5);
5657 assert_eq!(star.rows.len(), 12);
5658 }
5659
5660 #[test]
5664 fn left_outer_join_using_preserves_unmatched_left() {
5665 let db = seed_join_fixture();
5666 let r = run_rows(
5667 &db,
5668 "SELECT * FROM customers LEFT OUTER JOIN orders USING (id);",
5669 );
5670 assert_eq!(r.columns.len(), 4, "id is shown once");
5674 assert_eq!(r.rows.len(), 3);
5675 }
5676
5677 #[test]
5680 fn using_unknown_column_errors() {
5681 let db = seed_join_fixture();
5682 let q = parse_select("SELECT * FROM customers INNER JOIN orders USING (nope);");
5683 let res = execute_select_rows(q, &db);
5684 assert!(res.is_err(), "USING (nope) must error — column absent");
5685 }
5686
5687 #[test]
5692 fn group_by_with_aggregates_over_inner_join() {
5693 let db = seed_join_fixture();
5694 let r = run_rows(
5695 &db,
5696 "SELECT customers.name, COUNT(*), SUM(orders.amount) FROM customers \
5697 INNER JOIN orders ON customers.id = orders.customer_id \
5698 GROUP BY customers.name ORDER BY customers.name;",
5699 );
5700 assert_eq!(r.columns, vec!["name", "COUNT(*)", "SUM(orders.amount)"]);
5701 assert_eq!(r.rows.len(), 2);
5702 assert_eq!(r.rows[0][0].to_display_string(), "Alice");
5703 assert_eq!(expect_int(&r.rows[0][1]), 2);
5704 assert_eq!(expect_int(&r.rows[0][2]), 300);
5705 assert_eq!(r.rows[1][0].to_display_string(), "Bob");
5706 assert_eq!(expect_int(&r.rows[1][1]), 1);
5707 assert_eq!(expect_int(&r.rows[1][2]), 50);
5708 }
5709
5710 #[test]
5711 fn aggregates_over_join_without_group_by() {
5712 let db = seed_join_fixture();
5713 let r = run_rows(
5714 &db,
5715 "SELECT COUNT(*), SUM(orders.amount) FROM customers \
5716 INNER JOIN orders ON customers.id = orders.customer_id;",
5717 );
5718 assert_eq!(r.rows.len(), 1);
5719 assert_eq!(expect_int(&r.rows[0][0]), 3);
5720 assert_eq!(expect_int(&r.rows[0][1]), 350);
5721 }
5722
5723 #[test]
5724 fn count_column_skips_outer_join_null_padding() {
5725 let db = seed_join_fixture();
5729 let r = run_rows(
5730 &db,
5731 "SELECT customers.name, COUNT(*), COUNT(orders.id) FROM customers \
5732 LEFT OUTER JOIN orders ON customers.id = orders.customer_id \
5733 GROUP BY customers.name ORDER BY customers.name;",
5734 );
5735 assert_eq!(r.rows.len(), 3);
5736 let carol = &r.rows[2];
5737 assert_eq!(carol[0].to_display_string(), "Carol");
5738 assert_eq!(expect_int(&carol[1]), 1, "COUNT(*) counts the padded row");
5739 assert_eq!(expect_int(&carol[2]), 0, "COUNT(col) skips the NULL");
5740 }
5741
5742 #[test]
5743 fn outer_join_null_keys_group_together() {
5744 let db = seed_join_fixture();
5747 let r = run_rows(
5748 &db,
5749 "SELECT customers.name, COUNT(*) FROM customers \
5750 FULL OUTER JOIN orders ON customers.id = orders.customer_id \
5751 GROUP BY customers.name;",
5752 );
5753 assert_eq!(r.rows.len(), 4, "Alice, Bob, Carol, NULL");
5754 let null_group = r
5755 .rows
5756 .iter()
5757 .find(|row| row[0] == Value::Null)
5758 .expect("dangling order groups under NULL");
5759 assert_eq!(expect_int(&null_group[1]), 1);
5760 }
5761
5762 #[test]
5763 fn count_distinct_over_join() {
5764 let db = seed_join_fixture();
5765 let r = run_rows(
5766 &db,
5767 "SELECT COUNT(DISTINCT customers.name) FROM customers \
5768 INNER JOIN orders ON customers.id = orders.customer_id;",
5769 );
5770 assert_eq!(expect_int(&r.rows[0][0]), 2);
5771 }
5772
5773 #[test]
5774 fn group_by_qualified_key_resolves_ambiguous_name() {
5775 let db = seed_join_fixture();
5778 let r = run_rows(
5779 &db,
5780 "SELECT customers.id, COUNT(*) FROM customers \
5781 INNER JOIN orders ON customers.id = orders.customer_id \
5782 GROUP BY customers.id ORDER BY customers.id;",
5783 );
5784 assert_eq!(r.rows.len(), 2);
5785 assert_eq!(expect_int(&r.rows[0][0]), 1);
5786 assert_eq!(expect_int(&r.rows[0][1]), 2);
5787 }
5788
5789 #[test]
5790 fn group_by_ambiguous_unqualified_key_over_join_errors() {
5791 let err = crate::sql::process_command(
5792 "SELECT COUNT(*) FROM customers \
5793 INNER JOIN orders ON customers.id = orders.customer_id GROUP BY id;",
5794 &mut seed_join_fixture(),
5795 );
5796 match err {
5797 Err(e) => assert!(
5798 e.to_string().contains("ambiguous"),
5799 "unexpected message: {e}"
5800 ),
5801 Ok(_) => panic!("ambiguous GROUP BY key must error"),
5802 }
5803 }
5804
5805 #[test]
5806 fn bare_column_not_in_group_by_over_join_errors() {
5807 let err = crate::sql::process_command(
5808 "SELECT orders.amount, COUNT(*) FROM customers \
5809 INNER JOIN orders ON customers.id = orders.customer_id \
5810 GROUP BY customers.name;",
5811 &mut seed_join_fixture(),
5812 );
5813 match err {
5814 Err(e) => assert!(
5815 e.to_string().contains("must appear in GROUP BY"),
5816 "unexpected message: {e}"
5817 ),
5818 Ok(_) => panic!("bare column outside GROUP BY must error"),
5819 }
5820 }
5821
5822 #[test]
5823 fn aggregate_in_where_over_join_errors_cleanly() {
5824 let err = crate::sql::process_command(
5827 "SELECT COUNT(*) FROM customers \
5828 INNER JOIN orders ON customers.id = orders.customer_id \
5829 WHERE COUNT(*) > 1;",
5830 &mut seed_join_fixture(),
5831 );
5832 match err {
5833 Err(SQLRiteError::NotImplemented(msg)) => assert!(
5834 msg.contains("not allowed in WHERE"),
5835 "unexpected message: {msg}"
5836 ),
5837 other => panic!("expected NotImplemented, got {other:?}"),
5838 }
5839 }
5840
5841 #[test]
5842 fn order_by_aggregate_over_join() {
5843 let db = seed_join_fixture();
5844 let r = run_rows(
5845 &db,
5846 "SELECT customers.name, SUM(orders.amount) FROM customers \
5847 INNER JOIN orders ON customers.id = orders.customer_id \
5848 GROUP BY customers.name ORDER BY SUM(orders.amount) DESC;",
5849 );
5850 assert_eq!(r.rows[0][0].to_display_string(), "Alice");
5851 let r2 = run_rows(
5854 &db,
5855 "SELECT customers.name, SUM(orders.amount) FROM customers \
5856 INNER JOIN orders ON customers.id = orders.customer_id \
5857 GROUP BY customers.name ORDER BY SUM(amount) DESC;",
5858 );
5859 assert_eq!(r2.rows[0][0].to_display_string(), "Alice");
5860 }
5861
5862 #[test]
5863 fn distinct_over_join_dedupes_output_rows() {
5864 let db = seed_join_fixture();
5865 let r = run_rows(
5866 &db,
5867 "SELECT DISTINCT customers.name FROM customers \
5868 INNER JOIN orders ON customers.id = orders.customer_id;",
5869 );
5870 assert_eq!(r.rows.len(), 2);
5871 let names: Vec<String> = r
5872 .rows
5873 .iter()
5874 .map(|row| row[0].to_display_string())
5875 .collect();
5876 assert_eq!(names, vec!["Alice".to_string(), "Bob".to_string()]);
5877 }
5878
5879 #[test]
5880 fn distinct_over_join_defers_limit_past_dedupe() {
5881 let db = seed_join_fixture();
5884 let r = run_rows(
5885 &db,
5886 "SELECT DISTINCT customers.name FROM customers \
5887 INNER JOIN orders ON customers.id = orders.customer_id LIMIT 2;",
5888 );
5889 assert_eq!(r.rows.len(), 2, "LIMIT applies after DISTINCT collapses");
5890 }
5891
5892 #[test]
5893 fn select_star_group_by_errors_instead_of_panicking() {
5894 let err = crate::sql::process_command(
5898 "SELECT * FROM orders GROUP BY customer_id;",
5899 &mut seed_join_fixture(),
5900 );
5901 match err {
5902 Err(e) => assert!(
5903 e.to_string().contains("must appear in GROUP BY"),
5904 "unexpected message: {e}"
5905 ),
5906 Ok(_) => panic!("SELECT * with GROUP BY must error, not panic"),
5907 }
5908 }
5909
5910 #[test]
5911 fn group_by_qualified_key_single_table_still_works() {
5912 let db = seed_employees();
5915 let r = run_rows(
5916 &db,
5917 "SELECT dept, COUNT(*) FROM emp GROUP BY emp.dept ORDER BY dept;",
5918 );
5919 assert_eq!(r.rows.len(), 3, "eng / sales / ops");
5920 }
5921
5922 #[test]
5923 fn left_join_with_no_matches_pads_every_row() {
5924 let mut db = Database::new("t".to_string());
5925 for sql in [
5926 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
5927 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
5928 "INSERT INTO a (x) VALUES (1);",
5929 "INSERT INTO a (x) VALUES (2);",
5930 "INSERT INTO b (y) VALUES (10);",
5931 ] {
5932 crate::sql::process_command(sql, &mut db).unwrap();
5933 }
5934 let r = run_rows(
5936 &db,
5937 "SELECT a.x, b.y FROM a LEFT OUTER JOIN b ON a.x = b.y;",
5938 );
5939 assert_eq!(r.rows.len(), 2);
5940 for row in &r.rows {
5941 assert_eq!(row[1], Value::Null);
5942 }
5943 }
5944
5945 #[test]
5946 fn left_outer_join_order_by_places_nulls_first() {
5947 let db = seed_join_fixture();
5952 let r = run_rows(
5953 &db,
5954 "SELECT c.name, o.amount FROM customers AS c \
5955 LEFT OUTER JOIN orders AS o ON c.id = o.customer_id \
5956 ORDER BY o.amount ASC;",
5957 );
5958 assert_eq!(r.rows.len(), 4);
5959 assert_eq!(r.rows[0][0].to_display_string(), "Carol");
5961 assert_eq!(r.rows[0][1], Value::Null);
5962 }
5963
5964 #[test]
5965 fn chained_left_outer_join_preserves_left_through_two_levels() {
5966 let mut db = Database::new("t".to_string());
5969 for sql in [
5970 "CREATE TABLE a (id INTEGER PRIMARY KEY, label TEXT);",
5971 "CREATE TABLE b (id INTEGER PRIMARY KEY, a_id INTEGER, tag TEXT);",
5972 "CREATE TABLE c (id INTEGER PRIMARY KEY, b_id INTEGER, note TEXT);",
5973 "INSERT INTO a (label) VALUES ('a-one');",
5974 "INSERT INTO a (label) VALUES ('a-two');",
5975 "INSERT INTO b (a_id, tag) VALUES (1, 'b1');",
5977 ] {
5979 crate::sql::process_command(sql, &mut db).unwrap();
5980 }
5981 let r = run_rows(
5982 &db,
5983 "SELECT a.label, b.tag, c.note FROM a \
5984 LEFT OUTER JOIN b ON a.id = b.a_id \
5985 LEFT OUTER JOIN c ON b.id = c.b_id;",
5986 );
5987 assert_eq!(r.rows.len(), 2);
5989 let by_label: std::collections::HashMap<String, &Vec<Value>> = r
5990 .rows
5991 .iter()
5992 .map(|row| (row[0].to_display_string(), row))
5993 .collect();
5994 assert_eq!(by_label["a-one"][1].to_display_string(), "b1");
5995 assert_eq!(by_label["a-one"][2], Value::Null);
5996 assert_eq!(by_label["a-two"][1], Value::Null);
5997 assert_eq!(by_label["a-two"][2], Value::Null);
5998 }
5999
6000 #[test]
6001 fn on_clause_referencing_not_yet_joined_table_errors_clearly() {
6002 let mut db = Database::new("t".to_string());
6006 for sql in [
6007 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
6008 "CREATE TABLE b (id INTEGER PRIMARY KEY, x INTEGER);",
6009 "CREATE TABLE c (id INTEGER PRIMARY KEY, x INTEGER);",
6010 "INSERT INTO a (x) VALUES (1);",
6011 "INSERT INTO b (x) VALUES (1);",
6012 "INSERT INTO c (x) VALUES (1);",
6013 ] {
6014 crate::sql::process_command(sql, &mut db).unwrap();
6015 }
6016 let q =
6017 parse_select("SELECT a.x FROM a INNER JOIN b ON a.x = c.x INNER JOIN c ON b.x = c.x;");
6018 let res = execute_select_rows(q, &db);
6019 assert!(
6020 res.is_err(),
6021 "ON referencing not-yet-joined table 'c' should error"
6022 );
6023 }
6024
6025 #[test]
6026 fn join_on_truthy_integer_is_accepted() {
6027 let mut db = Database::new("t".to_string());
6031 for sql in [
6032 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
6033 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
6034 "INSERT INTO a (x) VALUES (1);",
6035 "INSERT INTO a (x) VALUES (2);",
6036 "INSERT INTO b (y) VALUES (10);",
6037 "INSERT INTO b (y) VALUES (20);",
6038 ] {
6039 crate::sql::process_command(sql, &mut db).unwrap();
6040 }
6041 let r = run_rows(&db, "SELECT a.x, b.y FROM a INNER JOIN b ON 1;");
6042 assert_eq!(r.rows.len(), 4);
6044 }
6045
6046 #[test]
6047 fn full_join_on_empty_tables_returns_empty() {
6048 let mut db = Database::new("t".to_string());
6049 for sql in [
6050 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
6051 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
6052 ] {
6053 crate::sql::process_command(sql, &mut db).unwrap();
6054 }
6055 let r = run_rows(
6056 &db,
6057 "SELECT a.x, b.y FROM a FULL OUTER JOIN b ON a.x = b.y;",
6058 );
6059 assert!(r.rows.is_empty());
6060 }
6061}