1use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8 AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr, FromTable, FunctionArg,
9 FunctionArgExpr, FunctionArguments, IndexType, ObjectNamePart, Statement, TableFactor,
10 TableWithJoins, UnaryOperator, Update,
11};
12
13use crate::error::{Result, SQLRiteError};
14use crate::sql::db::database::Database;
15use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
16use crate::sql::db::table::{DataType, HnswIndexEntry, Table, Value, parse_vector_literal};
17use crate::sql::hnsw::{DistanceMetric, HnswIndex};
18use crate::sql::parser::select::{OrderByClause, Projection, SelectQuery};
19
20pub struct SelectResult {
29 pub columns: Vec<String>,
30 pub rows: Vec<Vec<Value>>,
31}
32
33pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
37 let table = db
38 .get_table(query.table_name.clone())
39 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
40
41 let projected_cols: Vec<String> = match &query.projection {
43 Projection::All => table.column_names(),
44 Projection::Columns(cols) => {
45 for c in cols {
46 if !table.contains_column(c.to_string()) {
47 return Err(SQLRiteError::Internal(format!(
48 "Column '{c}' does not exist on table '{}'",
49 query.table_name
50 )));
51 }
52 }
53 cols.clone()
54 }
55 };
56
57 let matching = match select_rowids(table, query.selection.as_ref())? {
61 RowidSource::IndexProbe(rowids) => rowids,
62 RowidSource::FullScan => {
63 let mut out = Vec::new();
64 for rowid in table.rowids() {
65 if let Some(expr) = &query.selection {
66 if !eval_predicate(expr, table, rowid)? {
67 continue;
68 }
69 }
70 out.push(rowid);
71 }
72 out
73 }
74 };
75 let mut matching = matching;
76
77 match (&query.order_by, query.limit) {
106 (Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
107 matching = try_hnsw_probe(table, &order.expr, k).unwrap();
108 }
109 (Some(order), Some(k)) if k < matching.len() => {
110 matching = select_topk(&matching, table, order, k)?;
111 }
112 (Some(order), _) => {
113 sort_rowids(&mut matching, table, order)?;
114 if let Some(k) = query.limit {
115 matching.truncate(k);
116 }
117 }
118 (None, Some(k)) => {
119 matching.truncate(k);
120 }
121 (None, None) => {}
122 }
123
124 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
128 for rowid in &matching {
129 let row: Vec<Value> = projected_cols
130 .iter()
131 .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
132 .collect();
133 rows.push(row);
134 }
135
136 Ok(SelectResult {
137 columns: projected_cols,
138 rows,
139 })
140}
141
142pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
147 let result = execute_select_rows(query, db)?;
148 let row_count = result.rows.len();
149
150 let mut print_table = PrintTable::new();
151 let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
152 print_table.add_row(PrintRow::new(header_cells));
153
154 for row in &result.rows {
155 let cells: Vec<PrintCell> = row
156 .iter()
157 .map(|v| PrintCell::new(&v.to_display_string()))
158 .collect();
159 print_table.add_row(PrintRow::new(cells));
160 }
161
162 Ok((print_table.to_string(), row_count))
163}
164
165pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
167 let Statement::Delete(Delete {
168 from, selection, ..
169 }) = stmt
170 else {
171 return Err(SQLRiteError::Internal(
172 "execute_delete called on a non-DELETE statement".to_string(),
173 ));
174 };
175
176 let tables = match from {
177 FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
178 };
179 let table_name = extract_single_table_name(tables)?;
180
181 let matching: Vec<i64> = {
183 let table = db
184 .get_table(table_name.clone())
185 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
186 match select_rowids(table, selection.as_ref())? {
187 RowidSource::IndexProbe(rowids) => rowids,
188 RowidSource::FullScan => {
189 let mut out = Vec::new();
190 for rowid in table.rowids() {
191 if let Some(expr) = selection {
192 if !eval_predicate(expr, table, rowid)? {
193 continue;
194 }
195 }
196 out.push(rowid);
197 }
198 out
199 }
200 }
201 };
202
203 let table = db.get_table_mut(table_name)?;
204 for rowid in &matching {
205 table.delete_row(*rowid);
206 }
207 if !matching.is_empty() {
212 for entry in &mut table.hnsw_indexes {
213 entry.needs_rebuild = true;
214 }
215 }
216 Ok(matching.len())
217}
218
219pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
221 let Statement::Update(Update {
222 table,
223 assignments,
224 from,
225 selection,
226 ..
227 }) = stmt
228 else {
229 return Err(SQLRiteError::Internal(
230 "execute_update called on a non-UPDATE statement".to_string(),
231 ));
232 };
233
234 if from.is_some() {
235 return Err(SQLRiteError::NotImplemented(
236 "UPDATE ... FROM is not supported yet".to_string(),
237 ));
238 }
239
240 let table_name = extract_table_name(table)?;
241
242 let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
244 {
245 let tbl = db
246 .get_table(table_name.clone())
247 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
248 for a in assignments {
249 let col = match &a.target {
250 AssignmentTarget::ColumnName(name) => name
251 .0
252 .last()
253 .map(|p| p.to_string())
254 .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
255 AssignmentTarget::Tuple(_) => {
256 return Err(SQLRiteError::NotImplemented(
257 "tuple assignment targets are not supported".to_string(),
258 ));
259 }
260 };
261 if !tbl.contains_column(col.clone()) {
262 return Err(SQLRiteError::Internal(format!(
263 "UPDATE references unknown column '{col}'"
264 )));
265 }
266 parsed_assignments.push((col, a.value.clone()));
267 }
268 }
269
270 let work: Vec<(i64, Vec<(String, Value)>)> = {
274 let tbl = db.get_table(table_name.clone())?;
275 let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
276 RowidSource::IndexProbe(rowids) => rowids,
277 RowidSource::FullScan => {
278 let mut out = Vec::new();
279 for rowid in tbl.rowids() {
280 if let Some(expr) = selection {
281 if !eval_predicate(expr, tbl, rowid)? {
282 continue;
283 }
284 }
285 out.push(rowid);
286 }
287 out
288 }
289 };
290 let mut rows_to_update = Vec::new();
291 for rowid in matched_rowids {
292 let mut values = Vec::with_capacity(parsed_assignments.len());
293 for (col, expr) in &parsed_assignments {
294 let v = eval_expr(expr, tbl, rowid)?;
297 values.push((col.clone(), v));
298 }
299 rows_to_update.push((rowid, values));
300 }
301 rows_to_update
302 };
303
304 let tbl = db.get_table_mut(table_name)?;
305 for (rowid, values) in &work {
306 for (col, v) in values {
307 tbl.set_value(col, *rowid, v.clone())?;
308 }
309 }
310
311 if !work.is_empty() {
318 let updated_columns: std::collections::HashSet<&str> = work
319 .iter()
320 .flat_map(|(_, values)| values.iter().map(|(c, _)| c.as_str()))
321 .collect();
322 for entry in &mut tbl.hnsw_indexes {
323 if updated_columns.contains(entry.column_name.as_str()) {
324 entry.needs_rebuild = true;
325 }
326 }
327 }
328 Ok(work.len())
329}
330
331pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
343 let Statement::CreateIndex(CreateIndex {
344 name,
345 table_name,
346 columns,
347 using,
348 unique,
349 if_not_exists,
350 predicate,
351 ..
352 }) = stmt
353 else {
354 return Err(SQLRiteError::Internal(
355 "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
356 ));
357 };
358
359 if predicate.is_some() {
360 return Err(SQLRiteError::NotImplemented(
361 "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
362 ));
363 }
364
365 if columns.len() != 1 {
366 return Err(SQLRiteError::NotImplemented(format!(
367 "multi-column indexes are not supported yet ({} columns given)",
368 columns.len()
369 )));
370 }
371
372 let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
373 SQLRiteError::NotImplemented(
374 "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
375 )
376 })?;
377
378 let method = match using {
384 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
385 IndexMethod::Hnsw
386 }
387 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
388 IndexMethod::Btree
389 }
390 Some(other) => {
391 return Err(SQLRiteError::NotImplemented(format!(
392 "CREATE INDEX … USING {other:?} is not supported (try `hnsw` or no USING clause)"
393 )));
394 }
395 None => IndexMethod::Btree,
396 };
397
398 let table_name_str = table_name.to_string();
399 let column_name = match &columns[0].column.expr {
400 Expr::Identifier(ident) => ident.value.clone(),
401 Expr::CompoundIdentifier(parts) => parts
402 .last()
403 .map(|p| p.value.clone())
404 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
405 other => {
406 return Err(SQLRiteError::NotImplemented(format!(
407 "CREATE INDEX only supports simple column references, got {other:?}"
408 )));
409 }
410 };
411
412 let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
417 let table = db.get_table(table_name_str.clone()).map_err(|_| {
418 SQLRiteError::General(format!(
419 "CREATE INDEX references unknown table '{table_name_str}'"
420 ))
421 })?;
422 if !table.contains_column(column_name.clone()) {
423 return Err(SQLRiteError::General(format!(
424 "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
425 )));
426 }
427 let col = table
428 .columns
429 .iter()
430 .find(|c| c.column_name == column_name)
431 .expect("we just verified the column exists");
432
433 if table.index_by_name(&index_name).is_some()
436 || table.hnsw_indexes.iter().any(|i| i.name == index_name)
437 {
438 if *if_not_exists {
439 return Ok(index_name);
440 }
441 return Err(SQLRiteError::General(format!(
442 "index '{index_name}' already exists"
443 )));
444 }
445 let datatype = clone_datatype(&col.datatype);
446
447 let mut pairs = Vec::new();
448 for rowid in table.rowids() {
449 if let Some(v) = table.get_value(&column_name, rowid) {
450 pairs.push((rowid, v));
451 }
452 }
453 (datatype, pairs)
454 };
455
456 match method {
457 IndexMethod::Btree => create_btree_index(
458 db,
459 &table_name_str,
460 &index_name,
461 &column_name,
462 &datatype,
463 *unique,
464 &existing_rowids_and_values,
465 ),
466 IndexMethod::Hnsw => create_hnsw_index(
467 db,
468 &table_name_str,
469 &index_name,
470 &column_name,
471 &datatype,
472 *unique,
473 &existing_rowids_and_values,
474 ),
475 }
476}
477
478#[derive(Debug, Clone, Copy)]
482enum IndexMethod {
483 Btree,
484 Hnsw,
485}
486
487fn create_btree_index(
489 db: &mut Database,
490 table_name: &str,
491 index_name: &str,
492 column_name: &str,
493 datatype: &DataType,
494 unique: bool,
495 existing: &[(i64, Value)],
496) -> Result<String> {
497 let mut idx = SecondaryIndex::new(
498 index_name.to_string(),
499 table_name.to_string(),
500 column_name.to_string(),
501 datatype,
502 unique,
503 IndexOrigin::Explicit,
504 )?;
505
506 for (rowid, v) in existing {
510 if unique && idx.would_violate_unique(v) {
511 return Err(SQLRiteError::General(format!(
512 "cannot create UNIQUE index '{index_name}': column '{column_name}' \
513 already contains the duplicate value {}",
514 v.to_display_string()
515 )));
516 }
517 idx.insert(v, *rowid)?;
518 }
519
520 let table_mut = db.get_table_mut(table_name.to_string())?;
521 table_mut.secondary_indexes.push(idx);
522 Ok(index_name.to_string())
523}
524
525fn create_hnsw_index(
527 db: &mut Database,
528 table_name: &str,
529 index_name: &str,
530 column_name: &str,
531 datatype: &DataType,
532 unique: bool,
533 existing: &[(i64, Value)],
534) -> Result<String> {
535 let dim = match datatype {
538 DataType::Vector(d) => *d,
539 other => {
540 return Err(SQLRiteError::General(format!(
541 "USING hnsw requires a VECTOR column; '{column_name}' is {other}"
542 )));
543 }
544 };
545
546 if unique {
547 return Err(SQLRiteError::General(
548 "UNIQUE has no meaning for HNSW indexes".to_string(),
549 ));
550 }
551
552 let seed = hash_str_to_seed(index_name);
560 let mut idx = HnswIndex::new(DistanceMetric::L2, seed);
561
562 let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
566 std::collections::HashMap::with_capacity(existing.len());
567 for (rowid, v) in existing {
568 match v {
569 Value::Vector(vec) => {
570 if vec.len() != dim {
571 return Err(SQLRiteError::Internal(format!(
572 "row {rowid} stores a {}-dim vector in column '{column_name}' \
573 declared as VECTOR({dim}) — schema invariant violated",
574 vec.len()
575 )));
576 }
577 vec_map.insert(*rowid, vec.clone());
578 }
579 _ => continue,
583 }
584 }
585
586 for (rowid, _) in existing {
587 if let Some(v) = vec_map.get(rowid) {
588 let v_clone = v.clone();
589 idx.insert(*rowid, &v_clone, |id| {
590 vec_map.get(&id).cloned().unwrap_or_default()
591 });
592 }
593 }
594
595 let table_mut = db.get_table_mut(table_name.to_string())?;
596 table_mut.hnsw_indexes.push(HnswIndexEntry {
597 name: index_name.to_string(),
598 column_name: column_name.to_string(),
599 index: idx,
600 needs_rebuild: false,
602 });
603 Ok(index_name.to_string())
604}
605
606fn hash_str_to_seed(s: &str) -> u64 {
610 let mut h: u64 = 0xCBF29CE484222325;
611 for b in s.as_bytes() {
612 h ^= *b as u64;
613 h = h.wrapping_mul(0x100000001B3);
614 }
615 h
616}
617
618fn clone_datatype(dt: &DataType) -> DataType {
621 match dt {
622 DataType::Integer => DataType::Integer,
623 DataType::Text => DataType::Text,
624 DataType::Real => DataType::Real,
625 DataType::Bool => DataType::Bool,
626 DataType::Vector(dim) => DataType::Vector(*dim),
627 DataType::Json => DataType::Json,
628 DataType::None => DataType::None,
629 DataType::Invalid => DataType::Invalid,
630 }
631}
632
633fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
634 if tables.len() != 1 {
635 return Err(SQLRiteError::NotImplemented(
636 "multi-table DELETE is not supported yet".to_string(),
637 ));
638 }
639 extract_table_name(&tables[0])
640}
641
642fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
643 if !twj.joins.is_empty() {
644 return Err(SQLRiteError::NotImplemented(
645 "JOIN is not supported yet".to_string(),
646 ));
647 }
648 match &twj.relation {
649 TableFactor::Table { name, .. } => Ok(name.to_string()),
650 _ => Err(SQLRiteError::NotImplemented(
651 "only plain table references are supported".to_string(),
652 )),
653 }
654}
655
656enum RowidSource {
658 IndexProbe(Vec<i64>),
662 FullScan,
665}
666
667fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
672 let Some(expr) = selection else {
673 return Ok(RowidSource::FullScan);
674 };
675 let Some((col, literal)) = try_extract_equality(expr) else {
676 return Ok(RowidSource::FullScan);
677 };
678 let Some(idx) = table.index_for_column(&col) else {
679 return Ok(RowidSource::FullScan);
680 };
681
682 let literal_value = match convert_literal(&literal) {
686 Ok(v) => v,
687 Err(_) => return Ok(RowidSource::FullScan),
688 };
689
690 let mut rowids = idx.lookup(&literal_value);
694 rowids.sort_unstable();
695 Ok(RowidSource::IndexProbe(rowids))
696}
697
698fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
702 let peeled = match expr {
704 Expr::Nested(inner) => inner.as_ref(),
705 other => other,
706 };
707 let Expr::BinaryOp { left, op, right } = peeled else {
708 return None;
709 };
710 if !matches!(op, BinaryOperator::Eq) {
711 return None;
712 }
713 let col_from = |e: &Expr| -> Option<String> {
714 match e {
715 Expr::Identifier(ident) => Some(ident.value.clone()),
716 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
717 _ => None,
718 }
719 };
720 let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
721 if let Expr::Value(v) = e {
722 Some(v.value.clone())
723 } else {
724 None
725 }
726 };
727 if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
728 return Some((c, l));
729 }
730 if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
731 return Some((c, l));
732 }
733 None
734}
735
736fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
758 if k == 0 {
759 return None;
760 }
761
762 let func = match order_expr {
764 Expr::Function(f) => f,
765 _ => return None,
766 };
767 let fname = match func.name.0.as_slice() {
768 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
769 _ => return None,
770 };
771 if fname != "vec_distance_l2" {
772 return None;
773 }
774
775 let arg_list = match &func.args {
777 FunctionArguments::List(l) => &l.args,
778 _ => return None,
779 };
780 if arg_list.len() != 2 {
781 return None;
782 }
783 let exprs: Vec<&Expr> = arg_list
784 .iter()
785 .filter_map(|a| match a {
786 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
787 _ => None,
788 })
789 .collect();
790 if exprs.len() != 2 {
791 return None;
792 }
793
794 let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
799 Some(v) => v,
800 None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
801 Some(v) => v,
802 None => return None,
803 },
804 };
805
806 let entry = table
808 .hnsw_indexes
809 .iter()
810 .find(|e| e.column_name == col_name)?;
811
812 let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
818 Some(c) => match &c.datatype {
819 DataType::Vector(d) => *d,
820 _ => return None,
821 },
822 None => return None,
823 };
824 if query_vec.len() != declared_dim {
825 return None;
826 }
827
828 let column_for_closure = col_name.clone();
832 let table_ref = table;
833 let result = entry.index.search(&query_vec, k, |id| {
834 match table_ref.get_value(&column_for_closure, id) {
835 Some(Value::Vector(v)) => v,
836 _ => Vec::new(),
837 }
838 });
839 Some(result)
840}
841
842fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
847 let col_name = match a {
848 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
849 _ => return None,
850 };
851 let lit_str = match b {
852 Expr::Identifier(ident) if ident.quote_style == Some('[') => {
853 format!("[{}]", ident.value)
854 }
855 _ => return None,
856 };
857 let v = parse_vector_literal(&lit_str).ok()?;
858 Some((col_name, v))
859}
860
861struct HeapEntry {
874 key: Value,
875 rowid: i64,
876 asc: bool,
877}
878
879impl PartialEq for HeapEntry {
880 fn eq(&self, other: &Self) -> bool {
881 self.cmp(other) == Ordering::Equal
882 }
883}
884
885impl Eq for HeapEntry {}
886
887impl PartialOrd for HeapEntry {
888 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
889 Some(self.cmp(other))
890 }
891}
892
893impl Ord for HeapEntry {
894 fn cmp(&self, other: &Self) -> Ordering {
895 let raw = compare_values(Some(&self.key), Some(&other.key));
896 if self.asc { raw } else { raw.reverse() }
897 }
898}
899
900fn select_topk(
909 matching: &[i64],
910 table: &Table,
911 order: &OrderByClause,
912 k: usize,
913) -> Result<Vec<i64>> {
914 use std::collections::BinaryHeap;
915
916 if k == 0 || matching.is_empty() {
917 return Ok(Vec::new());
918 }
919
920 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
921
922 for &rowid in matching {
923 let key = eval_expr(&order.expr, table, rowid)?;
924 let entry = HeapEntry {
925 key,
926 rowid,
927 asc: order.ascending,
928 };
929
930 if heap.len() < k {
931 heap.push(entry);
932 } else {
933 if entry < *heap.peek().unwrap() {
937 heap.pop();
938 heap.push(entry);
939 }
940 }
941 }
942
943 Ok(heap
948 .into_sorted_vec()
949 .into_iter()
950 .map(|e| e.rowid)
951 .collect())
952}
953
954fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
955 let mut keys: Vec<(i64, Result<Value>)> = rowids
963 .iter()
964 .map(|r| (*r, eval_expr(&order.expr, table, *r)))
965 .collect();
966
967 for (_, k) in &keys {
971 if let Err(e) = k {
972 return Err(SQLRiteError::General(format!(
973 "ORDER BY expression failed: {e}"
974 )));
975 }
976 }
977
978 keys.sort_by(|(_, ka), (_, kb)| {
979 let va = ka.as_ref().unwrap();
982 let vb = kb.as_ref().unwrap();
983 let ord = compare_values(Some(va), Some(vb));
984 if order.ascending { ord } else { ord.reverse() }
985 });
986
987 for (i, (rowid, _)) in keys.into_iter().enumerate() {
989 rowids[i] = rowid;
990 }
991 Ok(())
992}
993
994fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
995 match (a, b) {
996 (None, None) => Ordering::Equal,
997 (None, _) => Ordering::Less,
998 (_, None) => Ordering::Greater,
999 (Some(a), Some(b)) => match (a, b) {
1000 (Value::Null, Value::Null) => Ordering::Equal,
1001 (Value::Null, _) => Ordering::Less,
1002 (_, Value::Null) => Ordering::Greater,
1003 (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
1004 (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
1005 (Value::Integer(x), Value::Real(y)) => {
1006 (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
1007 }
1008 (Value::Real(x), Value::Integer(y)) => {
1009 x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
1010 }
1011 (Value::Text(x), Value::Text(y)) => x.cmp(y),
1012 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
1013 (x, y) => x.to_display_string().cmp(&y.to_display_string()),
1015 },
1016 }
1017}
1018
1019pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
1021 let v = eval_expr(expr, table, rowid)?;
1022 match v {
1023 Value::Bool(b) => Ok(b),
1024 Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
1026 other => Err(SQLRiteError::Internal(format!(
1027 "WHERE clause must evaluate to boolean, got {}",
1028 other.to_display_string()
1029 ))),
1030 }
1031}
1032
1033fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
1034 match expr {
1035 Expr::Nested(inner) => eval_expr(inner, table, rowid),
1036
1037 Expr::Identifier(ident) => {
1038 if ident.quote_style == Some('[') {
1048 let raw = format!("[{}]", ident.value);
1049 let v = parse_vector_literal(&raw)?;
1050 return Ok(Value::Vector(v));
1051 }
1052 Ok(table.get_value(&ident.value, rowid).unwrap_or(Value::Null))
1053 }
1054
1055 Expr::CompoundIdentifier(parts) => {
1056 let col = parts
1058 .last()
1059 .map(|i| i.value.as_str())
1060 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
1061 Ok(table.get_value(col, rowid).unwrap_or(Value::Null))
1062 }
1063
1064 Expr::Value(v) => convert_literal(&v.value),
1065
1066 Expr::UnaryOp { op, expr } => {
1067 let inner = eval_expr(expr, table, rowid)?;
1068 match op {
1069 UnaryOperator::Not => match inner {
1070 Value::Bool(b) => Ok(Value::Bool(!b)),
1071 Value::Null => Ok(Value::Null),
1072 other => Err(SQLRiteError::Internal(format!(
1073 "NOT applied to non-boolean value: {}",
1074 other.to_display_string()
1075 ))),
1076 },
1077 UnaryOperator::Minus => match inner {
1078 Value::Integer(i) => Ok(Value::Integer(-i)),
1079 Value::Real(f) => Ok(Value::Real(-f)),
1080 Value::Null => Ok(Value::Null),
1081 other => Err(SQLRiteError::Internal(format!(
1082 "unary minus on non-numeric value: {}",
1083 other.to_display_string()
1084 ))),
1085 },
1086 UnaryOperator::Plus => Ok(inner),
1087 other => Err(SQLRiteError::NotImplemented(format!(
1088 "unary operator {other:?} is not supported"
1089 ))),
1090 }
1091 }
1092
1093 Expr::BinaryOp { left, op, right } => match op {
1094 BinaryOperator::And => {
1095 let l = eval_expr(left, table, rowid)?;
1096 let r = eval_expr(right, table, rowid)?;
1097 Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
1098 }
1099 BinaryOperator::Or => {
1100 let l = eval_expr(left, table, rowid)?;
1101 let r = eval_expr(right, table, rowid)?;
1102 Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
1103 }
1104 cmp @ (BinaryOperator::Eq
1105 | BinaryOperator::NotEq
1106 | BinaryOperator::Lt
1107 | BinaryOperator::LtEq
1108 | BinaryOperator::Gt
1109 | BinaryOperator::GtEq) => {
1110 let l = eval_expr(left, table, rowid)?;
1111 let r = eval_expr(right, table, rowid)?;
1112 if matches!(l, Value::Null) || matches!(r, Value::Null) {
1114 return Ok(Value::Bool(false));
1115 }
1116 let ord = compare_values(Some(&l), Some(&r));
1117 let result = match cmp {
1118 BinaryOperator::Eq => ord == Ordering::Equal,
1119 BinaryOperator::NotEq => ord != Ordering::Equal,
1120 BinaryOperator::Lt => ord == Ordering::Less,
1121 BinaryOperator::LtEq => ord != Ordering::Greater,
1122 BinaryOperator::Gt => ord == Ordering::Greater,
1123 BinaryOperator::GtEq => ord != Ordering::Less,
1124 _ => unreachable!(),
1125 };
1126 Ok(Value::Bool(result))
1127 }
1128 arith @ (BinaryOperator::Plus
1129 | BinaryOperator::Minus
1130 | BinaryOperator::Multiply
1131 | BinaryOperator::Divide
1132 | BinaryOperator::Modulo) => {
1133 let l = eval_expr(left, table, rowid)?;
1134 let r = eval_expr(right, table, rowid)?;
1135 eval_arith(arith, &l, &r)
1136 }
1137 BinaryOperator::StringConcat => {
1138 let l = eval_expr(left, table, rowid)?;
1139 let r = eval_expr(right, table, rowid)?;
1140 if matches!(l, Value::Null) || matches!(r, Value::Null) {
1141 return Ok(Value::Null);
1142 }
1143 Ok(Value::Text(format!(
1144 "{}{}",
1145 l.to_display_string(),
1146 r.to_display_string()
1147 )))
1148 }
1149 other => Err(SQLRiteError::NotImplemented(format!(
1150 "binary operator {other:?} is not supported yet"
1151 ))),
1152 },
1153
1154 Expr::Function(func) => eval_function(func, table, rowid),
1165
1166 other => Err(SQLRiteError::NotImplemented(format!(
1167 "unsupported expression in WHERE/projection: {other:?}"
1168 ))),
1169 }
1170}
1171
1172fn eval_function(func: &sqlparser::ast::Function, table: &Table, rowid: i64) -> Result<Value> {
1177 let name = match func.name.0.as_slice() {
1180 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1181 _ => {
1182 return Err(SQLRiteError::NotImplemented(format!(
1183 "qualified function names not supported: {:?}",
1184 func.name
1185 )));
1186 }
1187 };
1188
1189 match name.as_str() {
1190 "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
1191 let (a, b) = extract_two_vector_args(&name, &func.args, table, rowid)?;
1192 let dist = match name.as_str() {
1193 "vec_distance_l2" => vec_distance_l2(&a, &b),
1194 "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
1195 "vec_distance_dot" => vec_distance_dot(&a, &b),
1196 _ => unreachable!(),
1197 };
1198 Ok(Value::Real(dist as f64))
1204 }
1205 "json_extract" => json_fn_extract(&name, &func.args, table, rowid),
1210 "json_type" => json_fn_type(&name, &func.args, table, rowid),
1211 "json_array_length" => json_fn_array_length(&name, &func.args, table, rowid),
1212 "json_object_keys" => json_fn_object_keys(&name, &func.args, table, rowid),
1213 other => Err(SQLRiteError::NotImplemented(format!(
1214 "unknown function: {other}(...)"
1215 ))),
1216 }
1217}
1218
1219fn extract_json_and_path(
1233 fn_name: &str,
1234 args: &FunctionArguments,
1235 table: &Table,
1236 rowid: i64,
1237) -> Result<(String, String)> {
1238 let arg_list = match args {
1239 FunctionArguments::List(l) => &l.args,
1240 _ => {
1241 return Err(SQLRiteError::General(format!(
1242 "{fn_name}() expects 1 or 2 arguments"
1243 )));
1244 }
1245 };
1246 if !(arg_list.len() == 1 || arg_list.len() == 2) {
1247 return Err(SQLRiteError::General(format!(
1248 "{fn_name}() expects 1 or 2 arguments, got {}",
1249 arg_list.len()
1250 )));
1251 }
1252 let first_expr = match &arg_list[0] {
1254 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1255 other => {
1256 return Err(SQLRiteError::NotImplemented(format!(
1257 "{fn_name}() argument 0 has unsupported shape: {other:?}"
1258 )));
1259 }
1260 };
1261 let json_text = match eval_expr(first_expr, table, rowid)? {
1262 Value::Text(s) => s,
1263 Value::Null => {
1264 return Err(SQLRiteError::General(format!(
1265 "{fn_name}() called on NULL — JSON column has no value for this row"
1266 )));
1267 }
1268 other => {
1269 return Err(SQLRiteError::General(format!(
1270 "{fn_name}() argument 0 is not JSON-typed: got {}",
1271 other.to_display_string()
1272 )));
1273 }
1274 };
1275
1276 let path = if arg_list.len() == 2 {
1278 let path_expr = match &arg_list[1] {
1279 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1280 other => {
1281 return Err(SQLRiteError::NotImplemented(format!(
1282 "{fn_name}() argument 1 has unsupported shape: {other:?}"
1283 )));
1284 }
1285 };
1286 match eval_expr(path_expr, table, rowid)? {
1287 Value::Text(s) => s,
1288 other => {
1289 return Err(SQLRiteError::General(format!(
1290 "{fn_name}() path argument must be a string literal, got {}",
1291 other.to_display_string()
1292 )));
1293 }
1294 }
1295 } else {
1296 "$".to_string()
1297 };
1298
1299 Ok((json_text, path))
1300}
1301
1302fn walk_json_path<'a>(
1312 value: &'a serde_json::Value,
1313 path: &str,
1314) -> Result<Option<&'a serde_json::Value>> {
1315 let mut chars = path.chars().peekable();
1316 if chars.next() != Some('$') {
1317 return Err(SQLRiteError::General(format!(
1318 "JSON path must start with '$', got `{path}`"
1319 )));
1320 }
1321 let mut current = value;
1322 while let Some(&c) = chars.peek() {
1323 match c {
1324 '.' => {
1325 chars.next();
1326 let mut key = String::new();
1327 while let Some(&c) = chars.peek() {
1328 if c == '.' || c == '[' {
1329 break;
1330 }
1331 key.push(c);
1332 chars.next();
1333 }
1334 if key.is_empty() {
1335 return Err(SQLRiteError::General(format!(
1336 "JSON path has empty key after '.' in `{path}`"
1337 )));
1338 }
1339 match current.get(&key) {
1340 Some(v) => current = v,
1341 None => return Ok(None),
1342 }
1343 }
1344 '[' => {
1345 chars.next();
1346 let mut idx_str = String::new();
1347 while let Some(&c) = chars.peek() {
1348 if c == ']' {
1349 break;
1350 }
1351 idx_str.push(c);
1352 chars.next();
1353 }
1354 if chars.next() != Some(']') {
1355 return Err(SQLRiteError::General(format!(
1356 "JSON path has unclosed `[` in `{path}`"
1357 )));
1358 }
1359 let idx: usize = idx_str.trim().parse().map_err(|_| {
1360 SQLRiteError::General(format!(
1361 "JSON path has non-integer index `[{idx_str}]` in `{path}`"
1362 ))
1363 })?;
1364 match current.get(idx) {
1365 Some(v) => current = v,
1366 None => return Ok(None),
1367 }
1368 }
1369 other => {
1370 return Err(SQLRiteError::General(format!(
1371 "JSON path has unexpected character `{other}` in `{path}` \
1372 (expected `.`, `[`, or end-of-path)"
1373 )));
1374 }
1375 }
1376 }
1377 Ok(Some(current))
1378}
1379
1380fn json_value_to_sql(v: &serde_json::Value) -> Value {
1384 match v {
1385 serde_json::Value::Null => Value::Null,
1386 serde_json::Value::Bool(b) => Value::Bool(*b),
1387 serde_json::Value::Number(n) => {
1388 if let Some(i) = n.as_i64() {
1390 Value::Integer(i)
1391 } else if let Some(f) = n.as_f64() {
1392 Value::Real(f)
1393 } else {
1394 Value::Null
1395 }
1396 }
1397 serde_json::Value::String(s) => Value::Text(s.clone()),
1398 composite => Value::Text(composite.to_string()),
1402 }
1403}
1404
1405fn json_fn_extract(
1406 name: &str,
1407 args: &FunctionArguments,
1408 table: &Table,
1409 rowid: i64,
1410) -> Result<Value> {
1411 let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
1412 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
1413 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
1414 })?;
1415 match walk_json_path(&parsed, &path)? {
1416 Some(v) => Ok(json_value_to_sql(v)),
1417 None => Ok(Value::Null),
1418 }
1419}
1420
1421fn json_fn_type(name: &str, args: &FunctionArguments, table: &Table, rowid: i64) -> Result<Value> {
1422 let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
1423 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
1424 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
1425 })?;
1426 let resolved = match walk_json_path(&parsed, &path)? {
1427 Some(v) => v,
1428 None => return Ok(Value::Null),
1429 };
1430 let ty = match resolved {
1431 serde_json::Value::Null => "null",
1432 serde_json::Value::Bool(true) => "true",
1433 serde_json::Value::Bool(false) => "false",
1434 serde_json::Value::Number(n) => {
1435 if n.is_i64() || n.is_u64() {
1436 "integer"
1437 } else {
1438 "real"
1439 }
1440 }
1441 serde_json::Value::String(_) => "text",
1442 serde_json::Value::Array(_) => "array",
1443 serde_json::Value::Object(_) => "object",
1444 };
1445 Ok(Value::Text(ty.to_string()))
1446}
1447
1448fn json_fn_array_length(
1449 name: &str,
1450 args: &FunctionArguments,
1451 table: &Table,
1452 rowid: i64,
1453) -> Result<Value> {
1454 let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
1455 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
1456 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
1457 })?;
1458 let resolved = match walk_json_path(&parsed, &path)? {
1459 Some(v) => v,
1460 None => return Ok(Value::Null),
1461 };
1462 match resolved.as_array() {
1463 Some(arr) => Ok(Value::Integer(arr.len() as i64)),
1464 None => Err(SQLRiteError::General(format!(
1465 "{name}() resolved to a non-array value at path `{path}`"
1466 ))),
1467 }
1468}
1469
1470fn json_fn_object_keys(
1471 name: &str,
1472 args: &FunctionArguments,
1473 table: &Table,
1474 rowid: i64,
1475) -> Result<Value> {
1476 let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
1477 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
1478 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
1479 })?;
1480 let resolved = match walk_json_path(&parsed, &path)? {
1481 Some(v) => v,
1482 None => return Ok(Value::Null),
1483 };
1484 let obj = resolved.as_object().ok_or_else(|| {
1485 SQLRiteError::General(format!(
1486 "{name}() resolved to a non-object value at path `{path}`"
1487 ))
1488 })?;
1489 let keys: Vec<serde_json::Value> = obj
1496 .keys()
1497 .map(|k| serde_json::Value::String(k.clone()))
1498 .collect();
1499 Ok(Value::Text(serde_json::Value::Array(keys).to_string()))
1500}
1501
1502fn extract_two_vector_args(
1506 fn_name: &str,
1507 args: &FunctionArguments,
1508 table: &Table,
1509 rowid: i64,
1510) -> Result<(Vec<f32>, Vec<f32>)> {
1511 let arg_list = match args {
1512 FunctionArguments::List(l) => &l.args,
1513 _ => {
1514 return Err(SQLRiteError::General(format!(
1515 "{fn_name}() expects exactly two vector arguments"
1516 )));
1517 }
1518 };
1519 if arg_list.len() != 2 {
1520 return Err(SQLRiteError::General(format!(
1521 "{fn_name}() expects exactly 2 arguments, got {}",
1522 arg_list.len()
1523 )));
1524 }
1525 let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
1526 for (i, arg) in arg_list.iter().enumerate() {
1527 let expr = match arg {
1528 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1529 other => {
1530 return Err(SQLRiteError::NotImplemented(format!(
1531 "{fn_name}() argument {i} has unsupported shape: {other:?}"
1532 )));
1533 }
1534 };
1535 let val = eval_expr(expr, table, rowid)?;
1536 match val {
1537 Value::Vector(v) => out.push(v),
1538 other => {
1539 return Err(SQLRiteError::General(format!(
1540 "{fn_name}() argument {i} is not a vector: got {}",
1541 other.to_display_string()
1542 )));
1543 }
1544 }
1545 }
1546 let b = out.pop().unwrap();
1547 let a = out.pop().unwrap();
1548 if a.len() != b.len() {
1549 return Err(SQLRiteError::General(format!(
1550 "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
1551 a.len(),
1552 b.len()
1553 )));
1554 }
1555 Ok((a, b))
1556}
1557
1558pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
1561 debug_assert_eq!(a.len(), b.len());
1562 let mut sum = 0.0f32;
1563 for i in 0..a.len() {
1564 let d = a[i] - b[i];
1565 sum += d * d;
1566 }
1567 sum.sqrt()
1568}
1569
1570pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
1580 debug_assert_eq!(a.len(), b.len());
1581 let mut dot = 0.0f32;
1582 let mut norm_a_sq = 0.0f32;
1583 let mut norm_b_sq = 0.0f32;
1584 for i in 0..a.len() {
1585 dot += a[i] * b[i];
1586 norm_a_sq += a[i] * a[i];
1587 norm_b_sq += b[i] * b[i];
1588 }
1589 let denom = (norm_a_sq * norm_b_sq).sqrt();
1590 if denom == 0.0 {
1591 return Err(SQLRiteError::General(
1592 "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
1593 ));
1594 }
1595 Ok(1.0 - dot / denom)
1596}
1597
1598pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
1602 debug_assert_eq!(a.len(), b.len());
1603 let mut dot = 0.0f32;
1604 for i in 0..a.len() {
1605 dot += a[i] * b[i];
1606 }
1607 -dot
1608}
1609
1610fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
1613 if matches!(l, Value::Null) || matches!(r, Value::Null) {
1614 return Ok(Value::Null);
1615 }
1616 match (l, r) {
1617 (Value::Integer(a), Value::Integer(b)) => match op {
1618 BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
1619 BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
1620 BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
1621 BinaryOperator::Divide => {
1622 if *b == 0 {
1623 Err(SQLRiteError::General("division by zero".to_string()))
1624 } else {
1625 Ok(Value::Integer(a / b))
1626 }
1627 }
1628 BinaryOperator::Modulo => {
1629 if *b == 0 {
1630 Err(SQLRiteError::General("modulo by zero".to_string()))
1631 } else {
1632 Ok(Value::Integer(a % b))
1633 }
1634 }
1635 _ => unreachable!(),
1636 },
1637 (a, b) => {
1639 let af = as_number(a)?;
1640 let bf = as_number(b)?;
1641 match op {
1642 BinaryOperator::Plus => Ok(Value::Real(af + bf)),
1643 BinaryOperator::Minus => Ok(Value::Real(af - bf)),
1644 BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
1645 BinaryOperator::Divide => {
1646 if bf == 0.0 {
1647 Err(SQLRiteError::General("division by zero".to_string()))
1648 } else {
1649 Ok(Value::Real(af / bf))
1650 }
1651 }
1652 BinaryOperator::Modulo => {
1653 if bf == 0.0 {
1654 Err(SQLRiteError::General("modulo by zero".to_string()))
1655 } else {
1656 Ok(Value::Real(af % bf))
1657 }
1658 }
1659 _ => unreachable!(),
1660 }
1661 }
1662 }
1663}
1664
1665fn as_number(v: &Value) -> Result<f64> {
1666 match v {
1667 Value::Integer(i) => Ok(*i as f64),
1668 Value::Real(f) => Ok(*f),
1669 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
1670 other => Err(SQLRiteError::General(format!(
1671 "arithmetic on non-numeric value '{}'",
1672 other.to_display_string()
1673 ))),
1674 }
1675}
1676
1677fn as_bool(v: &Value) -> Result<bool> {
1678 match v {
1679 Value::Bool(b) => Ok(*b),
1680 Value::Null => Ok(false),
1681 Value::Integer(i) => Ok(*i != 0),
1682 other => Err(SQLRiteError::Internal(format!(
1683 "expected boolean, got {}",
1684 other.to_display_string()
1685 ))),
1686 }
1687}
1688
1689fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
1690 use sqlparser::ast::Value as AstValue;
1691 match v {
1692 AstValue::Number(n, _) => {
1693 if let Ok(i) = n.parse::<i64>() {
1694 Ok(Value::Integer(i))
1695 } else if let Ok(f) = n.parse::<f64>() {
1696 Ok(Value::Real(f))
1697 } else {
1698 Err(SQLRiteError::Internal(format!(
1699 "could not parse numeric literal '{n}'"
1700 )))
1701 }
1702 }
1703 AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
1704 AstValue::Boolean(b) => Ok(Value::Bool(*b)),
1705 AstValue::Null => Ok(Value::Null),
1706 other => Err(SQLRiteError::NotImplemented(format!(
1707 "unsupported literal value: {other:?}"
1708 ))),
1709 }
1710}
1711
1712#[cfg(test)]
1713mod tests {
1714 use super::*;
1715
1716 fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
1723 (a - b).abs() < eps
1724 }
1725
1726 #[test]
1727 fn vec_distance_l2_identical_is_zero() {
1728 let v = vec![0.1, 0.2, 0.3];
1729 assert_eq!(vec_distance_l2(&v, &v), 0.0);
1730 }
1731
1732 #[test]
1733 fn vec_distance_l2_unit_basis_is_sqrt2() {
1734 let a = vec![1.0, 0.0];
1736 let b = vec![0.0, 1.0];
1737 assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
1738 }
1739
1740 #[test]
1741 fn vec_distance_l2_known_value() {
1742 let a = vec![0.0, 0.0, 0.0];
1744 let b = vec![3.0, 4.0, 0.0];
1745 assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
1746 }
1747
1748 #[test]
1749 fn vec_distance_cosine_identical_is_zero() {
1750 let v = vec![0.1, 0.2, 0.3];
1751 let d = vec_distance_cosine(&v, &v).unwrap();
1752 assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
1753 }
1754
1755 #[test]
1756 fn vec_distance_cosine_orthogonal_is_one() {
1757 let a = vec![1.0, 0.0];
1760 let b = vec![0.0, 1.0];
1761 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
1762 }
1763
1764 #[test]
1765 fn vec_distance_cosine_opposite_is_two() {
1766 let a = vec![1.0, 0.0, 0.0];
1768 let b = vec![-1.0, 0.0, 0.0];
1769 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
1770 }
1771
1772 #[test]
1773 fn vec_distance_cosine_zero_magnitude_errors() {
1774 let a = vec![0.0, 0.0];
1776 let b = vec![1.0, 0.0];
1777 let err = vec_distance_cosine(&a, &b).unwrap_err();
1778 assert!(format!("{err}").contains("zero-magnitude"));
1779 }
1780
1781 #[test]
1782 fn vec_distance_dot_negates() {
1783 let a = vec![1.0, 2.0, 3.0];
1785 let b = vec![4.0, 5.0, 6.0];
1786 assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
1787 }
1788
1789 #[test]
1790 fn vec_distance_dot_orthogonal_is_zero() {
1791 let a = vec![1.0, 0.0];
1793 let b = vec![0.0, 1.0];
1794 assert_eq!(vec_distance_dot(&a, &b), 0.0);
1795 }
1796
1797 #[test]
1798 fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
1799 let a = vec![0.6f32, 0.8]; let b = vec![0.8f32, 0.6]; let dot = vec_distance_dot(&a, &b);
1805 let cos = vec_distance_cosine(&a, &b).unwrap();
1806 assert!(approx_eq(dot, cos - 1.0, 1e-5));
1807 }
1808
1809 use crate::sql::db::database::Database;
1814 use crate::sql::parser::select::SelectQuery;
1815 use sqlparser::dialect::SQLiteDialect;
1816 use sqlparser::parser::Parser;
1817
1818 fn seed_score_table(n: usize) -> Database {
1831 let mut db = Database::new("tempdb".to_string());
1832 crate::sql::process_command(
1833 "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
1834 &mut db,
1835 )
1836 .expect("create");
1837 for i in 0..n {
1838 let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
1842 let sql = format!("INSERT INTO docs (score) VALUES ({score});");
1843 crate::sql::process_command(&sql, &mut db).expect("insert");
1844 }
1845 db
1846 }
1847
1848 fn parse_select(sql: &str) -> SelectQuery {
1852 let dialect = SQLiteDialect {};
1853 let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
1854 let stmt = ast.pop().expect("one statement");
1855 SelectQuery::new(&stmt).expect("select-query")
1856 }
1857
1858 #[test]
1859 fn topk_matches_full_sort_asc() {
1860 let db = seed_score_table(200);
1863 let table = db.get_table("docs".to_string()).unwrap();
1864 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
1865 let order = q.order_by.as_ref().unwrap();
1866 let all_rowids = table.rowids();
1867
1868 let mut full = all_rowids.clone();
1870 sort_rowids(&mut full, table, order).unwrap();
1871 full.truncate(10);
1872
1873 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
1875
1876 assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
1877 }
1878
1879 #[test]
1880 fn topk_matches_full_sort_desc() {
1881 let db = seed_score_table(200);
1883 let table = db.get_table("docs".to_string()).unwrap();
1884 let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
1885 let order = q.order_by.as_ref().unwrap();
1886 let all_rowids = table.rowids();
1887
1888 let mut full = all_rowids.clone();
1889 sort_rowids(&mut full, table, order).unwrap();
1890 full.truncate(10);
1891
1892 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
1893
1894 assert_eq!(
1895 topk, full,
1896 "top-k DESC via heap should match full-sort+truncate"
1897 );
1898 }
1899
1900 #[test]
1901 fn topk_k_larger_than_n_returns_everything_sorted() {
1902 let db = seed_score_table(50);
1907 let table = db.get_table("docs".to_string()).unwrap();
1908 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
1909 let order = q.order_by.as_ref().unwrap();
1910 let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
1911 assert_eq!(topk.len(), 50);
1912 let scores: Vec<f64> = topk
1914 .iter()
1915 .filter_map(|r| match table.get_value("score", *r) {
1916 Some(Value::Real(f)) => Some(f),
1917 _ => None,
1918 })
1919 .collect();
1920 assert!(scores.windows(2).all(|w| w[0] <= w[1]));
1921 }
1922
1923 #[test]
1924 fn topk_k_zero_returns_empty() {
1925 let db = seed_score_table(10);
1926 let table = db.get_table("docs".to_string()).unwrap();
1927 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
1928 let order = q.order_by.as_ref().unwrap();
1929 let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
1930 assert!(topk.is_empty());
1931 }
1932
1933 #[test]
1934 fn topk_empty_input_returns_empty() {
1935 let db = seed_score_table(0);
1936 let table = db.get_table("docs".to_string()).unwrap();
1937 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
1938 let order = q.order_by.as_ref().unwrap();
1939 let topk = select_topk(&[], table, order, 5).unwrap();
1940 assert!(topk.is_empty());
1941 }
1942
1943 #[test]
1944 fn topk_works_through_select_executor_with_distance_function() {
1945 let mut db = Database::new("tempdb".to_string());
1949 crate::sql::process_command(
1950 "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
1951 &mut db,
1952 )
1953 .unwrap();
1954 for v in &[
1961 "[1.0, 0.0]",
1962 "[2.0, 0.0]",
1963 "[0.0, 3.0]",
1964 "[1.0, 4.0]",
1965 "[10.0, 10.0]",
1966 ] {
1967 crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
1968 .unwrap();
1969 }
1970 let resp = crate::sql::process_command(
1971 "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
1972 &mut db,
1973 )
1974 .unwrap();
1975 assert!(resp.contains("3 rows returned"), "got: {resp}");
1978 }
1979
1980 #[test]
2003 #[ignore]
2004 fn topk_benchmark() {
2005 use std::time::Instant;
2006 const N: usize = 10_000;
2007 const K: usize = 10;
2008
2009 let db = seed_score_table(N);
2010 let table = db.get_table("docs".to_string()).unwrap();
2011 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
2012 let order = q.order_by.as_ref().unwrap();
2013 let all_rowids = table.rowids();
2014
2015 let t0 = Instant::now();
2017 let _topk = select_topk(&all_rowids, table, order, K).unwrap();
2018 let heap_dur = t0.elapsed();
2019
2020 let t1 = Instant::now();
2022 let mut full = all_rowids.clone();
2023 sort_rowids(&mut full, table, order).unwrap();
2024 full.truncate(K);
2025 let sort_dur = t1.elapsed();
2026
2027 let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
2028 println!("\n--- topk_benchmark (N={N}, k={K}) ---");
2029 println!(" bounded heap: {heap_dur:?}");
2030 println!(" full sort+trunc: {sort_dur:?}");
2031 println!(" speedup ratio: {ratio:.2}×");
2032
2033 assert!(
2040 ratio > 1.4,
2041 "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
2042 );
2043 }
2044}