1use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8 AlterTable, AlterTableOperation, AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr,
9 FromTable, FunctionArg, FunctionArgExpr, FunctionArguments, IndexType, ObjectName,
10 ObjectNamePart, RenameTableNameKind, Statement, TableFactor, TableWithJoins, UnaryOperator,
11 Update, Value as AstValue,
12};
13
14use crate::error::{Result, SQLRiteError};
15use crate::sql::agg::{AggState, DistinctKey, like_match};
16use crate::sql::db::database::Database;
17use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
18use crate::sql::db::table::{
19 DataType, FtsIndexEntry, HnswIndexEntry, Table, Value, parse_vector_literal,
20};
21use crate::sql::fts::{Bm25Params, PostingList};
22use crate::sql::hnsw::{DistanceMetric, HnswIndex};
23use crate::sql::parser::select::{
24 AggregateArg, JoinType, OrderByClause, Projection, ProjectionItem, ProjectionKind, SelectQuery,
25};
26
27pub(crate) trait RowScope {
56 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value>;
57
58 fn single_table_view(&self) -> Option<(&Table, i64)>;
64}
65
66pub(crate) struct SingleTableScope<'a> {
68 table: &'a Table,
69 rowid: i64,
70}
71
72impl<'a> SingleTableScope<'a> {
73 pub(crate) fn new(table: &'a Table, rowid: i64) -> Self {
74 Self { table, rowid }
75 }
76}
77
78impl RowScope for SingleTableScope<'_> {
79 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
80 let _ = qualifier;
85 Ok(self.table.get_value(col, self.rowid).unwrap_or(Value::Null))
86 }
87
88 fn single_table_view(&self) -> Option<(&Table, i64)> {
89 Some((self.table, self.rowid))
90 }
91}
92
93pub(crate) struct JoinedTableRef<'a> {
97 pub table: &'a Table,
98 pub scope_name: String,
99}
100
101pub(crate) struct JoinedScope<'a> {
105 pub tables: &'a [JoinedTableRef<'a>],
106 pub rowids: &'a [Option<i64>],
107}
108
109impl RowScope for JoinedScope<'_> {
110 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
111 if let Some(q) = qualifier {
112 let pos = self
115 .tables
116 .iter()
117 .position(|t| t.scope_name.eq_ignore_ascii_case(q))
118 .ok_or_else(|| {
119 SQLRiteError::Internal(format!(
120 "unknown table qualifier '{q}' in column reference '{q}.{col}'"
121 ))
122 })?;
123 if !self.tables[pos].table.contains_column(col.to_string()) {
124 return Err(SQLRiteError::Internal(format!(
125 "column '{col}' does not exist on '{}'",
126 self.tables[pos].scope_name
127 )));
128 }
129 return Ok(match self.rowids[pos] {
130 None => Value::Null,
131 Some(r) => self.tables[pos]
132 .table
133 .get_value(col, r)
134 .unwrap_or(Value::Null),
135 });
136 }
137 let mut hit: Option<usize> = None;
141 for (i, t) in self.tables.iter().enumerate() {
142 if t.table.contains_column(col.to_string()) {
143 if hit.is_some() {
144 return Err(SQLRiteError::Internal(format!(
145 "column reference '{col}' is ambiguous — qualify it as <table>.{col}"
146 )));
147 }
148 hit = Some(i);
149 }
150 }
151 let i = hit.ok_or_else(|| {
152 SQLRiteError::Internal(format!(
153 "unknown column '{col}' in joined SELECT (no in-scope table has it)"
154 ))
155 })?;
156 Ok(match self.rowids[i] {
157 None => Value::Null,
158 Some(r) => self.tables[i]
159 .table
160 .get_value(col, r)
161 .unwrap_or(Value::Null),
162 })
163 }
164
165 fn single_table_view(&self) -> Option<(&Table, i64)> {
166 None
167 }
168}
169
170pub struct SelectResult {
179 pub columns: Vec<String>,
180 pub rows: Vec<Vec<Value>>,
181}
182
183pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
187 if !query.joins.is_empty() {
192 return execute_select_rows_joined(query, db);
193 }
194
195 let table = db
196 .get_table(query.table_name.clone())
197 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
198
199 let proj_items: Vec<ProjectionItem> = match &query.projection {
204 Projection::All => table
205 .column_names()
206 .into_iter()
207 .map(|c| ProjectionItem {
208 kind: ProjectionKind::Column {
209 qualifier: None,
210 name: c,
211 },
212 alias: None,
213 })
214 .collect(),
215 Projection::Items(items) => items.clone(),
216 };
217 let has_aggregates = proj_items
218 .iter()
219 .any(|i| matches!(i.kind, ProjectionKind::Aggregate(_)));
220 for item in &proj_items {
222 if let ProjectionKind::Column { name: c, .. } = &item.kind
223 && !table.contains_column(c.clone())
224 {
225 return Err(SQLRiteError::Internal(format!(
226 "Column '{c}' does not exist on table '{}'",
227 query.table_name
228 )));
229 }
230 }
231 for c in &query.group_by {
232 if !table.contains_column(c.clone()) {
233 return Err(SQLRiteError::Internal(format!(
234 "GROUP BY references unknown column '{c}' on table '{}'",
235 query.table_name
236 )));
237 }
238 }
239 let matching = match select_rowids(table, query.selection.as_ref())? {
243 RowidSource::IndexProbe(rowids) => rowids,
244 RowidSource::FullScan => {
245 let mut out = Vec::new();
246 for rowid in table.rowids() {
247 if let Some(expr) = &query.selection
248 && !eval_predicate(expr, table, rowid)?
249 {
250 continue;
251 }
252 out.push(rowid);
253 }
254 out
255 }
256 };
257 let mut matching = matching;
258
259 let aggregating = has_aggregates || !query.group_by.is_empty();
260
261 if aggregating {
267 for item in &proj_items {
269 if let ProjectionKind::Aggregate(call) = &item.kind
270 && let AggregateArg::Column(c) = &call.arg
271 && !table.contains_column(c.clone())
272 {
273 return Err(SQLRiteError::Internal(format!(
274 "{}({}) references unknown column '{c}' on table '{}'",
275 call.func.as_str(),
276 c,
277 query.table_name
278 )));
279 }
280 }
281
282 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
283 let mut rows = aggregate_rows(table, &matching, &query.group_by, &proj_items)?;
284
285 if query.distinct {
286 rows = dedupe_rows(rows);
287 }
288
289 if let Some(order) = &query.order_by {
290 sort_output_rows(&mut rows, &columns, &proj_items, order)?;
291 }
292 if let Some(k) = query.limit {
293 rows.truncate(k);
294 }
295
296 return Ok(SelectResult { columns, rows });
297 }
298
299 let defer_limit_for_distinct = query.distinct;
337 match (&query.order_by, query.limit) {
338 (Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
339 matching = try_hnsw_probe(table, &order.expr, k).unwrap();
340 }
341 (Some(order), Some(k))
342 if try_fts_probe(table, &order.expr, order.ascending, k).is_some() =>
343 {
344 matching = try_fts_probe(table, &order.expr, order.ascending, k).unwrap();
345 }
346 (Some(order), Some(k)) if !defer_limit_for_distinct && k < matching.len() => {
347 matching = select_topk(&matching, table, order, k)?;
348 }
349 (Some(order), _) => {
350 sort_rowids(&mut matching, table, order)?;
351 if let Some(k) = query.limit
352 && !defer_limit_for_distinct
353 {
354 matching.truncate(k);
355 }
356 }
357 (None, Some(k)) if !defer_limit_for_distinct => {
358 matching.truncate(k);
359 }
360 _ => {}
361 }
362
363 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
364 let projected_cols: Vec<String> = proj_items
365 .iter()
366 .map(|i| match &i.kind {
367 ProjectionKind::Column { name, .. } => name.clone(),
368 ProjectionKind::Aggregate(_) => unreachable!("aggregation handled above"),
369 })
370 .collect();
371
372 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
376 for rowid in &matching {
377 let row: Vec<Value> = projected_cols
378 .iter()
379 .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
380 .collect();
381 rows.push(row);
382 }
383
384 if query.distinct {
385 rows = dedupe_rows(rows);
386 if let Some(k) = query.limit {
387 rows.truncate(k);
388 }
389 }
390
391 Ok(SelectResult { columns, rows })
392}
393
394fn execute_select_rows_joined(query: SelectQuery, db: &Database) -> Result<SelectResult> {
421 let mut joined_tables: Vec<JoinedTableRef<'_>> = Vec::with_capacity(1 + query.joins.len());
428
429 let primary = db
430 .get_table(query.table_name.clone())
431 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
432 joined_tables.push(JoinedTableRef {
433 table: primary,
434 scope_name: query
435 .table_alias
436 .clone()
437 .unwrap_or_else(|| query.table_name.clone()),
438 });
439 for j in &query.joins {
440 let t = db
441 .get_table(j.right_table.clone())
442 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", j.right_table)))?;
443 joined_tables.push(JoinedTableRef {
444 table: t,
445 scope_name: j
446 .right_alias
447 .clone()
448 .unwrap_or_else(|| j.right_table.clone()),
449 });
450 }
451
452 {
457 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
458 for t in &joined_tables {
459 let key = t.scope_name.to_ascii_lowercase();
460 if !seen.insert(key) {
461 return Err(SQLRiteError::Internal(format!(
462 "duplicate table reference '{}' in FROM/JOIN — use AS to alias one side",
463 t.scope_name
464 )));
465 }
466 }
467 }
468
469 let proj_items: Vec<ProjectionItem> = match &query.projection {
475 Projection::All => {
476 let mut all = Vec::new();
485 for t in &joined_tables {
486 for col in t.table.column_names() {
487 all.push(ProjectionItem {
488 kind: ProjectionKind::Column {
489 qualifier: Some(t.scope_name.clone()),
494 name: col,
495 },
496 alias: None,
497 });
498 }
499 }
500 all
501 }
502 Projection::Items(items) => items.clone(),
503 };
504
505 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
506
507 let mut acc: Vec<Vec<Option<i64>>> = primary
512 .rowids()
513 .into_iter()
514 .map(|r| {
515 let mut row = Vec::with_capacity(joined_tables.len());
516 row.push(Some(r));
517 row
518 })
519 .collect();
520
521 for (j_idx, join) in query.joins.iter().enumerate() {
526 let right_pos = j_idx + 1;
527 let right_table = joined_tables[right_pos].table;
528 let right_rowids: Vec<i64> = right_table.rowids();
529
530 let mut right_matched: Vec<bool> = vec![false; right_rowids.len()];
534
535 let mut next_acc: Vec<Vec<Option<i64>>> = Vec::with_capacity(acc.len());
536
537 let on_scope_tables: &[JoinedTableRef<'_>] = &joined_tables[..=right_pos];
545
546 for left_row in acc.into_iter() {
547 let mut left_match_count = 0usize;
551 for (r_idx, &rrid) in right_rowids.iter().enumerate() {
552 let mut on_rowids: Vec<Option<i64>> = left_row.clone();
553 on_rowids.push(Some(rrid));
554 debug_assert_eq!(on_rowids.len(), on_scope_tables.len());
555 let scope = JoinedScope {
556 tables: on_scope_tables,
557 rowids: &on_rowids,
558 };
559 if eval_predicate_scope(&join.on, &scope)? {
564 left_match_count += 1;
565 right_matched[r_idx] = true;
566 next_acc.push(on_rowids);
571 }
572 }
573
574 if left_match_count == 0
575 && matches!(join.join_type, JoinType::LeftOuter | JoinType::FullOuter)
576 {
577 let mut padded = left_row;
580 padded.push(None);
581 next_acc.push(padded);
582 }
583 }
584
585 if matches!(join.join_type, JoinType::RightOuter | JoinType::FullOuter) {
589 for (r_idx, matched) in right_matched.iter().enumerate() {
590 if *matched {
591 continue;
592 }
593 let mut row: Vec<Option<i64>> = vec![None; right_pos];
594 row.push(Some(right_rowids[r_idx]));
595 next_acc.push(row);
596 }
597 }
598
599 acc = next_acc;
600 }
601
602 let mut filtered: Vec<Vec<Option<i64>>> = if let Some(where_expr) = &query.selection {
607 let mut out = Vec::with_capacity(acc.len());
608 for row in acc {
609 let scope = JoinedScope {
610 tables: &joined_tables,
611 rowids: &row,
612 };
613 if eval_predicate_scope(where_expr, &scope)? {
614 out.push(row);
615 }
616 }
617 out
618 } else {
619 acc
620 };
621
622 if let Some(order) = &query.order_by {
626 let mut keys: Vec<(usize, Value)> = Vec::with_capacity(filtered.len());
629 for (i, row) in filtered.iter().enumerate() {
630 let scope = JoinedScope {
631 tables: &joined_tables,
632 rowids: row,
633 };
634 let v = eval_expr_scope(&order.expr, &scope)?;
635 keys.push((i, v));
636 }
637 keys.sort_by(|(_, a), (_, b)| {
638 let ord = compare_values(Some(a), Some(b));
639 if order.ascending { ord } else { ord.reverse() }
640 });
641 let mut sorted = Vec::with_capacity(filtered.len());
642 for (i, _) in keys {
643 sorted.push(filtered[i].clone());
644 }
645 filtered = sorted;
646 }
647
648 if let Some(k) = query.limit {
650 filtered.truncate(k);
651 }
652
653 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(filtered.len());
656 for row in &filtered {
657 let scope = JoinedScope {
658 tables: &joined_tables,
659 rowids: row,
660 };
661 let mut out_row = Vec::with_capacity(proj_items.len());
662 for item in &proj_items {
663 let v = match &item.kind {
664 ProjectionKind::Column { qualifier, name } => {
665 scope.lookup(qualifier.as_deref(), name)?
666 }
667 ProjectionKind::Aggregate(_) => {
668 return Err(SQLRiteError::Internal(
671 "aggregate functions over JOIN are not supported".to_string(),
672 ));
673 }
674 };
675 out_row.push(v);
676 }
677 rows.push(out_row);
678 }
679
680 Ok(SelectResult { columns, rows })
681}
682
683pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
688 let result = execute_select_rows(query, db)?;
689 let row_count = result.rows.len();
690
691 let mut print_table = PrintTable::new();
692 let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
693 print_table.add_row(PrintRow::new(header_cells));
694
695 for row in &result.rows {
696 let cells: Vec<PrintCell> = row
697 .iter()
698 .map(|v| PrintCell::new(&v.to_display_string()))
699 .collect();
700 print_table.add_row(PrintRow::new(cells));
701 }
702
703 Ok((print_table.to_string(), row_count))
704}
705
706pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
708 let Statement::Delete(Delete {
709 from, selection, ..
710 }) = stmt
711 else {
712 return Err(SQLRiteError::Internal(
713 "execute_delete called on a non-DELETE statement".to_string(),
714 ));
715 };
716
717 let tables = match from {
718 FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
719 };
720 let table_name = extract_single_table_name(tables)?;
721
722 let matching: Vec<i64> = {
724 let table = db
725 .get_table(table_name.clone())
726 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
727 match select_rowids(table, selection.as_ref())? {
728 RowidSource::IndexProbe(rowids) => rowids,
729 RowidSource::FullScan => {
730 let mut out = Vec::new();
731 for rowid in table.rowids() {
732 if let Some(expr) = selection {
733 if !eval_predicate(expr, table, rowid)? {
734 continue;
735 }
736 }
737 out.push(rowid);
738 }
739 out
740 }
741 }
742 };
743
744 let table = db.get_table_mut(table_name)?;
745 for rowid in &matching {
746 table.delete_row(*rowid);
747 }
748 if !matching.is_empty() {
757 for entry in &mut table.hnsw_indexes {
758 entry.needs_rebuild = true;
759 }
760 for entry in &mut table.fts_indexes {
761 entry.needs_rebuild = true;
762 }
763 }
764 Ok(matching.len())
765}
766
767pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
769 let Statement::Update(Update {
770 table,
771 assignments,
772 from,
773 selection,
774 ..
775 }) = stmt
776 else {
777 return Err(SQLRiteError::Internal(
778 "execute_update called on a non-UPDATE statement".to_string(),
779 ));
780 };
781
782 if from.is_some() {
783 return Err(SQLRiteError::NotImplemented(
784 "UPDATE ... FROM is not supported yet".to_string(),
785 ));
786 }
787
788 let table_name = extract_table_name(table)?;
789
790 let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
792 {
793 let tbl = db
794 .get_table(table_name.clone())
795 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
796 for a in assignments {
797 let col = match &a.target {
798 AssignmentTarget::ColumnName(name) => name
799 .0
800 .last()
801 .map(|p| p.to_string())
802 .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
803 AssignmentTarget::Tuple(_) => {
804 return Err(SQLRiteError::NotImplemented(
805 "tuple assignment targets are not supported".to_string(),
806 ));
807 }
808 };
809 if !tbl.contains_column(col.clone()) {
810 return Err(SQLRiteError::Internal(format!(
811 "UPDATE references unknown column '{col}'"
812 )));
813 }
814 parsed_assignments.push((col, a.value.clone()));
815 }
816 }
817
818 let work: Vec<(i64, Vec<(String, Value)>)> = {
822 let tbl = db.get_table(table_name.clone())?;
823 let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
824 RowidSource::IndexProbe(rowids) => rowids,
825 RowidSource::FullScan => {
826 let mut out = Vec::new();
827 for rowid in tbl.rowids() {
828 if let Some(expr) = selection {
829 if !eval_predicate(expr, tbl, rowid)? {
830 continue;
831 }
832 }
833 out.push(rowid);
834 }
835 out
836 }
837 };
838 let mut rows_to_update = Vec::new();
839 for rowid in matched_rowids {
840 let mut values = Vec::with_capacity(parsed_assignments.len());
841 for (col, expr) in &parsed_assignments {
842 let v = eval_expr(expr, tbl, rowid)?;
845 values.push((col.clone(), v));
846 }
847 rows_to_update.push((rowid, values));
848 }
849 rows_to_update
850 };
851
852 let tbl = db.get_table_mut(table_name)?;
853 for (rowid, values) in &work {
854 for (col, v) in values {
855 tbl.set_value(col, *rowid, v.clone())?;
856 }
857 }
858
859 if !work.is_empty() {
868 let updated_columns: std::collections::HashSet<&str> = work
869 .iter()
870 .flat_map(|(_, values)| values.iter().map(|(c, _)| c.as_str()))
871 .collect();
872 for entry in &mut tbl.hnsw_indexes {
873 if updated_columns.contains(entry.column_name.as_str()) {
874 entry.needs_rebuild = true;
875 }
876 }
877 for entry in &mut tbl.fts_indexes {
878 if updated_columns.contains(entry.column_name.as_str()) {
879 entry.needs_rebuild = true;
880 }
881 }
882 }
883 Ok(work.len())
884}
885
886pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
898 let Statement::CreateIndex(CreateIndex {
899 name,
900 table_name,
901 columns,
902 using,
903 unique,
904 if_not_exists,
905 predicate,
906 with,
907 ..
908 }) = stmt
909 else {
910 return Err(SQLRiteError::Internal(
911 "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
912 ));
913 };
914
915 if predicate.is_some() {
916 return Err(SQLRiteError::NotImplemented(
917 "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
918 ));
919 }
920
921 if columns.len() != 1 {
922 return Err(SQLRiteError::NotImplemented(format!(
923 "multi-column indexes are not supported yet ({} columns given)",
924 columns.len()
925 )));
926 }
927
928 let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
929 SQLRiteError::NotImplemented(
930 "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
931 )
932 })?;
933
934 let method = match using {
940 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
941 IndexMethod::Hnsw
942 }
943 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("fts") => {
944 IndexMethod::Fts
945 }
946 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
947 IndexMethod::Btree
948 }
949 Some(other) => {
950 return Err(SQLRiteError::NotImplemented(format!(
951 "CREATE INDEX … USING {other:?} is not supported \
952 (try `hnsw`, `fts`, or no USING clause)"
953 )));
954 }
955 None => IndexMethod::Btree,
956 };
957
958 let hnsw_metric = parse_hnsw_with_options(with, &index_name, method)?;
964
965 let table_name_str = table_name.to_string();
966 let column_name = match &columns[0].column.expr {
967 Expr::Identifier(ident) => ident.value.clone(),
968 Expr::CompoundIdentifier(parts) => parts
969 .last()
970 .map(|p| p.value.clone())
971 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
972 other => {
973 return Err(SQLRiteError::NotImplemented(format!(
974 "CREATE INDEX only supports simple column references, got {other:?}"
975 )));
976 }
977 };
978
979 let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
984 let table = db.get_table(table_name_str.clone()).map_err(|_| {
985 SQLRiteError::General(format!(
986 "CREATE INDEX references unknown table '{table_name_str}'"
987 ))
988 })?;
989 if !table.contains_column(column_name.clone()) {
990 return Err(SQLRiteError::General(format!(
991 "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
992 )));
993 }
994 let col = table
995 .columns
996 .iter()
997 .find(|c| c.column_name == column_name)
998 .expect("we just verified the column exists");
999
1000 if table.index_by_name(&index_name).is_some()
1003 || table.hnsw_indexes.iter().any(|i| i.name == index_name)
1004 || table.fts_indexes.iter().any(|i| i.name == index_name)
1005 {
1006 if *if_not_exists {
1007 return Ok(index_name);
1008 }
1009 return Err(SQLRiteError::General(format!(
1010 "index '{index_name}' already exists"
1011 )));
1012 }
1013 let datatype = clone_datatype(&col.datatype);
1014
1015 let mut pairs = Vec::new();
1016 for rowid in table.rowids() {
1017 if let Some(v) = table.get_value(&column_name, rowid) {
1018 pairs.push((rowid, v));
1019 }
1020 }
1021 (datatype, pairs)
1022 };
1023
1024 match method {
1025 IndexMethod::Btree => create_btree_index(
1026 db,
1027 &table_name_str,
1028 &index_name,
1029 &column_name,
1030 &datatype,
1031 *unique,
1032 &existing_rowids_and_values,
1033 ),
1034 IndexMethod::Hnsw => create_hnsw_index(
1035 db,
1036 &table_name_str,
1037 &index_name,
1038 &column_name,
1039 &datatype,
1040 *unique,
1041 hnsw_metric.unwrap_or(DistanceMetric::L2),
1042 &existing_rowids_and_values,
1043 ),
1044 IndexMethod::Fts => create_fts_index(
1045 db,
1046 &table_name_str,
1047 &index_name,
1048 &column_name,
1049 &datatype,
1050 *unique,
1051 &existing_rowids_and_values,
1052 ),
1053 }
1054}
1055
1056pub fn execute_drop_table(
1067 names: &[ObjectName],
1068 if_exists: bool,
1069 db: &mut Database,
1070) -> Result<usize> {
1071 if names.len() != 1 {
1072 return Err(SQLRiteError::NotImplemented(
1073 "DROP TABLE supports a single table per statement".to_string(),
1074 ));
1075 }
1076 let name = names[0].to_string();
1077
1078 if name == crate::sql::pager::MASTER_TABLE_NAME {
1079 return Err(SQLRiteError::General(format!(
1080 "'{}' is a reserved name used by the internal schema catalog",
1081 crate::sql::pager::MASTER_TABLE_NAME
1082 )));
1083 }
1084
1085 if !db.contains_table(name.clone()) {
1086 return if if_exists {
1087 Ok(0)
1088 } else {
1089 Err(SQLRiteError::General(format!(
1090 "Table '{name}' does not exist"
1091 )))
1092 };
1093 }
1094
1095 db.tables.remove(&name);
1096 Ok(1)
1097}
1098
1099pub fn execute_drop_index(
1108 names: &[ObjectName],
1109 if_exists: bool,
1110 db: &mut Database,
1111) -> Result<usize> {
1112 if names.len() != 1 {
1113 return Err(SQLRiteError::NotImplemented(
1114 "DROP INDEX supports a single index per statement".to_string(),
1115 ));
1116 }
1117 let name = names[0].to_string();
1118
1119 for table in db.tables.values_mut() {
1120 if let Some(secondary) = table.secondary_indexes.iter().find(|i| i.name == name) {
1121 if secondary.origin == IndexOrigin::Auto {
1122 return Err(SQLRiteError::General(format!(
1123 "cannot drop auto-created index '{name}' (drop the column or table instead)"
1124 )));
1125 }
1126 table.secondary_indexes.retain(|i| i.name != name);
1127 return Ok(1);
1128 }
1129 if table.hnsw_indexes.iter().any(|i| i.name == name) {
1130 table.hnsw_indexes.retain(|i| i.name != name);
1131 return Ok(1);
1132 }
1133 if table.fts_indexes.iter().any(|i| i.name == name) {
1134 table.fts_indexes.retain(|i| i.name != name);
1135 return Ok(1);
1136 }
1137 }
1138
1139 if if_exists {
1140 Ok(0)
1141 } else {
1142 Err(SQLRiteError::General(format!(
1143 "Index '{name}' does not exist"
1144 )))
1145 }
1146}
1147
1148pub fn execute_alter_table(alter: AlterTable, db: &mut Database) -> Result<String> {
1160 let table_name = alter.name.to_string();
1161
1162 if table_name == crate::sql::pager::MASTER_TABLE_NAME {
1163 return Err(SQLRiteError::General(format!(
1164 "'{}' is a reserved name used by the internal schema catalog",
1165 crate::sql::pager::MASTER_TABLE_NAME
1166 )));
1167 }
1168
1169 if !db.contains_table(table_name.clone()) {
1170 return if alter.if_exists {
1171 Ok("ALTER TABLE: no-op (table does not exist)".to_string())
1172 } else {
1173 Err(SQLRiteError::General(format!(
1174 "Table '{table_name}' does not exist"
1175 )))
1176 };
1177 }
1178
1179 if alter.operations.len() != 1 {
1180 return Err(SQLRiteError::NotImplemented(
1181 "ALTER TABLE supports one operation per statement".to_string(),
1182 ));
1183 }
1184
1185 match &alter.operations[0] {
1186 AlterTableOperation::RenameTable { table_name: kind } => {
1187 let new_name = match kind {
1188 RenameTableNameKind::To(name) => name.to_string(),
1189 RenameTableNameKind::As(_) => {
1190 return Err(SQLRiteError::NotImplemented(
1191 "ALTER TABLE ... RENAME AS (MySQL-only) is not supported; use RENAME TO"
1192 .to_string(),
1193 ));
1194 }
1195 };
1196 alter_rename_table(db, &table_name, &new_name)?;
1197 Ok(format!(
1198 "ALTER TABLE '{table_name}' RENAME TO '{new_name}' executed."
1199 ))
1200 }
1201 AlterTableOperation::RenameColumn {
1202 old_column_name,
1203 new_column_name,
1204 } => {
1205 let old = old_column_name.value.clone();
1206 let new = new_column_name.value.clone();
1207 db.get_table_mut(table_name.clone())?
1208 .rename_column(&old, &new)?;
1209 Ok(format!(
1210 "ALTER TABLE '{table_name}' RENAME COLUMN '{old}' TO '{new}' executed."
1211 ))
1212 }
1213 AlterTableOperation::AddColumn {
1214 column_def,
1215 if_not_exists,
1216 ..
1217 } => {
1218 let parsed = crate::sql::parser::create::parse_one_column(column_def)?;
1219 let table = db.get_table_mut(table_name.clone())?;
1220 if *if_not_exists && table.contains_column(parsed.name.clone()) {
1221 return Ok(format!(
1222 "ALTER TABLE '{table_name}' ADD COLUMN: no-op (column '{}' already exists)",
1223 parsed.name
1224 ));
1225 }
1226 let col_name = parsed.name.clone();
1227 table.add_column(parsed)?;
1228 Ok(format!(
1229 "ALTER TABLE '{table_name}' ADD COLUMN '{col_name}' executed."
1230 ))
1231 }
1232 AlterTableOperation::DropColumn {
1233 column_names,
1234 if_exists,
1235 ..
1236 } => {
1237 if column_names.len() != 1 {
1238 return Err(SQLRiteError::NotImplemented(
1239 "ALTER TABLE DROP COLUMN supports a single column per statement".to_string(),
1240 ));
1241 }
1242 let col_name = column_names[0].value.clone();
1243 let table = db.get_table_mut(table_name.clone())?;
1244 if *if_exists && !table.contains_column(col_name.clone()) {
1245 return Ok(format!(
1246 "ALTER TABLE '{table_name}' DROP COLUMN: no-op (column '{col_name}' does not exist)"
1247 ));
1248 }
1249 table.drop_column(&col_name)?;
1250 Ok(format!(
1251 "ALTER TABLE '{table_name}' DROP COLUMN '{col_name}' executed."
1252 ))
1253 }
1254 other => Err(SQLRiteError::NotImplemented(format!(
1255 "ALTER TABLE operation {other:?} is not supported"
1256 ))),
1257 }
1258}
1259
1260pub fn execute_vacuum(db: &mut Database) -> Result<String> {
1270 if db.in_transaction() {
1271 return Err(SQLRiteError::General(
1272 "VACUUM cannot run inside a transaction".to_string(),
1273 ));
1274 }
1275 let path = match db.source_path.clone() {
1276 Some(p) => p,
1277 None => {
1278 return Ok("VACUUM is a no-op for in-memory databases".to_string());
1279 }
1280 };
1281 if let Some(pager) = db.pager.as_mut() {
1287 let _ = pager.checkpoint();
1288 }
1289 let size_before = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
1290 let pages_before = db
1291 .pager
1292 .as_ref()
1293 .map(|p| p.header().page_count)
1294 .unwrap_or(0);
1295 crate::sql::pager::vacuum_database(db, &path)?;
1296 if let Some(pager) = db.pager.as_mut() {
1299 let _ = pager.checkpoint();
1300 }
1301 let size_after = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
1302 let pages_after = db
1303 .pager
1304 .as_ref()
1305 .map(|p| p.header().page_count)
1306 .unwrap_or(0);
1307 let pages_reclaimed = pages_before.saturating_sub(pages_after);
1308 let bytes_reclaimed = size_before.saturating_sub(size_after);
1309 Ok(format!(
1310 "VACUUM completed. {pages_reclaimed} pages reclaimed ({bytes_reclaimed} bytes)."
1311 ))
1312}
1313
1314fn alter_rename_table(db: &mut Database, old: &str, new: &str) -> Result<()> {
1320 if new == crate::sql::pager::MASTER_TABLE_NAME {
1321 return Err(SQLRiteError::General(format!(
1322 "'{}' is a reserved name used by the internal schema catalog",
1323 crate::sql::pager::MASTER_TABLE_NAME
1324 )));
1325 }
1326 if old == new {
1327 return Ok(());
1328 }
1329 if db.contains_table(new.to_string()) {
1330 return Err(SQLRiteError::General(format!(
1331 "target table '{new}' already exists"
1332 )));
1333 }
1334
1335 let mut table = db
1336 .tables
1337 .remove(old)
1338 .ok_or_else(|| SQLRiteError::General(format!("Table '{old}' does not exist")))?;
1339 table.tb_name = new.to_string();
1340 for idx in table.secondary_indexes.iter_mut() {
1341 idx.table_name = new.to_string();
1342 if idx.origin == IndexOrigin::Auto
1343 && idx.name == SecondaryIndex::auto_name(old, &idx.column_name)
1344 {
1345 idx.name = SecondaryIndex::auto_name(new, &idx.column_name);
1346 }
1347 }
1348 db.tables.insert(new.to_string(), table);
1349 Ok(())
1350}
1351
1352#[derive(Debug, Clone, Copy)]
1356enum IndexMethod {
1357 Btree,
1358 Hnsw,
1359 Fts,
1361}
1362
1363fn create_btree_index(
1365 db: &mut Database,
1366 table_name: &str,
1367 index_name: &str,
1368 column_name: &str,
1369 datatype: &DataType,
1370 unique: bool,
1371 existing: &[(i64, Value)],
1372) -> Result<String> {
1373 let mut idx = SecondaryIndex::new(
1374 index_name.to_string(),
1375 table_name.to_string(),
1376 column_name.to_string(),
1377 datatype,
1378 unique,
1379 IndexOrigin::Explicit,
1380 )?;
1381
1382 for (rowid, v) in existing {
1386 if unique && idx.would_violate_unique(v) {
1387 return Err(SQLRiteError::General(format!(
1388 "cannot create UNIQUE index '{index_name}': column '{column_name}' \
1389 already contains the duplicate value {}",
1390 v.to_display_string()
1391 )));
1392 }
1393 idx.insert(v, *rowid)?;
1394 }
1395
1396 let table_mut = db.get_table_mut(table_name.to_string())?;
1397 table_mut.secondary_indexes.push(idx);
1398 Ok(index_name.to_string())
1399}
1400
1401fn create_hnsw_index(
1403 db: &mut Database,
1404 table_name: &str,
1405 index_name: &str,
1406 column_name: &str,
1407 datatype: &DataType,
1408 unique: bool,
1409 metric: DistanceMetric,
1410 existing: &[(i64, Value)],
1411) -> Result<String> {
1412 let dim = match datatype {
1415 DataType::Vector(d) => *d,
1416 other => {
1417 return Err(SQLRiteError::General(format!(
1418 "USING hnsw requires a VECTOR column; '{column_name}' is {other}"
1419 )));
1420 }
1421 };
1422
1423 if unique {
1424 return Err(SQLRiteError::General(
1425 "UNIQUE has no meaning for HNSW indexes".to_string(),
1426 ));
1427 }
1428
1429 let seed = hash_str_to_seed(index_name);
1440 let mut idx = HnswIndex::new(metric, seed);
1441
1442 let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
1446 std::collections::HashMap::with_capacity(existing.len());
1447 for (rowid, v) in existing {
1448 match v {
1449 Value::Vector(vec) => {
1450 if vec.len() != dim {
1451 return Err(SQLRiteError::Internal(format!(
1452 "row {rowid} stores a {}-dim vector in column '{column_name}' \
1453 declared as VECTOR({dim}) — schema invariant violated",
1454 vec.len()
1455 )));
1456 }
1457 vec_map.insert(*rowid, vec.clone());
1458 }
1459 _ => continue,
1463 }
1464 }
1465
1466 for (rowid, _) in existing {
1467 if let Some(v) = vec_map.get(rowid) {
1468 let v_clone = v.clone();
1469 idx.insert(*rowid, &v_clone, |id| {
1470 vec_map.get(&id).cloned().unwrap_or_default()
1471 });
1472 }
1473 }
1474
1475 let table_mut = db.get_table_mut(table_name.to_string())?;
1476 table_mut.hnsw_indexes.push(HnswIndexEntry {
1477 name: index_name.to_string(),
1478 column_name: column_name.to_string(),
1479 metric,
1480 index: idx,
1481 needs_rebuild: false,
1483 });
1484 Ok(index_name.to_string())
1485}
1486
1487fn parse_hnsw_with_options(
1498 with: &[Expr],
1499 index_name: &str,
1500 method: IndexMethod,
1501) -> Result<Option<DistanceMetric>> {
1502 if with.is_empty() {
1503 return Ok(None);
1504 }
1505 if !matches!(method, IndexMethod::Hnsw) {
1506 return Err(SQLRiteError::General(format!(
1507 "CREATE INDEX '{index_name}' has a WITH (...) clause but its index method \
1508 doesn't support any options — only `USING hnsw` recognises `WITH (metric = ...)`"
1509 )));
1510 }
1511
1512 let mut metric: Option<DistanceMetric> = None;
1513 for opt in with {
1514 let Expr::BinaryOp { left, op, right } = opt else {
1515 return Err(SQLRiteError::General(format!(
1516 "CREATE INDEX '{index_name}': unsupported WITH option {opt:?} \
1517 (expected `key = 'value'`)"
1518 )));
1519 };
1520 if !matches!(op, BinaryOperator::Eq) {
1521 return Err(SQLRiteError::General(format!(
1522 "CREATE INDEX '{index_name}': WITH options must use `=` (got {op:?})"
1523 )));
1524 }
1525 let key = match left.as_ref() {
1526 Expr::Identifier(ident) => ident.value.clone(),
1527 other => {
1528 return Err(SQLRiteError::General(format!(
1529 "CREATE INDEX '{index_name}': WITH option key must be a bare identifier, \
1530 got {other:?}"
1531 )));
1532 }
1533 };
1534 let value = match right.as_ref() {
1535 Expr::Value(v) => match &v.value {
1536 AstValue::SingleQuotedString(s) => s.clone(),
1537 AstValue::DoubleQuotedString(s) => s.clone(),
1538 other => {
1539 return Err(SQLRiteError::General(format!(
1540 "CREATE INDEX '{index_name}': WITH option '{key}' value must be \
1541 a quoted string, got {other:?}"
1542 )));
1543 }
1544 },
1545 Expr::Identifier(ident) => ident.value.clone(),
1546 other => {
1547 return Err(SQLRiteError::General(format!(
1548 "CREATE INDEX '{index_name}': WITH option '{key}' value must be a \
1549 quoted string, got {other:?}"
1550 )));
1551 }
1552 };
1553
1554 if key.eq_ignore_ascii_case("metric") {
1555 let parsed = DistanceMetric::from_sql_name(&value).ok_or_else(|| {
1556 SQLRiteError::General(format!(
1557 "CREATE INDEX '{index_name}': unknown HNSW metric '{value}' \
1558 (try 'l2', 'cosine', or 'dot')"
1559 ))
1560 })?;
1561 if metric.is_some() {
1562 return Err(SQLRiteError::General(format!(
1563 "CREATE INDEX '{index_name}': metric specified more than once in WITH (...)"
1564 )));
1565 }
1566 metric = Some(parsed);
1567 } else {
1568 return Err(SQLRiteError::General(format!(
1569 "CREATE INDEX '{index_name}': unknown WITH option '{key}' \
1570 (only 'metric' is recognised on HNSW indexes)"
1571 )));
1572 }
1573 }
1574
1575 Ok(metric)
1576}
1577
1578fn create_fts_index(
1583 db: &mut Database,
1584 table_name: &str,
1585 index_name: &str,
1586 column_name: &str,
1587 datatype: &DataType,
1588 unique: bool,
1589 existing: &[(i64, Value)],
1590) -> Result<String> {
1591 match datatype {
1596 DataType::Text => {}
1597 other => {
1598 return Err(SQLRiteError::General(format!(
1599 "USING fts requires a TEXT column; '{column_name}' is {other}"
1600 )));
1601 }
1602 }
1603
1604 if unique {
1605 return Err(SQLRiteError::General(
1606 "UNIQUE has no meaning for FTS indexes".to_string(),
1607 ));
1608 }
1609
1610 let mut idx = PostingList::new();
1611 for (rowid, v) in existing {
1612 if let Value::Text(text) = v {
1613 idx.insert(*rowid, text);
1614 }
1615 }
1618
1619 let table_mut = db.get_table_mut(table_name.to_string())?;
1620 table_mut.fts_indexes.push(FtsIndexEntry {
1621 name: index_name.to_string(),
1622 column_name: column_name.to_string(),
1623 index: idx,
1624 needs_rebuild: false,
1625 });
1626 Ok(index_name.to_string())
1627}
1628
1629fn hash_str_to_seed(s: &str) -> u64 {
1633 let mut h: u64 = 0xCBF29CE484222325;
1634 for b in s.as_bytes() {
1635 h ^= *b as u64;
1636 h = h.wrapping_mul(0x100000001B3);
1637 }
1638 h
1639}
1640
1641fn clone_datatype(dt: &DataType) -> DataType {
1644 match dt {
1645 DataType::Integer => DataType::Integer,
1646 DataType::Text => DataType::Text,
1647 DataType::Real => DataType::Real,
1648 DataType::Bool => DataType::Bool,
1649 DataType::Vector(dim) => DataType::Vector(*dim),
1650 DataType::Json => DataType::Json,
1651 DataType::None => DataType::None,
1652 DataType::Invalid => DataType::Invalid,
1653 }
1654}
1655
1656fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
1657 if tables.len() != 1 {
1658 return Err(SQLRiteError::NotImplemented(
1659 "multi-table DELETE is not supported yet".to_string(),
1660 ));
1661 }
1662 extract_table_name(&tables[0])
1663}
1664
1665fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
1666 if !twj.joins.is_empty() {
1667 return Err(SQLRiteError::NotImplemented(
1668 "JOIN is not supported yet".to_string(),
1669 ));
1670 }
1671 match &twj.relation {
1672 TableFactor::Table { name, .. } => Ok(name.to_string()),
1673 _ => Err(SQLRiteError::NotImplemented(
1674 "only plain table references are supported".to_string(),
1675 )),
1676 }
1677}
1678
1679enum RowidSource {
1681 IndexProbe(Vec<i64>),
1685 FullScan,
1688}
1689
1690fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
1695 let Some(expr) = selection else {
1696 return Ok(RowidSource::FullScan);
1697 };
1698 let Some((col, literal)) = try_extract_equality(expr) else {
1699 return Ok(RowidSource::FullScan);
1700 };
1701 let Some(idx) = table.index_for_column(&col) else {
1702 return Ok(RowidSource::FullScan);
1703 };
1704
1705 let literal_value = match convert_literal(&literal) {
1709 Ok(v) => v,
1710 Err(_) => return Ok(RowidSource::FullScan),
1711 };
1712
1713 let mut rowids = idx.lookup(&literal_value);
1717 rowids.sort_unstable();
1718 Ok(RowidSource::IndexProbe(rowids))
1719}
1720
1721fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
1725 let peeled = match expr {
1727 Expr::Nested(inner) => inner.as_ref(),
1728 other => other,
1729 };
1730 let Expr::BinaryOp { left, op, right } = peeled else {
1731 return None;
1732 };
1733 if !matches!(op, BinaryOperator::Eq) {
1734 return None;
1735 }
1736 let col_from = |e: &Expr| -> Option<String> {
1737 match e {
1738 Expr::Identifier(ident) => Some(ident.value.clone()),
1739 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
1740 _ => None,
1741 }
1742 };
1743 let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
1744 if let Expr::Value(v) = e {
1745 Some(v.value.clone())
1746 } else {
1747 None
1748 }
1749 };
1750 if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
1751 return Some((c, l));
1752 }
1753 if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
1754 return Some((c, l));
1755 }
1756 None
1757}
1758
1759fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
1784 if k == 0 {
1785 return None;
1786 }
1787
1788 let func = match order_expr {
1791 Expr::Function(f) => f,
1792 _ => return None,
1793 };
1794 let fname = match func.name.0.as_slice() {
1795 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1796 _ => return None,
1797 };
1798 let query_metric = match fname.as_str() {
1799 "vec_distance_l2" => DistanceMetric::L2,
1800 "vec_distance_cosine" => DistanceMetric::Cosine,
1801 "vec_distance_dot" => DistanceMetric::Dot,
1802 _ => return None,
1803 };
1804
1805 let arg_list = match &func.args {
1807 FunctionArguments::List(l) => &l.args,
1808 _ => return None,
1809 };
1810 if arg_list.len() != 2 {
1811 return None;
1812 }
1813 let exprs: Vec<&Expr> = arg_list
1814 .iter()
1815 .filter_map(|a| match a {
1816 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
1817 _ => None,
1818 })
1819 .collect();
1820 if exprs.len() != 2 {
1821 return None;
1822 }
1823
1824 let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
1829 Some(v) => v,
1830 None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
1831 Some(v) => v,
1832 None => return None,
1833 },
1834 };
1835
1836 let entry = table
1841 .hnsw_indexes
1842 .iter()
1843 .find(|e| e.column_name == col_name && e.metric == query_metric)?;
1844
1845 let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
1851 Some(c) => match &c.datatype {
1852 DataType::Vector(d) => *d,
1853 _ => return None,
1854 },
1855 None => return None,
1856 };
1857 if query_vec.len() != declared_dim {
1858 return None;
1859 }
1860
1861 let column_for_closure = col_name.clone();
1865 let table_ref = table;
1866 let result = entry.index.search(&query_vec, k, |id| {
1867 match table_ref.get_value(&column_for_closure, id) {
1868 Some(Value::Vector(v)) => v,
1869 _ => Vec::new(),
1870 }
1871 });
1872 Some(result)
1873}
1874
1875fn try_fts_probe(table: &Table, order_expr: &Expr, ascending: bool, k: usize) -> Option<Vec<i64>> {
1891 if k == 0 || ascending {
1892 return None;
1896 }
1897
1898 let func = match order_expr {
1899 Expr::Function(f) => f,
1900 _ => return None,
1901 };
1902 let fname = match func.name.0.as_slice() {
1903 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1904 _ => return None,
1905 };
1906 if fname != "bm25_score" {
1907 return None;
1908 }
1909
1910 let arg_list = match &func.args {
1911 FunctionArguments::List(l) => &l.args,
1912 _ => return None,
1913 };
1914 if arg_list.len() != 2 {
1915 return None;
1916 }
1917 let exprs: Vec<&Expr> = arg_list
1918 .iter()
1919 .filter_map(|a| match a {
1920 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
1921 _ => None,
1922 })
1923 .collect();
1924 if exprs.len() != 2 {
1925 return None;
1926 }
1927
1928 let col_name = match exprs[0] {
1930 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
1931 _ => return None,
1932 };
1933
1934 let query = match exprs[1] {
1938 Expr::Value(v) => match &v.value {
1939 AstValue::SingleQuotedString(s) => s.clone(),
1940 _ => return None,
1941 },
1942 _ => return None,
1943 };
1944
1945 let entry = table
1946 .fts_indexes
1947 .iter()
1948 .find(|e| e.column_name == col_name)?;
1949
1950 let scored = entry.index.query(&query, &Bm25Params::default());
1951 let mut out: Vec<i64> = scored.into_iter().map(|(id, _)| id).collect();
1952 if out.len() > k {
1953 out.truncate(k);
1954 }
1955 Some(out)
1956}
1957
1958fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
1963 let col_name = match a {
1964 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
1965 _ => return None,
1966 };
1967 let lit_str = match b {
1968 Expr::Identifier(ident) if ident.quote_style == Some('[') => {
1969 format!("[{}]", ident.value)
1970 }
1971 _ => return None,
1972 };
1973 let v = parse_vector_literal(&lit_str).ok()?;
1974 Some((col_name, v))
1975}
1976
1977struct HeapEntry {
1990 key: Value,
1991 rowid: i64,
1992 asc: bool,
1993}
1994
1995impl PartialEq for HeapEntry {
1996 fn eq(&self, other: &Self) -> bool {
1997 self.cmp(other) == Ordering::Equal
1998 }
1999}
2000
2001impl Eq for HeapEntry {}
2002
2003impl PartialOrd for HeapEntry {
2004 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2005 Some(self.cmp(other))
2006 }
2007}
2008
2009impl Ord for HeapEntry {
2010 fn cmp(&self, other: &Self) -> Ordering {
2011 let raw = compare_values(Some(&self.key), Some(&other.key));
2012 if self.asc { raw } else { raw.reverse() }
2013 }
2014}
2015
2016fn select_topk(
2025 matching: &[i64],
2026 table: &Table,
2027 order: &OrderByClause,
2028 k: usize,
2029) -> Result<Vec<i64>> {
2030 use std::collections::BinaryHeap;
2031
2032 if k == 0 || matching.is_empty() {
2033 return Ok(Vec::new());
2034 }
2035
2036 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
2037
2038 for &rowid in matching {
2039 let key = eval_expr(&order.expr, table, rowid)?;
2040 let entry = HeapEntry {
2041 key,
2042 rowid,
2043 asc: order.ascending,
2044 };
2045
2046 if heap.len() < k {
2047 heap.push(entry);
2048 } else {
2049 if entry < *heap.peek().unwrap() {
2053 heap.pop();
2054 heap.push(entry);
2055 }
2056 }
2057 }
2058
2059 Ok(heap
2064 .into_sorted_vec()
2065 .into_iter()
2066 .map(|e| e.rowid)
2067 .collect())
2068}
2069
2070fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
2071 let mut keys: Vec<(i64, Result<Value>)> = rowids
2079 .iter()
2080 .map(|r| (*r, eval_expr(&order.expr, table, *r)))
2081 .collect();
2082
2083 for (_, k) in &keys {
2087 if let Err(e) = k {
2088 return Err(SQLRiteError::General(format!(
2089 "ORDER BY expression failed: {e}"
2090 )));
2091 }
2092 }
2093
2094 keys.sort_by(|(_, ka), (_, kb)| {
2095 let va = ka.as_ref().unwrap();
2098 let vb = kb.as_ref().unwrap();
2099 let ord = compare_values(Some(va), Some(vb));
2100 if order.ascending { ord } else { ord.reverse() }
2101 });
2102
2103 for (i, (rowid, _)) in keys.into_iter().enumerate() {
2105 rowids[i] = rowid;
2106 }
2107 Ok(())
2108}
2109
2110fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
2111 match (a, b) {
2112 (None, None) => Ordering::Equal,
2113 (None, _) => Ordering::Less,
2114 (_, None) => Ordering::Greater,
2115 (Some(a), Some(b)) => match (a, b) {
2116 (Value::Null, Value::Null) => Ordering::Equal,
2117 (Value::Null, _) => Ordering::Less,
2118 (_, Value::Null) => Ordering::Greater,
2119 (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
2120 (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
2121 (Value::Integer(x), Value::Real(y)) => {
2122 (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
2123 }
2124 (Value::Real(x), Value::Integer(y)) => {
2125 x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
2126 }
2127 (Value::Text(x), Value::Text(y)) => x.cmp(y),
2128 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
2129 (x, y) => x.to_display_string().cmp(&y.to_display_string()),
2131 },
2132 }
2133}
2134
2135pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
2137 eval_predicate_scope(expr, &SingleTableScope::new(table, rowid))
2138}
2139
2140pub(crate) fn eval_predicate_scope(expr: &Expr, scope: &dyn RowScope) -> Result<bool> {
2144 let v = eval_expr_scope(expr, scope)?;
2145 match v {
2146 Value::Bool(b) => Ok(b),
2147 Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
2149 other => Err(SQLRiteError::Internal(format!(
2150 "WHERE clause must evaluate to boolean, got {}",
2151 other.to_display_string()
2152 ))),
2153 }
2154}
2155
2156fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
2158 eval_expr_scope(expr, &SingleTableScope::new(table, rowid))
2159}
2160
2161fn eval_expr_scope(expr: &Expr, scope: &dyn RowScope) -> Result<Value> {
2162 match expr {
2163 Expr::Nested(inner) => eval_expr_scope(inner, scope),
2164
2165 Expr::Identifier(ident) => {
2166 if ident.quote_style == Some('[') {
2176 let raw = format!("[{}]", ident.value);
2177 let v = parse_vector_literal(&raw)?;
2178 return Ok(Value::Vector(v));
2179 }
2180 scope.lookup(None, &ident.value)
2181 }
2182
2183 Expr::CompoundIdentifier(parts) => {
2184 match parts.as_slice() {
2190 [only] => scope.lookup(None, &only.value),
2191 [q, c] => scope.lookup(Some(&q.value), &c.value),
2192 _ => Err(SQLRiteError::NotImplemented(format!(
2193 "compound identifier with {} parts is not supported",
2194 parts.len()
2195 ))),
2196 }
2197 }
2198
2199 Expr::Value(v) => convert_literal(&v.value),
2200
2201 Expr::UnaryOp { op, expr } => {
2202 let inner = eval_expr_scope(expr, scope)?;
2203 match op {
2204 UnaryOperator::Not => match inner {
2205 Value::Bool(b) => Ok(Value::Bool(!b)),
2206 Value::Null => Ok(Value::Null),
2207 other => Err(SQLRiteError::Internal(format!(
2208 "NOT applied to non-boolean value: {}",
2209 other.to_display_string()
2210 ))),
2211 },
2212 UnaryOperator::Minus => match inner {
2213 Value::Integer(i) => Ok(Value::Integer(-i)),
2214 Value::Real(f) => Ok(Value::Real(-f)),
2215 Value::Null => Ok(Value::Null),
2216 other => Err(SQLRiteError::Internal(format!(
2217 "unary minus on non-numeric value: {}",
2218 other.to_display_string()
2219 ))),
2220 },
2221 UnaryOperator::Plus => Ok(inner),
2222 other => Err(SQLRiteError::NotImplemented(format!(
2223 "unary operator {other:?} is not supported"
2224 ))),
2225 }
2226 }
2227
2228 Expr::BinaryOp { left, op, right } => match op {
2229 BinaryOperator::And => {
2230 let l = eval_expr_scope(left, scope)?;
2231 let r = eval_expr_scope(right, scope)?;
2232 Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
2233 }
2234 BinaryOperator::Or => {
2235 let l = eval_expr_scope(left, scope)?;
2236 let r = eval_expr_scope(right, scope)?;
2237 Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
2238 }
2239 cmp @ (BinaryOperator::Eq
2240 | BinaryOperator::NotEq
2241 | BinaryOperator::Lt
2242 | BinaryOperator::LtEq
2243 | BinaryOperator::Gt
2244 | BinaryOperator::GtEq) => {
2245 let l = eval_expr_scope(left, scope)?;
2246 let r = eval_expr_scope(right, scope)?;
2247 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2249 return Ok(Value::Bool(false));
2250 }
2251 let ord = compare_values(Some(&l), Some(&r));
2252 let result = match cmp {
2253 BinaryOperator::Eq => ord == Ordering::Equal,
2254 BinaryOperator::NotEq => ord != Ordering::Equal,
2255 BinaryOperator::Lt => ord == Ordering::Less,
2256 BinaryOperator::LtEq => ord != Ordering::Greater,
2257 BinaryOperator::Gt => ord == Ordering::Greater,
2258 BinaryOperator::GtEq => ord != Ordering::Less,
2259 _ => unreachable!(),
2260 };
2261 Ok(Value::Bool(result))
2262 }
2263 arith @ (BinaryOperator::Plus
2264 | BinaryOperator::Minus
2265 | BinaryOperator::Multiply
2266 | BinaryOperator::Divide
2267 | BinaryOperator::Modulo) => {
2268 let l = eval_expr_scope(left, scope)?;
2269 let r = eval_expr_scope(right, scope)?;
2270 eval_arith(arith, &l, &r)
2271 }
2272 BinaryOperator::StringConcat => {
2273 let l = eval_expr_scope(left, scope)?;
2274 let r = eval_expr_scope(right, scope)?;
2275 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2276 return Ok(Value::Null);
2277 }
2278 Ok(Value::Text(format!(
2279 "{}{}",
2280 l.to_display_string(),
2281 r.to_display_string()
2282 )))
2283 }
2284 other => Err(SQLRiteError::NotImplemented(format!(
2285 "binary operator {other:?} is not supported yet"
2286 ))),
2287 },
2288
2289 Expr::IsNull(inner) => {
2297 let v = eval_expr_scope(inner, scope)?;
2298 Ok(Value::Bool(matches!(v, Value::Null)))
2299 }
2300 Expr::IsNotNull(inner) => {
2301 let v = eval_expr_scope(inner, scope)?;
2302 Ok(Value::Bool(!matches!(v, Value::Null)))
2303 }
2304
2305 Expr::Like {
2312 negated,
2313 any,
2314 expr: lhs,
2315 pattern,
2316 escape_char,
2317 } => eval_like(
2318 scope,
2319 *negated,
2320 *any,
2321 lhs,
2322 pattern,
2323 escape_char.as_ref(),
2324 true,
2325 ),
2326 Expr::ILike {
2327 negated,
2328 any,
2329 expr: lhs,
2330 pattern,
2331 escape_char,
2332 } => eval_like(
2333 scope,
2334 *negated,
2335 *any,
2336 lhs,
2337 pattern,
2338 escape_char.as_ref(),
2339 true,
2340 ),
2341
2342 Expr::InList {
2348 expr: lhs,
2349 list,
2350 negated,
2351 } => eval_in_list(scope, lhs, list, *negated),
2352 Expr::InSubquery { .. } => Err(SQLRiteError::NotImplemented(
2353 "IN (subquery) is not supported (only literal lists are)".to_string(),
2354 )),
2355
2356 Expr::Function(func) => eval_function(func, scope),
2367
2368 other => Err(SQLRiteError::NotImplemented(format!(
2369 "unsupported expression in WHERE/projection: {other:?}"
2370 ))),
2371 }
2372}
2373
2374fn eval_function(func: &sqlparser::ast::Function, scope: &dyn RowScope) -> Result<Value> {
2379 let name = match func.name.0.as_slice() {
2382 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
2383 _ => {
2384 return Err(SQLRiteError::NotImplemented(format!(
2385 "qualified function names not supported: {:?}",
2386 func.name
2387 )));
2388 }
2389 };
2390
2391 match name.as_str() {
2392 "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
2393 let (a, b) = extract_two_vector_args(&name, &func.args, scope)?;
2394 let dist = match name.as_str() {
2395 "vec_distance_l2" => vec_distance_l2(&a, &b),
2396 "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
2397 "vec_distance_dot" => vec_distance_dot(&a, &b),
2398 _ => unreachable!(),
2399 };
2400 Ok(Value::Real(dist as f64))
2406 }
2407 "json_extract" => json_fn_extract(&name, &func.args, scope),
2412 "json_type" => json_fn_type(&name, &func.args, scope),
2413 "json_array_length" => json_fn_array_length(&name, &func.args, scope),
2414 "json_object_keys" => json_fn_object_keys(&name, &func.args, scope),
2415 "fts_match" | "bm25_score" => {
2426 let Some((table, rowid)) = scope.single_table_view() else {
2427 return Err(SQLRiteError::NotImplemented(format!(
2428 "{name}() is not yet supported inside a JOIN query — \
2429 use it on a single-table SELECT or move the FTS lookup into a subquery"
2430 )));
2431 };
2432 let (entry, query) = resolve_fts_args(&name, &func.args, table, scope)?;
2433 Ok(match name.as_str() {
2434 "fts_match" => Value::Bool(entry.index.matches(rowid, &query)),
2435 "bm25_score" => {
2436 Value::Real(entry.index.score(rowid, &query, &Bm25Params::default()))
2437 }
2438 _ => unreachable!(),
2439 })
2440 }
2441 "count" | "sum" | "avg" | "min" | "max" => Err(SQLRiteError::NotImplemented(format!(
2445 "aggregate function '{name}' is not allowed in WHERE / projection-scalar position; \
2446 use it as a top-level projection item (HAVING is not yet supported)"
2447 ))),
2448 other => Err(SQLRiteError::NotImplemented(format!(
2449 "unknown function: {other}(...)"
2450 ))),
2451 }
2452}
2453
2454fn resolve_fts_args<'t>(
2459 fn_name: &str,
2460 args: &FunctionArguments,
2461 table: &'t Table,
2462 scope: &dyn RowScope,
2463) -> Result<(&'t FtsIndexEntry, String)> {
2464 let arg_list = match args {
2465 FunctionArguments::List(l) => &l.args,
2466 _ => {
2467 return Err(SQLRiteError::General(format!(
2468 "{fn_name}() expects exactly two arguments: (column, query_text)"
2469 )));
2470 }
2471 };
2472 if arg_list.len() != 2 {
2473 return Err(SQLRiteError::General(format!(
2474 "{fn_name}() expects exactly 2 arguments, got {}",
2475 arg_list.len()
2476 )));
2477 }
2478
2479 let col_expr = match &arg_list[0] {
2483 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2484 other => {
2485 return Err(SQLRiteError::NotImplemented(format!(
2486 "{fn_name}() argument 0 must be a column name, got {other:?}"
2487 )));
2488 }
2489 };
2490 let col_name = match col_expr {
2491 Expr::Identifier(ident) => ident.value.clone(),
2492 Expr::CompoundIdentifier(parts) => parts
2493 .last()
2494 .map(|p| p.value.clone())
2495 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
2496 other => {
2497 return Err(SQLRiteError::General(format!(
2498 "{fn_name}() argument 0 must be a column reference, got {other:?}"
2499 )));
2500 }
2501 };
2502
2503 let q_expr = match &arg_list[1] {
2507 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2508 other => {
2509 return Err(SQLRiteError::NotImplemented(format!(
2510 "{fn_name}() argument 1 must be a text expression, got {other:?}"
2511 )));
2512 }
2513 };
2514 let query = match eval_expr_scope(q_expr, scope)? {
2515 Value::Text(s) => s,
2516 other => {
2517 return Err(SQLRiteError::General(format!(
2518 "{fn_name}() argument 1 must be TEXT, got {}",
2519 other.to_display_string()
2520 )));
2521 }
2522 };
2523
2524 let entry = table
2525 .fts_indexes
2526 .iter()
2527 .find(|e| e.column_name == col_name)
2528 .ok_or_else(|| {
2529 SQLRiteError::General(format!(
2530 "{fn_name}({col_name}, ...): no FTS index on column '{col_name}' \
2531 (run CREATE INDEX <name> ON <table> USING fts({col_name}) first)"
2532 ))
2533 })?;
2534 Ok((entry, query))
2535}
2536
2537fn extract_json_and_path(
2551 fn_name: &str,
2552 args: &FunctionArguments,
2553 scope: &dyn RowScope,
2554) -> Result<(String, String)> {
2555 let arg_list = match args {
2556 FunctionArguments::List(l) => &l.args,
2557 _ => {
2558 return Err(SQLRiteError::General(format!(
2559 "{fn_name}() expects 1 or 2 arguments"
2560 )));
2561 }
2562 };
2563 if !(arg_list.len() == 1 || arg_list.len() == 2) {
2564 return Err(SQLRiteError::General(format!(
2565 "{fn_name}() expects 1 or 2 arguments, got {}",
2566 arg_list.len()
2567 )));
2568 }
2569 let first_expr = match &arg_list[0] {
2571 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2572 other => {
2573 return Err(SQLRiteError::NotImplemented(format!(
2574 "{fn_name}() argument 0 has unsupported shape: {other:?}"
2575 )));
2576 }
2577 };
2578 let json_text = match eval_expr_scope(first_expr, scope)? {
2579 Value::Text(s) => s,
2580 Value::Null => {
2581 return Err(SQLRiteError::General(format!(
2582 "{fn_name}() called on NULL — JSON column has no value for this row"
2583 )));
2584 }
2585 other => {
2586 return Err(SQLRiteError::General(format!(
2587 "{fn_name}() argument 0 is not JSON-typed: got {}",
2588 other.to_display_string()
2589 )));
2590 }
2591 };
2592
2593 let path = if arg_list.len() == 2 {
2595 let path_expr = match &arg_list[1] {
2596 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2597 other => {
2598 return Err(SQLRiteError::NotImplemented(format!(
2599 "{fn_name}() argument 1 has unsupported shape: {other:?}"
2600 )));
2601 }
2602 };
2603 match eval_expr_scope(path_expr, scope)? {
2604 Value::Text(s) => s,
2605 other => {
2606 return Err(SQLRiteError::General(format!(
2607 "{fn_name}() path argument must be a string literal, got {}",
2608 other.to_display_string()
2609 )));
2610 }
2611 }
2612 } else {
2613 "$".to_string()
2614 };
2615
2616 Ok((json_text, path))
2617}
2618
2619fn walk_json_path<'a>(
2629 value: &'a serde_json::Value,
2630 path: &str,
2631) -> Result<Option<&'a serde_json::Value>> {
2632 let mut chars = path.chars().peekable();
2633 if chars.next() != Some('$') {
2634 return Err(SQLRiteError::General(format!(
2635 "JSON path must start with '$', got `{path}`"
2636 )));
2637 }
2638 let mut current = value;
2639 while let Some(&c) = chars.peek() {
2640 match c {
2641 '.' => {
2642 chars.next();
2643 let mut key = String::new();
2644 while let Some(&c) = chars.peek() {
2645 if c == '.' || c == '[' {
2646 break;
2647 }
2648 key.push(c);
2649 chars.next();
2650 }
2651 if key.is_empty() {
2652 return Err(SQLRiteError::General(format!(
2653 "JSON path has empty key after '.' in `{path}`"
2654 )));
2655 }
2656 match current.get(&key) {
2657 Some(v) => current = v,
2658 None => return Ok(None),
2659 }
2660 }
2661 '[' => {
2662 chars.next();
2663 let mut idx_str = String::new();
2664 while let Some(&c) = chars.peek() {
2665 if c == ']' {
2666 break;
2667 }
2668 idx_str.push(c);
2669 chars.next();
2670 }
2671 if chars.next() != Some(']') {
2672 return Err(SQLRiteError::General(format!(
2673 "JSON path has unclosed `[` in `{path}`"
2674 )));
2675 }
2676 let idx: usize = idx_str.trim().parse().map_err(|_| {
2677 SQLRiteError::General(format!(
2678 "JSON path has non-integer index `[{idx_str}]` in `{path}`"
2679 ))
2680 })?;
2681 match current.get(idx) {
2682 Some(v) => current = v,
2683 None => return Ok(None),
2684 }
2685 }
2686 other => {
2687 return Err(SQLRiteError::General(format!(
2688 "JSON path has unexpected character `{other}` in `{path}` \
2689 (expected `.`, `[`, or end-of-path)"
2690 )));
2691 }
2692 }
2693 }
2694 Ok(Some(current))
2695}
2696
2697fn json_value_to_sql(v: &serde_json::Value) -> Value {
2701 match v {
2702 serde_json::Value::Null => Value::Null,
2703 serde_json::Value::Bool(b) => Value::Bool(*b),
2704 serde_json::Value::Number(n) => {
2705 if let Some(i) = n.as_i64() {
2707 Value::Integer(i)
2708 } else if let Some(f) = n.as_f64() {
2709 Value::Real(f)
2710 } else {
2711 Value::Null
2712 }
2713 }
2714 serde_json::Value::String(s) => Value::Text(s.clone()),
2715 composite => Value::Text(composite.to_string()),
2719 }
2720}
2721
2722fn json_fn_extract(name: &str, args: &FunctionArguments, scope: &dyn RowScope) -> Result<Value> {
2723 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2724 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2725 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2726 })?;
2727 match walk_json_path(&parsed, &path)? {
2728 Some(v) => Ok(json_value_to_sql(v)),
2729 None => Ok(Value::Null),
2730 }
2731}
2732
2733fn json_fn_type(name: &str, args: &FunctionArguments, scope: &dyn RowScope) -> Result<Value> {
2734 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2735 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2736 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2737 })?;
2738 let resolved = match walk_json_path(&parsed, &path)? {
2739 Some(v) => v,
2740 None => return Ok(Value::Null),
2741 };
2742 let ty = match resolved {
2743 serde_json::Value::Null => "null",
2744 serde_json::Value::Bool(true) => "true",
2745 serde_json::Value::Bool(false) => "false",
2746 serde_json::Value::Number(n) => {
2747 if n.is_i64() || n.is_u64() {
2748 "integer"
2749 } else {
2750 "real"
2751 }
2752 }
2753 serde_json::Value::String(_) => "text",
2754 serde_json::Value::Array(_) => "array",
2755 serde_json::Value::Object(_) => "object",
2756 };
2757 Ok(Value::Text(ty.to_string()))
2758}
2759
2760fn json_fn_array_length(
2761 name: &str,
2762 args: &FunctionArguments,
2763 scope: &dyn RowScope,
2764) -> Result<Value> {
2765 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2766 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2767 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2768 })?;
2769 let resolved = match walk_json_path(&parsed, &path)? {
2770 Some(v) => v,
2771 None => return Ok(Value::Null),
2772 };
2773 match resolved.as_array() {
2774 Some(arr) => Ok(Value::Integer(arr.len() as i64)),
2775 None => Err(SQLRiteError::General(format!(
2776 "{name}() resolved to a non-array value at path `{path}`"
2777 ))),
2778 }
2779}
2780
2781fn json_fn_object_keys(
2782 name: &str,
2783 args: &FunctionArguments,
2784 scope: &dyn RowScope,
2785) -> Result<Value> {
2786 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2787 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2788 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2789 })?;
2790 let resolved = match walk_json_path(&parsed, &path)? {
2791 Some(v) => v,
2792 None => return Ok(Value::Null),
2793 };
2794 let obj = resolved.as_object().ok_or_else(|| {
2795 SQLRiteError::General(format!(
2796 "{name}() resolved to a non-object value at path `{path}`"
2797 ))
2798 })?;
2799 let keys: Vec<serde_json::Value> = obj
2806 .keys()
2807 .map(|k| serde_json::Value::String(k.clone()))
2808 .collect();
2809 Ok(Value::Text(serde_json::Value::Array(keys).to_string()))
2810}
2811
2812fn extract_two_vector_args(
2816 fn_name: &str,
2817 args: &FunctionArguments,
2818 scope: &dyn RowScope,
2819) -> Result<(Vec<f32>, Vec<f32>)> {
2820 let arg_list = match args {
2821 FunctionArguments::List(l) => &l.args,
2822 _ => {
2823 return Err(SQLRiteError::General(format!(
2824 "{fn_name}() expects exactly two vector arguments"
2825 )));
2826 }
2827 };
2828 if arg_list.len() != 2 {
2829 return Err(SQLRiteError::General(format!(
2830 "{fn_name}() expects exactly 2 arguments, got {}",
2831 arg_list.len()
2832 )));
2833 }
2834 let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
2835 for (i, arg) in arg_list.iter().enumerate() {
2836 let expr = match arg {
2837 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2838 other => {
2839 return Err(SQLRiteError::NotImplemented(format!(
2840 "{fn_name}() argument {i} has unsupported shape: {other:?}"
2841 )));
2842 }
2843 };
2844 let val = eval_expr_scope(expr, scope)?;
2845 match val {
2846 Value::Vector(v) => out.push(v),
2847 other => {
2848 return Err(SQLRiteError::General(format!(
2849 "{fn_name}() argument {i} is not a vector: got {}",
2850 other.to_display_string()
2851 )));
2852 }
2853 }
2854 }
2855 let b = out.pop().unwrap();
2856 let a = out.pop().unwrap();
2857 if a.len() != b.len() {
2858 return Err(SQLRiteError::General(format!(
2859 "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
2860 a.len(),
2861 b.len()
2862 )));
2863 }
2864 Ok((a, b))
2865}
2866
2867pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
2870 debug_assert_eq!(a.len(), b.len());
2871 let mut sum = 0.0f32;
2872 for i in 0..a.len() {
2873 let d = a[i] - b[i];
2874 sum += d * d;
2875 }
2876 sum.sqrt()
2877}
2878
2879pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
2889 debug_assert_eq!(a.len(), b.len());
2890 let mut dot = 0.0f32;
2891 let mut norm_a_sq = 0.0f32;
2892 let mut norm_b_sq = 0.0f32;
2893 for i in 0..a.len() {
2894 dot += a[i] * b[i];
2895 norm_a_sq += a[i] * a[i];
2896 norm_b_sq += b[i] * b[i];
2897 }
2898 let denom = (norm_a_sq * norm_b_sq).sqrt();
2899 if denom == 0.0 {
2900 return Err(SQLRiteError::General(
2901 "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
2902 ));
2903 }
2904 Ok(1.0 - dot / denom)
2905}
2906
2907pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
2911 debug_assert_eq!(a.len(), b.len());
2912 let mut dot = 0.0f32;
2913 for i in 0..a.len() {
2914 dot += a[i] * b[i];
2915 }
2916 -dot
2917}
2918
2919fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
2922 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2923 return Ok(Value::Null);
2924 }
2925 match (l, r) {
2926 (Value::Integer(a), Value::Integer(b)) => match op {
2927 BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
2928 BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
2929 BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
2930 BinaryOperator::Divide => {
2931 if *b == 0 {
2932 Err(SQLRiteError::General("division by zero".to_string()))
2933 } else {
2934 Ok(Value::Integer(a / b))
2935 }
2936 }
2937 BinaryOperator::Modulo => {
2938 if *b == 0 {
2939 Err(SQLRiteError::General("modulo by zero".to_string()))
2940 } else {
2941 Ok(Value::Integer(a % b))
2942 }
2943 }
2944 _ => unreachable!(),
2945 },
2946 (a, b) => {
2948 let af = as_number(a)?;
2949 let bf = as_number(b)?;
2950 match op {
2951 BinaryOperator::Plus => Ok(Value::Real(af + bf)),
2952 BinaryOperator::Minus => Ok(Value::Real(af - bf)),
2953 BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
2954 BinaryOperator::Divide => {
2955 if bf == 0.0 {
2956 Err(SQLRiteError::General("division by zero".to_string()))
2957 } else {
2958 Ok(Value::Real(af / bf))
2959 }
2960 }
2961 BinaryOperator::Modulo => {
2962 if bf == 0.0 {
2963 Err(SQLRiteError::General("modulo by zero".to_string()))
2964 } else {
2965 Ok(Value::Real(af % bf))
2966 }
2967 }
2968 _ => unreachable!(),
2969 }
2970 }
2971 }
2972}
2973
2974fn as_number(v: &Value) -> Result<f64> {
2975 match v {
2976 Value::Integer(i) => Ok(*i as f64),
2977 Value::Real(f) => Ok(*f),
2978 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
2979 other => Err(SQLRiteError::General(format!(
2980 "arithmetic on non-numeric value '{}'",
2981 other.to_display_string()
2982 ))),
2983 }
2984}
2985
2986fn as_bool(v: &Value) -> Result<bool> {
2987 match v {
2988 Value::Bool(b) => Ok(*b),
2989 Value::Null => Ok(false),
2990 Value::Integer(i) => Ok(*i != 0),
2991 other => Err(SQLRiteError::Internal(format!(
2992 "expected boolean, got {}",
2993 other.to_display_string()
2994 ))),
2995 }
2996}
2997
2998#[allow(clippy::too_many_arguments)]
3003fn eval_like(
3004 scope: &dyn RowScope,
3005 negated: bool,
3006 any: bool,
3007 lhs: &Expr,
3008 pattern: &Expr,
3009 escape_char: Option<&AstValue>,
3010 case_insensitive: bool,
3011) -> Result<Value> {
3012 if any {
3013 return Err(SQLRiteError::NotImplemented(
3014 "LIKE ANY (...) is not supported".to_string(),
3015 ));
3016 }
3017 if escape_char.is_some() {
3018 return Err(SQLRiteError::NotImplemented(
3019 "LIKE ... ESCAPE '<char>' is not supported (default `\\` escape only)".to_string(),
3020 ));
3021 }
3022
3023 let l = eval_expr_scope(lhs, scope)?;
3024 let p = eval_expr_scope(pattern, scope)?;
3025 if matches!(l, Value::Null) || matches!(p, Value::Null) {
3026 return Ok(Value::Null);
3027 }
3028 let text = match l {
3029 Value::Text(s) => s,
3030 other => other.to_display_string(),
3031 };
3032 let pat = match p {
3033 Value::Text(s) => s,
3034 other => other.to_display_string(),
3035 };
3036 let m = like_match(&text, &pat, case_insensitive);
3037 Ok(Value::Bool(if negated { !m } else { m }))
3038}
3039
3040fn eval_in_list(scope: &dyn RowScope, lhs: &Expr, list: &[Expr], negated: bool) -> Result<Value> {
3041 let l = eval_expr_scope(lhs, scope)?;
3042 if matches!(l, Value::Null) {
3043 return Ok(Value::Null);
3044 }
3045 let mut saw_null = false;
3046 for item in list {
3047 let r = eval_expr_scope(item, scope)?;
3048 if matches!(r, Value::Null) {
3049 saw_null = true;
3050 continue;
3051 }
3052 if compare_values(Some(&l), Some(&r)) == Ordering::Equal {
3053 return Ok(Value::Bool(!negated));
3054 }
3055 }
3056 if saw_null {
3057 Ok(Value::Null)
3060 } else {
3061 Ok(Value::Bool(negated))
3062 }
3063}
3064
3065fn aggregate_rows(
3076 table: &Table,
3077 matching: &[i64],
3078 group_by: &[String],
3079 proj_items: &[ProjectionItem],
3080) -> Result<Vec<Vec<Value>>> {
3081 let template: Vec<Option<AggState>> = proj_items
3085 .iter()
3086 .map(|i| match &i.kind {
3087 ProjectionKind::Aggregate(call) => Some(AggState::new(call)),
3088 ProjectionKind::Column { .. } => None,
3089 })
3090 .collect();
3091
3092 let mut keys: Vec<Vec<DistinctKey>> = Vec::new();
3098 let mut group_states: Vec<Vec<Option<AggState>>> = Vec::new();
3099 let mut group_key_values: Vec<Vec<Value>> = Vec::new();
3100
3101 for &rowid in matching {
3102 let mut key_values: Vec<Value> = Vec::with_capacity(group_by.len());
3103 let mut key: Vec<DistinctKey> = Vec::with_capacity(group_by.len());
3104 for col in group_by {
3105 let v = table.get_value(col, rowid).unwrap_or(Value::Null);
3106 key.push(DistinctKey::from_value(&v));
3107 key_values.push(v);
3108 }
3109 let idx = match keys.iter().position(|k| k == &key) {
3110 Some(i) => i,
3111 None => {
3112 keys.push(key);
3113 group_states.push(template.clone());
3114 group_key_values.push(key_values);
3115 keys.len() - 1
3116 }
3117 };
3118
3119 for (slot, item) in proj_items.iter().enumerate() {
3120 if let ProjectionKind::Aggregate(call) = &item.kind {
3121 let v = match &call.arg {
3122 AggregateArg::Star => Value::Null,
3123 AggregateArg::Column(c) => table.get_value(c, rowid).unwrap_or(Value::Null),
3124 };
3125 if let Some(state) = group_states[idx][slot].as_mut() {
3126 state.update(&v)?;
3127 }
3128 }
3129 }
3130 }
3131
3132 if keys.is_empty() && group_by.is_empty() {
3138 keys.push(Vec::new());
3141 group_states.push(template.clone());
3142 group_key_values.push(Vec::new());
3143 }
3144
3145 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(keys.len());
3147 for (group_idx, _) in keys.iter().enumerate() {
3148 let mut row: Vec<Value> = Vec::with_capacity(proj_items.len());
3149 for (slot, item) in proj_items.iter().enumerate() {
3150 match &item.kind {
3151 ProjectionKind::Column { name: c, .. } => {
3152 let pos = group_by
3155 .iter()
3156 .position(|g| g == c)
3157 .expect("validated to be in GROUP BY");
3158 row.push(group_key_values[group_idx][pos].clone());
3159 }
3160 ProjectionKind::Aggregate(_) => {
3161 let state = group_states[group_idx][slot]
3162 .as_ref()
3163 .expect("aggregate slot has state");
3164 row.push(state.finalize());
3165 }
3166 }
3167 }
3168 rows.push(row);
3169 }
3170 Ok(rows)
3171}
3172
3173fn dedupe_rows(rows: Vec<Vec<Value>>) -> Vec<Vec<Value>> {
3177 use std::collections::HashSet;
3178 let mut seen: HashSet<Vec<DistinctKey>> = HashSet::new();
3179 let mut out = Vec::with_capacity(rows.len());
3180 for row in rows {
3181 let key: Vec<DistinctKey> = row.iter().map(DistinctKey::from_value).collect();
3182 if seen.insert(key) {
3183 out.push(row);
3184 }
3185 }
3186 out
3187}
3188
3189fn sort_output_rows(
3193 rows: &mut [Vec<Value>],
3194 columns: &[String],
3195 proj_items: &[ProjectionItem],
3196 order: &OrderByClause,
3197) -> Result<()> {
3198 let target_idx = resolve_order_by_index(&order.expr, columns, proj_items)?;
3199 rows.sort_by(|a, b| {
3200 let va = &a[target_idx];
3201 let vb = &b[target_idx];
3202 let ord = compare_values(Some(va), Some(vb));
3203 if order.ascending { ord } else { ord.reverse() }
3204 });
3205 Ok(())
3206}
3207
3208fn resolve_order_by_index(
3211 expr: &Expr,
3212 columns: &[String],
3213 proj_items: &[ProjectionItem],
3214) -> Result<usize> {
3215 let target_name: Option<String> = match expr {
3217 Expr::Identifier(ident) => Some(ident.value.clone()),
3218 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
3219 Expr::Function(_) => None,
3220 Expr::Nested(inner) => return resolve_order_by_index(inner, columns, proj_items),
3221 other => {
3222 return Err(SQLRiteError::NotImplemented(format!(
3223 "ORDER BY expression not supported on aggregating queries: {other:?}"
3224 )));
3225 }
3226 };
3227 if let Some(name) = target_name {
3228 if let Some(i) = columns.iter().position(|c| c.eq_ignore_ascii_case(&name)) {
3229 return Ok(i);
3230 }
3231 return Err(SQLRiteError::Internal(format!(
3232 "ORDER BY references unknown column '{name}' in the SELECT output"
3233 )));
3234 }
3235 if let Expr::Function(func) = expr {
3239 let user_disp = format_function_display(func);
3240 for (i, item) in proj_items.iter().enumerate() {
3241 if let ProjectionKind::Aggregate(call) = &item.kind
3242 && call.display_name().eq_ignore_ascii_case(&user_disp)
3243 {
3244 return Ok(i);
3245 }
3246 }
3247 return Err(SQLRiteError::Internal(format!(
3248 "ORDER BY references aggregate '{user_disp}' that isn't in the SELECT output"
3249 )));
3250 }
3251 Err(SQLRiteError::Internal(
3252 "ORDER BY expression could not be resolved against the output columns".to_string(),
3253 ))
3254}
3255
3256fn format_function_display(func: &sqlparser::ast::Function) -> String {
3260 let name = match func.name.0.as_slice() {
3261 [ObjectNamePart::Identifier(ident)] => ident.value.to_uppercase(),
3262 _ => format!("{:?}", func.name).to_uppercase(),
3263 };
3264 let inner = match &func.args {
3265 FunctionArguments::List(l) => {
3266 let distinct = matches!(
3267 l.duplicate_treatment,
3268 Some(sqlparser::ast::DuplicateTreatment::Distinct)
3269 );
3270 let arg = l.args.first().map(|a| match a {
3271 FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => "*".to_string(),
3272 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(i))) => i.value.clone(),
3273 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::CompoundIdentifier(parts))) => {
3274 parts.last().map(|p| p.value.clone()).unwrap_or_default()
3275 }
3276 _ => String::new(),
3277 });
3278 match (distinct, arg) {
3279 (true, Some(a)) if a != "*" => format!("DISTINCT {a}"),
3280 (_, Some(a)) => a,
3281 _ => String::new(),
3282 }
3283 }
3284 _ => String::new(),
3285 };
3286 format!("{name}({inner})")
3287}
3288
3289fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
3290 use sqlparser::ast::Value as AstValue;
3291 match v {
3292 AstValue::Number(n, _) => {
3293 if let Ok(i) = n.parse::<i64>() {
3294 Ok(Value::Integer(i))
3295 } else if let Ok(f) = n.parse::<f64>() {
3296 Ok(Value::Real(f))
3297 } else {
3298 Err(SQLRiteError::Internal(format!(
3299 "could not parse numeric literal '{n}'"
3300 )))
3301 }
3302 }
3303 AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
3304 AstValue::Boolean(b) => Ok(Value::Bool(*b)),
3305 AstValue::Null => Ok(Value::Null),
3306 other => Err(SQLRiteError::NotImplemented(format!(
3307 "unsupported literal value: {other:?}"
3308 ))),
3309 }
3310}
3311
3312#[cfg(test)]
3313mod tests {
3314 use super::*;
3315
3316 fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
3323 (a - b).abs() < eps
3324 }
3325
3326 #[test]
3327 fn vec_distance_l2_identical_is_zero() {
3328 let v = vec![0.1, 0.2, 0.3];
3329 assert_eq!(vec_distance_l2(&v, &v), 0.0);
3330 }
3331
3332 #[test]
3333 fn vec_distance_l2_unit_basis_is_sqrt2() {
3334 let a = vec![1.0, 0.0];
3336 let b = vec![0.0, 1.0];
3337 assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
3338 }
3339
3340 #[test]
3341 fn vec_distance_l2_known_value() {
3342 let a = vec![0.0, 0.0, 0.0];
3344 let b = vec![3.0, 4.0, 0.0];
3345 assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
3346 }
3347
3348 #[test]
3349 fn vec_distance_cosine_identical_is_zero() {
3350 let v = vec![0.1, 0.2, 0.3];
3351 let d = vec_distance_cosine(&v, &v).unwrap();
3352 assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
3353 }
3354
3355 #[test]
3356 fn vec_distance_cosine_orthogonal_is_one() {
3357 let a = vec![1.0, 0.0];
3360 let b = vec![0.0, 1.0];
3361 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
3362 }
3363
3364 #[test]
3365 fn vec_distance_cosine_opposite_is_two() {
3366 let a = vec![1.0, 0.0, 0.0];
3368 let b = vec![-1.0, 0.0, 0.0];
3369 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
3370 }
3371
3372 #[test]
3373 fn vec_distance_cosine_zero_magnitude_errors() {
3374 let a = vec![0.0, 0.0];
3376 let b = vec![1.0, 0.0];
3377 let err = vec_distance_cosine(&a, &b).unwrap_err();
3378 assert!(format!("{err}").contains("zero-magnitude"));
3379 }
3380
3381 #[test]
3382 fn vec_distance_dot_negates() {
3383 let a = vec![1.0, 2.0, 3.0];
3385 let b = vec![4.0, 5.0, 6.0];
3386 assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
3387 }
3388
3389 #[test]
3390 fn vec_distance_dot_orthogonal_is_zero() {
3391 let a = vec![1.0, 0.0];
3393 let b = vec![0.0, 1.0];
3394 assert_eq!(vec_distance_dot(&a, &b), 0.0);
3395 }
3396
3397 #[test]
3398 fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
3399 let a = vec![0.6f32, 0.8]; let b = vec![0.8f32, 0.6]; let dot = vec_distance_dot(&a, &b);
3405 let cos = vec_distance_cosine(&a, &b).unwrap();
3406 assert!(approx_eq(dot, cos - 1.0, 1e-5));
3407 }
3408
3409 use crate::sql::db::database::Database;
3414 use crate::sql::dialect::SqlriteDialect;
3415 use crate::sql::parser::select::SelectQuery;
3416 use sqlparser::parser::Parser;
3417
3418 fn seed_score_table(n: usize) -> Database {
3431 let mut db = Database::new("tempdb".to_string());
3432 crate::sql::process_command(
3433 "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
3434 &mut db,
3435 )
3436 .expect("create");
3437 for i in 0..n {
3438 let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
3442 let sql = format!("INSERT INTO docs (score) VALUES ({score});");
3443 crate::sql::process_command(&sql, &mut db).expect("insert");
3444 }
3445 db
3446 }
3447
3448 fn parse_select(sql: &str) -> SelectQuery {
3452 let dialect = SqlriteDialect::new();
3453 let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
3454 let stmt = ast.pop().expect("one statement");
3455 SelectQuery::new(&stmt).expect("select-query")
3456 }
3457
3458 #[test]
3459 fn topk_matches_full_sort_asc() {
3460 let db = seed_score_table(200);
3463 let table = db.get_table("docs".to_string()).unwrap();
3464 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
3465 let order = q.order_by.as_ref().unwrap();
3466 let all_rowids = table.rowids();
3467
3468 let mut full = all_rowids.clone();
3470 sort_rowids(&mut full, table, order).unwrap();
3471 full.truncate(10);
3472
3473 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
3475
3476 assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
3477 }
3478
3479 #[test]
3480 fn topk_matches_full_sort_desc() {
3481 let db = seed_score_table(200);
3483 let table = db.get_table("docs".to_string()).unwrap();
3484 let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
3485 let order = q.order_by.as_ref().unwrap();
3486 let all_rowids = table.rowids();
3487
3488 let mut full = all_rowids.clone();
3489 sort_rowids(&mut full, table, order).unwrap();
3490 full.truncate(10);
3491
3492 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
3493
3494 assert_eq!(
3495 topk, full,
3496 "top-k DESC via heap should match full-sort+truncate"
3497 );
3498 }
3499
3500 #[test]
3501 fn topk_k_larger_than_n_returns_everything_sorted() {
3502 let db = seed_score_table(50);
3507 let table = db.get_table("docs".to_string()).unwrap();
3508 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
3509 let order = q.order_by.as_ref().unwrap();
3510 let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
3511 assert_eq!(topk.len(), 50);
3512 let scores: Vec<f64> = topk
3514 .iter()
3515 .filter_map(|r| match table.get_value("score", *r) {
3516 Some(Value::Real(f)) => Some(f),
3517 _ => None,
3518 })
3519 .collect();
3520 assert!(scores.windows(2).all(|w| w[0] <= w[1]));
3521 }
3522
3523 #[test]
3524 fn topk_k_zero_returns_empty() {
3525 let db = seed_score_table(10);
3526 let table = db.get_table("docs".to_string()).unwrap();
3527 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
3528 let order = q.order_by.as_ref().unwrap();
3529 let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
3530 assert!(topk.is_empty());
3531 }
3532
3533 #[test]
3534 fn topk_empty_input_returns_empty() {
3535 let db = seed_score_table(0);
3536 let table = db.get_table("docs".to_string()).unwrap();
3537 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
3538 let order = q.order_by.as_ref().unwrap();
3539 let topk = select_topk(&[], table, order, 5).unwrap();
3540 assert!(topk.is_empty());
3541 }
3542
3543 #[test]
3544 fn topk_works_through_select_executor_with_distance_function() {
3545 let mut db = Database::new("tempdb".to_string());
3549 crate::sql::process_command(
3550 "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
3551 &mut db,
3552 )
3553 .unwrap();
3554 for v in &[
3561 "[1.0, 0.0]",
3562 "[2.0, 0.0]",
3563 "[0.0, 3.0]",
3564 "[1.0, 4.0]",
3565 "[10.0, 10.0]",
3566 ] {
3567 crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
3568 .unwrap();
3569 }
3570 let resp = crate::sql::process_command(
3571 "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
3572 &mut db,
3573 )
3574 .unwrap();
3575 assert!(resp.contains("3 rows returned"), "got: {resp}");
3578 }
3579
3580 #[test]
3603 #[ignore]
3604 fn topk_benchmark() {
3605 use std::time::Instant;
3606 const N: usize = 10_000;
3607 const K: usize = 10;
3608
3609 let db = seed_score_table(N);
3610 let table = db.get_table("docs".to_string()).unwrap();
3611 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
3612 let order = q.order_by.as_ref().unwrap();
3613 let all_rowids = table.rowids();
3614
3615 let t0 = Instant::now();
3617 let _topk = select_topk(&all_rowids, table, order, K).unwrap();
3618 let heap_dur = t0.elapsed();
3619
3620 let t1 = Instant::now();
3622 let mut full = all_rowids.clone();
3623 sort_rowids(&mut full, table, order).unwrap();
3624 full.truncate(K);
3625 let sort_dur = t1.elapsed();
3626
3627 let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
3628 println!("\n--- topk_benchmark (N={N}, k={K}) ---");
3629 println!(" bounded heap: {heap_dur:?}");
3630 println!(" full sort+trunc: {sort_dur:?}");
3631 println!(" speedup ratio: {ratio:.2}×");
3632
3633 assert!(
3640 ratio > 1.4,
3641 "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
3642 );
3643 }
3644
3645 fn run_select(db: &mut Database, sql: &str) -> String {
3653 crate::sql::process_command(sql, db).expect("select")
3654 }
3655
3656 #[test]
3657 fn where_is_null_returns_null_rows() {
3658 let mut db = Database::new("t".to_string());
3659 crate::sql::process_command(
3660 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3661 &mut db,
3662 )
3663 .unwrap();
3664 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
3665 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3666 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3667 crate::sql::process_command("INSERT INTO t (id, n) VALUES (4, NULL);", &mut db).unwrap();
3668
3669 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL;");
3670 assert!(
3671 response.contains("2 rows returned"),
3672 "IS NULL should return 2 rows, got: {response}"
3673 );
3674 }
3675
3676 #[test]
3677 fn where_is_not_null_returns_non_null_rows() {
3678 let mut db = Database::new("t".to_string());
3679 crate::sql::process_command(
3680 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3681 &mut db,
3682 )
3683 .unwrap();
3684 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
3685 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3686 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3687
3688 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NOT NULL;");
3689 assert!(
3690 response.contains("2 rows returned"),
3691 "IS NOT NULL should return 2 rows, got: {response}"
3692 );
3693 }
3694
3695 #[test]
3696 fn where_is_null_on_indexed_column() {
3697 let mut db = Database::new("t".to_string());
3702 crate::sql::process_command(
3703 "CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT UNIQUE);",
3704 &mut db,
3705 )
3706 .unwrap();
3707 crate::sql::process_command("INSERT INTO t (id, name) VALUES (1, 'alice');", &mut db)
3708 .unwrap();
3709 crate::sql::process_command("INSERT INTO t (id, name) VALUES (2, NULL);", &mut db).unwrap();
3710 crate::sql::process_command("INSERT INTO t (id, name) VALUES (3, 'bob');", &mut db)
3711 .unwrap();
3712
3713 let null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NULL;");
3714 assert!(
3715 null_rows.contains("1 row returned"),
3716 "indexed IS NULL should return 1 row, got: {null_rows}"
3717 );
3718 let not_null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NOT NULL;");
3719 assert!(
3720 not_null_rows.contains("2 rows returned"),
3721 "indexed IS NOT NULL should return 2 rows, got: {not_null_rows}"
3722 );
3723 }
3724
3725 #[test]
3726 fn where_is_null_works_on_omitted_column() {
3727 let mut db = Database::new("t".to_string());
3731 crate::sql::process_command(
3732 "CREATE TABLE t (id INTEGER PRIMARY KEY, qty INTEGER, label TEXT);",
3733 &mut db,
3734 )
3735 .unwrap();
3736 crate::sql::process_command(
3737 "INSERT INTO t (id, qty, label) VALUES (1, 7, 'a');",
3738 &mut db,
3739 )
3740 .unwrap();
3741 crate::sql::process_command("INSERT INTO t (id, label) VALUES (2, 'b');", &mut db).unwrap();
3743
3744 let response = run_select(&mut db, "SELECT id FROM t WHERE qty IS NULL;");
3745 assert!(
3746 response.contains("1 row returned"),
3747 "IS NULL should match the omitted-column row, got: {response}"
3748 );
3749 }
3750
3751 #[test]
3752 fn where_is_null_combines_with_and_or() {
3753 let mut db = Database::new("t".to_string());
3757 crate::sql::process_command(
3758 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3759 &mut db,
3760 )
3761 .unwrap();
3762 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, NULL);", &mut db).unwrap();
3763 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3764 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3765
3766 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL AND id > 1;");
3767 assert!(
3768 response.contains("1 row returned"),
3769 "IS NULL combined with AND should match exactly row 2, got: {response}"
3770 );
3771 }
3772
3773 fn seed_employees() -> Database {
3779 let mut db = Database::new("t".to_string());
3780 crate::sql::process_command(
3781 "CREATE TABLE emp (id INTEGER PRIMARY KEY, name TEXT, dept TEXT, salary INTEGER);",
3782 &mut db,
3783 )
3784 .unwrap();
3785 let rows = [
3786 "INSERT INTO emp (name, dept, salary) VALUES ('Alice', 'eng', 100);",
3787 "INSERT INTO emp (name, dept, salary) VALUES ('alex', 'eng', 120);",
3788 "INSERT INTO emp (name, dept, salary) VALUES ('Bob', 'eng', 100);",
3789 "INSERT INTO emp (name, dept, salary) VALUES ('Carol', 'sales', 90);",
3790 "INSERT INTO emp (name, dept, salary) VALUES ('Dave', 'sales', NULL);",
3791 "INSERT INTO emp (name, dept, salary) VALUES ('Eve', 'ops', 80);",
3792 ];
3793 for sql in rows {
3794 crate::sql::process_command(sql, &mut db).unwrap();
3795 }
3796 db
3797 }
3798
3799 fn run_rows(db: &Database, sql: &str) -> SelectResult {
3801 let q = parse_select(sql);
3802 execute_select_rows(q, db).expect("select")
3803 }
3804
3805 #[test]
3808 fn like_percent_prefix_case_insensitive() {
3809 let db = seed_employees();
3810 let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE 'a%';");
3811 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
3813 assert_eq!(names.len(), 2, "expected 2 rows, got {names:?}");
3814 assert!(names.contains(&"Alice".to_string()));
3815 assert!(names.contains(&"alex".to_string()));
3816 }
3817
3818 #[test]
3819 fn like_underscore_singlechar() {
3820 let db = seed_employees();
3821 let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE '_ve';");
3822 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
3824 assert_eq!(names, vec!["Eve".to_string()]);
3825 }
3826
3827 #[test]
3828 fn not_like_excludes_match() {
3829 let db = seed_employees();
3830 let r = run_rows(&db, "SELECT name FROM emp WHERE name NOT LIKE 'a%';");
3831 assert_eq!(r.rows.len(), 4);
3833 }
3834
3835 #[test]
3836 fn like_with_null_excludes_row() {
3837 let db = seed_employees();
3838 let r = run_rows(
3840 &db,
3841 "SELECT name FROM emp WHERE dept LIKE 'sales' AND salary IS NULL;",
3842 );
3843 assert_eq!(r.rows.len(), 1);
3844 assert_eq!(r.rows[0][0].to_display_string(), "Dave");
3845 }
3846
3847 #[test]
3850 fn in_list_positive() {
3851 let db = seed_employees();
3852 let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, 3, 5);");
3853 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
3854 assert_eq!(names.len(), 3);
3855 assert!(names.contains(&"Alice".to_string()));
3856 assert!(names.contains(&"Bob".to_string()));
3857 assert!(names.contains(&"Dave".to_string()));
3858 }
3859
3860 #[test]
3861 fn not_in_excludes_listed() {
3862 let db = seed_employees();
3863 let r = run_rows(&db, "SELECT name FROM emp WHERE id NOT IN (1, 2);");
3864 assert_eq!(r.rows.len(), 4);
3866 }
3867
3868 #[test]
3869 fn in_list_with_null_three_valued() {
3870 let db = seed_employees();
3871 let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, NULL);");
3874 assert_eq!(r.rows.len(), 1);
3875 assert_eq!(r.rows[0][0].to_display_string(), "Alice");
3876 }
3877
3878 #[test]
3881 fn distinct_single_column() {
3882 let db = seed_employees();
3883 let r = run_rows(&db, "SELECT DISTINCT dept FROM emp;");
3884 assert_eq!(r.rows.len(), 3);
3886 }
3887
3888 #[test]
3889 fn distinct_multi_column_with_null() {
3890 let db = seed_employees();
3891 let r = run_rows(&db, "SELECT DISTINCT dept, salary FROM emp;");
3893 assert_eq!(r.rows.len(), 5);
3895 }
3896
3897 #[test]
3900 fn count_star_no_groupby() {
3901 let db = seed_employees();
3902 let r = run_rows(&db, "SELECT COUNT(*) FROM emp;");
3903 assert_eq!(r.rows.len(), 1);
3904 assert_eq!(r.rows[0][0], Value::Integer(6));
3905 }
3906
3907 #[test]
3908 fn count_col_skips_nulls() {
3909 let db = seed_employees();
3910 let r = run_rows(&db, "SELECT COUNT(salary) FROM emp;");
3911 assert_eq!(r.rows[0][0], Value::Integer(5));
3913 }
3914
3915 #[test]
3916 fn count_distinct_dedupes_and_skips_nulls() {
3917 let db = seed_employees();
3918 let r = run_rows(&db, "SELECT COUNT(DISTINCT salary) FROM emp;");
3919 assert_eq!(r.rows[0][0], Value::Integer(4));
3921 }
3922
3923 #[test]
3924 fn sum_int_stays_integer() {
3925 let db = seed_employees();
3926 let r = run_rows(&db, "SELECT SUM(salary) FROM emp;");
3927 assert_eq!(r.rows[0][0], Value::Integer(490));
3929 }
3930
3931 #[test]
3932 fn avg_returns_real() {
3933 let db = seed_employees();
3934 let r = run_rows(&db, "SELECT AVG(salary) FROM emp;");
3935 match &r.rows[0][0] {
3937 Value::Real(v) => assert!((v - 98.0).abs() < 1e-9),
3938 other => panic!("expected Real, got {other:?}"),
3939 }
3940 }
3941
3942 #[test]
3943 fn min_max_skip_nulls() {
3944 let db = seed_employees();
3945 let r = run_rows(&db, "SELECT MIN(salary), MAX(salary) FROM emp;");
3946 assert_eq!(r.rows[0][0], Value::Integer(80));
3947 assert_eq!(r.rows[0][1], Value::Integer(120));
3948 }
3949
3950 #[test]
3951 fn aggregates_on_empty_table_emit_one_row() {
3952 let mut db = Database::new("t".to_string());
3953 crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
3954 let r = run_rows(
3955 &db,
3956 "SELECT COUNT(*), SUM(x), AVG(x), MIN(x), MAX(x) FROM t;",
3957 );
3958 assert_eq!(r.rows.len(), 1);
3959 assert_eq!(r.rows[0][0], Value::Integer(0));
3960 assert_eq!(r.rows[0][1], Value::Null);
3961 assert_eq!(r.rows[0][2], Value::Null);
3962 assert_eq!(r.rows[0][3], Value::Null);
3963 assert_eq!(r.rows[0][4], Value::Null);
3964 }
3965
3966 #[test]
3969 fn group_by_single_col_with_count() {
3970 let db = seed_employees();
3971 let r = run_rows(&db, "SELECT dept, COUNT(*) FROM emp GROUP BY dept;");
3972 assert_eq!(r.rows.len(), 3);
3973 let mut by_dept: std::collections::HashMap<String, i64> = Default::default();
3975 for row in &r.rows {
3976 let d = row[0].to_display_string();
3977 let c = match &row[1] {
3978 Value::Integer(i) => *i,
3979 v => panic!("expected Integer count, got {v:?}"),
3980 };
3981 by_dept.insert(d, c);
3982 }
3983 assert_eq!(by_dept["eng"], 3);
3984 assert_eq!(by_dept["sales"], 2);
3985 assert_eq!(by_dept["ops"], 1);
3986 }
3987
3988 #[test]
3989 fn group_by_with_where_filter() {
3990 let db = seed_employees();
3991 let r = run_rows(
3992 &db,
3993 "SELECT dept, SUM(salary) FROM emp WHERE salary > 80 GROUP BY dept;",
3994 );
3995 let by: std::collections::HashMap<String, i64> = r
3998 .rows
3999 .iter()
4000 .map(|row| {
4001 (
4002 row[0].to_display_string(),
4003 match &row[1] {
4004 Value::Integer(i) => *i,
4005 v => panic!("expected Integer sum, got {v:?}"),
4006 },
4007 )
4008 })
4009 .collect();
4010 assert_eq!(by.len(), 2);
4011 assert_eq!(by["eng"], 320);
4012 assert_eq!(by["sales"], 90);
4013 }
4014
4015 #[test]
4016 fn group_by_without_aggregates_is_distinct() {
4017 let db = seed_employees();
4018 let r = run_rows(&db, "SELECT dept FROM emp GROUP BY dept;");
4019 assert_eq!(r.rows.len(), 3);
4020 }
4021
4022 #[test]
4023 fn order_by_count_desc() {
4024 let db = seed_employees();
4025 let r = run_rows(
4026 &db,
4027 "SELECT dept, COUNT(*) AS n FROM emp GROUP BY dept ORDER BY n DESC LIMIT 2;",
4028 );
4029 assert_eq!(r.rows.len(), 2);
4030 assert_eq!(r.rows[0][0].to_display_string(), "eng");
4032 assert_eq!(r.rows[0][1], Value::Integer(3));
4033 }
4034
4035 #[test]
4036 fn order_by_aggregate_call_form() {
4037 let db = seed_employees();
4038 let r = run_rows(
4040 &db,
4041 "SELECT dept, COUNT(*) FROM emp GROUP BY dept ORDER BY COUNT(*) DESC;",
4042 );
4043 assert_eq!(r.rows.len(), 3);
4044 assert_eq!(r.rows[0][0].to_display_string(), "eng");
4045 }
4046
4047 #[test]
4048 fn group_by_invalid_bare_column_errors() {
4049 let mut db = Database::new("t".to_string());
4051 crate::sql::process_command(
4052 "CREATE TABLE t (id INTEGER PRIMARY KEY, dept TEXT, name TEXT);",
4053 &mut db,
4054 )
4055 .unwrap();
4056 let err = crate::sql::process_command("SELECT dept, name FROM t GROUP BY dept;", &mut db);
4057 assert!(err.is_err(), "should reject bare 'name' not in GROUP BY");
4058 }
4059
4060 #[test]
4061 fn aggregate_in_where_errors_friendly() {
4062 let mut db = Database::new("t".to_string());
4063 crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
4064 crate::sql::process_command("INSERT INTO t (x) VALUES (1);", &mut db).unwrap();
4065 let err = crate::sql::process_command("SELECT x FROM t WHERE COUNT(*) > 0;", &mut db);
4066 assert!(err.is_err(), "aggregates must not be allowed in WHERE");
4067 }
4068
4069 fn seed_join_fixture() -> Database {
4080 let mut db = Database::new("t".to_string());
4081 for sql in [
4082 "CREATE TABLE customers (id INTEGER PRIMARY KEY, name TEXT);",
4083 "CREATE TABLE orders (id INTEGER PRIMARY KEY, customer_id INTEGER, amount INTEGER);",
4084 "INSERT INTO customers (name) VALUES ('Alice');",
4085 "INSERT INTO customers (name) VALUES ('Bob');",
4086 "INSERT INTO customers (name) VALUES ('Carol');",
4087 "INSERT INTO orders (customer_id, amount) VALUES (1, 100);",
4088 "INSERT INTO orders (customer_id, amount) VALUES (1, 200);",
4089 "INSERT INTO orders (customer_id, amount) VALUES (2, 50);",
4090 "INSERT INTO orders (customer_id, amount) VALUES (4, 999);",
4091 ] {
4092 crate::sql::process_command(sql, &mut db).unwrap();
4093 }
4094 db
4095 }
4096
4097 #[test]
4098 fn inner_join_returns_only_matched_rows() {
4099 let db = seed_join_fixture();
4100 let r = run_rows(
4101 &db,
4102 "SELECT customers.name, orders.amount FROM customers \
4103 INNER JOIN orders ON customers.id = orders.customer_id;",
4104 );
4105 assert_eq!(r.columns, vec!["name".to_string(), "amount".to_string()]);
4106 let pairs: Vec<(String, i64)> = r
4109 .rows
4110 .iter()
4111 .map(|row| {
4112 (
4113 row[0].to_display_string(),
4114 match row[1] {
4115 Value::Integer(i) => i,
4116 ref v => panic!("expected integer amount, got {v:?}"),
4117 },
4118 )
4119 })
4120 .collect();
4121 assert_eq!(pairs.len(), 3);
4122 assert!(pairs.contains(&("Alice".to_string(), 100)));
4123 assert!(pairs.contains(&("Alice".to_string(), 200)));
4124 assert!(pairs.contains(&("Bob".to_string(), 50)));
4125 }
4126
4127 #[test]
4128 fn bare_join_defaults_to_inner() {
4129 let db = seed_join_fixture();
4130 let r = run_rows(
4131 &db,
4132 "SELECT customers.name FROM customers \
4133 JOIN orders ON customers.id = orders.customer_id;",
4134 );
4135 assert_eq!(r.rows.len(), 3, "JOIN without prefix should be INNER");
4136 }
4137
4138 #[test]
4139 fn left_outer_join_preserves_unmatched_left() {
4140 let db = seed_join_fixture();
4141 let r = run_rows(
4142 &db,
4143 "SELECT customers.name, orders.amount FROM customers \
4144 LEFT OUTER JOIN orders ON customers.id = orders.customer_id;",
4145 );
4146 assert_eq!(r.rows.len(), 4);
4149 let carol = r
4150 .rows
4151 .iter()
4152 .find(|row| row[0].to_display_string() == "Carol")
4153 .expect("Carol should appear with a NULL-padded right side");
4154 assert_eq!(carol[1], Value::Null);
4155 }
4156
4157 #[test]
4158 fn right_outer_join_preserves_unmatched_right() {
4159 let db = seed_join_fixture();
4160 let r = run_rows(
4161 &db,
4162 "SELECT customers.name, orders.amount FROM customers \
4163 RIGHT OUTER JOIN orders ON customers.id = orders.customer_id;",
4164 );
4165 assert_eq!(r.rows.len(), 4);
4169 let dangling = r
4170 .rows
4171 .iter()
4172 .find(|row| matches!(row[1], Value::Integer(999)))
4173 .expect("dangling order 999 should appear with a NULL-padded customer name");
4174 assert_eq!(dangling[0], Value::Null);
4175 }
4176
4177 #[test]
4178 fn full_outer_join_preserves_both_sides() {
4179 let db = seed_join_fixture();
4180 let r = run_rows(
4181 &db,
4182 "SELECT customers.name, orders.amount FROM customers \
4183 FULL OUTER JOIN orders ON customers.id = orders.customer_id;",
4184 );
4185 assert_eq!(r.rows.len(), 5);
4188 assert!(
4190 r.rows
4191 .iter()
4192 .any(|row| row[0].to_display_string() == "Carol" && matches!(row[1], Value::Null))
4193 );
4194 assert!(
4196 r.rows
4197 .iter()
4198 .any(|row| matches!(row[1], Value::Integer(999)) && matches!(row[0], Value::Null))
4199 );
4200 }
4201
4202 #[test]
4203 fn join_with_table_aliases_resolves_qualifiers() {
4204 let db = seed_join_fixture();
4205 let r = run_rows(
4206 &db,
4207 "SELECT c.name, o.amount FROM customers AS c \
4208 INNER JOIN orders AS o ON c.id = o.customer_id;",
4209 );
4210 assert_eq!(r.rows.len(), 3);
4211 assert_eq!(r.columns, vec!["name".to_string(), "amount".to_string()]);
4212 }
4213
4214 #[test]
4215 fn join_with_where_filter_applies_after_join() {
4216 let db = seed_join_fixture();
4217 let r = run_rows(
4220 &db,
4221 "SELECT customers.name, orders.amount FROM customers \
4222 INNER JOIN orders ON customers.id = orders.customer_id \
4223 WHERE orders.amount >= 100;",
4224 );
4225 assert_eq!(r.rows.len(), 2);
4226 assert!(
4227 r.rows
4228 .iter()
4229 .all(|row| row[0].to_display_string() == "Alice")
4230 );
4231 }
4232
4233 #[test]
4234 fn left_join_with_where_on_right_side_is_not_inner() {
4235 let db = seed_join_fixture();
4239 let r = run_rows(
4240 &db,
4241 "SELECT customers.name, orders.amount FROM customers \
4242 LEFT OUTER JOIN orders ON customers.id = orders.customer_id \
4243 WHERE orders.amount IS NULL;",
4244 );
4245 assert_eq!(r.rows.len(), 1);
4247 assert_eq!(r.rows[0][0].to_display_string(), "Carol");
4248 assert_eq!(r.rows[0][1], Value::Null);
4249 }
4250
4251 #[test]
4252 fn select_star_over_join_emits_all_columns_from_both_tables() {
4253 let db = seed_join_fixture();
4254 let r = run_rows(
4255 &db,
4256 "SELECT * FROM customers \
4257 INNER JOIN orders ON customers.id = orders.customer_id;",
4258 );
4259 assert_eq!(
4263 r.columns,
4264 vec![
4265 "id".to_string(),
4266 "name".to_string(),
4267 "id".to_string(),
4268 "customer_id".to_string(),
4269 "amount".to_string(),
4270 ]
4271 );
4272 assert_eq!(r.rows.len(), 3);
4273 }
4274
4275 #[test]
4276 fn join_order_by_sorts_full_joined_rows() {
4277 let db = seed_join_fixture();
4278 let r = run_rows(
4279 &db,
4280 "SELECT c.name, o.amount FROM customers AS c \
4281 INNER JOIN orders AS o ON c.id = o.customer_id \
4282 ORDER BY o.amount;",
4283 );
4284 let amounts: Vec<i64> = r
4285 .rows
4286 .iter()
4287 .map(|row| match row[1] {
4288 Value::Integer(i) => i,
4289 ref v => panic!("expected integer, got {v:?}"),
4290 })
4291 .collect();
4292 assert_eq!(amounts, vec![50, 100, 200]);
4293 }
4294
4295 #[test]
4296 fn join_limit_truncates_after_join_and_sort() {
4297 let db = seed_join_fixture();
4298 let r = run_rows(
4299 &db,
4300 "SELECT c.name, o.amount FROM customers AS c \
4301 INNER JOIN orders AS o ON c.id = o.customer_id \
4302 ORDER BY o.amount DESC LIMIT 2;",
4303 );
4304 assert_eq!(r.rows.len(), 2);
4305 let amounts: Vec<i64> = r
4307 .rows
4308 .iter()
4309 .map(|row| match row[1] {
4310 Value::Integer(i) => i,
4311 ref v => panic!("expected integer, got {v:?}"),
4312 })
4313 .collect();
4314 assert_eq!(amounts, vec![200, 100]);
4315 }
4316
4317 #[test]
4318 fn three_table_join_chains_correctly() {
4319 let mut db = Database::new("t".to_string());
4320 for sql in [
4321 "CREATE TABLE a (id INTEGER PRIMARY KEY, label TEXT);",
4322 "CREATE TABLE b (id INTEGER PRIMARY KEY, a_id INTEGER, tag TEXT);",
4323 "CREATE TABLE c (id INTEGER PRIMARY KEY, b_id INTEGER, note TEXT);",
4324 "INSERT INTO a (label) VALUES ('a-one');",
4325 "INSERT INTO a (label) VALUES ('a-two');",
4326 "INSERT INTO b (a_id, tag) VALUES (1, 'b1');",
4327 "INSERT INTO b (a_id, tag) VALUES (2, 'b2');",
4328 "INSERT INTO c (b_id, note) VALUES (1, 'c1');",
4329 ] {
4330 crate::sql::process_command(sql, &mut db).unwrap();
4331 }
4332 let r = run_rows(
4333 &db,
4334 "SELECT a.label, b.tag, c.note FROM a \
4335 INNER JOIN b ON a.id = b.a_id \
4336 INNER JOIN c ON b.id = c.b_id;",
4337 );
4338 assert_eq!(r.rows.len(), 1);
4340 assert_eq!(r.rows[0][0].to_display_string(), "a-one");
4341 assert_eq!(r.rows[0][1].to_display_string(), "b1");
4342 assert_eq!(r.rows[0][2].to_display_string(), "c1");
4343 }
4344
4345 #[test]
4346 fn ambiguous_unqualified_column_in_join_errors() {
4347 let db = seed_join_fixture();
4351 let q = parse_select(
4352 "SELECT id FROM customers INNER JOIN orders ON customers.id = orders.customer_id;",
4353 );
4354 let res = execute_select_rows(q, &db);
4355 assert!(res.is_err(), "unqualified ambiguous 'id' should error");
4356 }
4357
4358 #[test]
4359 fn join_self_without_alias_is_rejected() {
4360 let mut db = Database::new("t".to_string());
4361 crate::sql::process_command(
4362 "CREATE TABLE n (id INTEGER PRIMARY KEY, parent INTEGER);",
4363 &mut db,
4364 )
4365 .unwrap();
4366 let q = parse_select("SELECT n.id FROM n INNER JOIN n ON n.id = n.parent;");
4367 let res = execute_select_rows(q, &db);
4368 assert!(
4369 res.is_err(),
4370 "self-join without an alias should error on duplicate qualifier"
4371 );
4372 }
4373
4374 #[test]
4375 fn using_or_natural_join_returns_not_implemented() {
4376 let mut db = Database::new("t".to_string());
4377 crate::sql::process_command("CREATE TABLE a (id INTEGER PRIMARY KEY);", &mut db).unwrap();
4378 crate::sql::process_command("CREATE TABLE b (id INTEGER PRIMARY KEY);", &mut db).unwrap();
4379 let err = crate::sql::process_command("SELECT * FROM a INNER JOIN b USING (id);", &mut db);
4380 assert!(err.is_err(), "USING is not yet supported");
4381
4382 let err = crate::sql::process_command("SELECT * FROM a NATURAL JOIN b;", &mut db);
4383 assert!(err.is_err(), "NATURAL is not supported");
4384 }
4385
4386 #[test]
4387 fn aggregates_over_join_are_rejected() {
4388 let db = seed_join_fixture();
4389 let err = crate::sql::process_command(
4390 "SELECT COUNT(*) FROM customers \
4391 INNER JOIN orders ON customers.id = orders.customer_id;",
4392 &mut seed_join_fixture(),
4393 );
4394 assert!(err.is_err(), "aggregates over JOIN are not yet supported");
4395 let _ = db; }
4397
4398 #[test]
4399 fn left_join_with_no_matches_pads_every_row() {
4400 let mut db = Database::new("t".to_string());
4401 for sql in [
4402 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4403 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
4404 "INSERT INTO a (x) VALUES (1);",
4405 "INSERT INTO a (x) VALUES (2);",
4406 "INSERT INTO b (y) VALUES (10);",
4407 ] {
4408 crate::sql::process_command(sql, &mut db).unwrap();
4409 }
4410 let r = run_rows(
4412 &db,
4413 "SELECT a.x, b.y FROM a LEFT OUTER JOIN b ON a.x = b.y;",
4414 );
4415 assert_eq!(r.rows.len(), 2);
4416 for row in &r.rows {
4417 assert_eq!(row[1], Value::Null);
4418 }
4419 }
4420
4421 #[test]
4422 fn left_outer_join_order_by_places_nulls_first() {
4423 let db = seed_join_fixture();
4428 let r = run_rows(
4429 &db,
4430 "SELECT c.name, o.amount FROM customers AS c \
4431 LEFT OUTER JOIN orders AS o ON c.id = o.customer_id \
4432 ORDER BY o.amount ASC;",
4433 );
4434 assert_eq!(r.rows.len(), 4);
4435 assert_eq!(r.rows[0][0].to_display_string(), "Carol");
4437 assert_eq!(r.rows[0][1], Value::Null);
4438 }
4439
4440 #[test]
4441 fn chained_left_outer_join_preserves_left_through_two_levels() {
4442 let mut db = Database::new("t".to_string());
4445 for sql in [
4446 "CREATE TABLE a (id INTEGER PRIMARY KEY, label TEXT);",
4447 "CREATE TABLE b (id INTEGER PRIMARY KEY, a_id INTEGER, tag TEXT);",
4448 "CREATE TABLE c (id INTEGER PRIMARY KEY, b_id INTEGER, note TEXT);",
4449 "INSERT INTO a (label) VALUES ('a-one');",
4450 "INSERT INTO a (label) VALUES ('a-two');",
4451 "INSERT INTO b (a_id, tag) VALUES (1, 'b1');",
4453 ] {
4455 crate::sql::process_command(sql, &mut db).unwrap();
4456 }
4457 let r = run_rows(
4458 &db,
4459 "SELECT a.label, b.tag, c.note FROM a \
4460 LEFT OUTER JOIN b ON a.id = b.a_id \
4461 LEFT OUTER JOIN c ON b.id = c.b_id;",
4462 );
4463 assert_eq!(r.rows.len(), 2);
4465 let by_label: std::collections::HashMap<String, &Vec<Value>> = r
4466 .rows
4467 .iter()
4468 .map(|row| (row[0].to_display_string(), row))
4469 .collect();
4470 assert_eq!(by_label["a-one"][1].to_display_string(), "b1");
4471 assert_eq!(by_label["a-one"][2], Value::Null);
4472 assert_eq!(by_label["a-two"][1], Value::Null);
4473 assert_eq!(by_label["a-two"][2], Value::Null);
4474 }
4475
4476 #[test]
4477 fn on_clause_referencing_not_yet_joined_table_errors_clearly() {
4478 let mut db = Database::new("t".to_string());
4482 for sql in [
4483 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4484 "CREATE TABLE b (id INTEGER PRIMARY KEY, x INTEGER);",
4485 "CREATE TABLE c (id INTEGER PRIMARY KEY, x INTEGER);",
4486 "INSERT INTO a (x) VALUES (1);",
4487 "INSERT INTO b (x) VALUES (1);",
4488 "INSERT INTO c (x) VALUES (1);",
4489 ] {
4490 crate::sql::process_command(sql, &mut db).unwrap();
4491 }
4492 let q =
4493 parse_select("SELECT a.x FROM a INNER JOIN b ON a.x = c.x INNER JOIN c ON b.x = c.x;");
4494 let res = execute_select_rows(q, &db);
4495 assert!(
4496 res.is_err(),
4497 "ON referencing not-yet-joined table 'c' should error"
4498 );
4499 }
4500
4501 #[test]
4502 fn join_on_truthy_integer_is_accepted() {
4503 let mut db = Database::new("t".to_string());
4507 for sql in [
4508 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4509 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
4510 "INSERT INTO a (x) VALUES (1);",
4511 "INSERT INTO a (x) VALUES (2);",
4512 "INSERT INTO b (y) VALUES (10);",
4513 "INSERT INTO b (y) VALUES (20);",
4514 ] {
4515 crate::sql::process_command(sql, &mut db).unwrap();
4516 }
4517 let r = run_rows(&db, "SELECT a.x, b.y FROM a INNER JOIN b ON 1;");
4518 assert_eq!(r.rows.len(), 4);
4520 }
4521
4522 #[test]
4523 fn full_join_on_empty_tables_returns_empty() {
4524 let mut db = Database::new("t".to_string());
4525 for sql in [
4526 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4527 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
4528 ] {
4529 crate::sql::process_command(sql, &mut db).unwrap();
4530 }
4531 let r = run_rows(
4532 &db,
4533 "SELECT a.x, b.y FROM a FULL OUTER JOIN b ON a.x = b.y;",
4534 );
4535 assert!(r.rows.is_empty());
4536 }
4537}