1use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8 AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr, FromTable, FunctionArg,
9 FunctionArgExpr, FunctionArguments, ObjectNamePart, Statement, TableFactor, TableWithJoins,
10 UnaryOperator, Update,
11};
12
13use crate::error::{Result, SQLRiteError};
14use crate::sql::db::database::Database;
15use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
16use crate::sql::db::table::{DataType, Table, Value, parse_vector_literal};
17use crate::sql::parser::select::{OrderByClause, Projection, SelectQuery};
18
19pub struct SelectResult {
28 pub columns: Vec<String>,
29 pub rows: Vec<Vec<Value>>,
30}
31
32pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
36 let table = db
37 .get_table(query.table_name.clone())
38 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
39
40 let projected_cols: Vec<String> = match &query.projection {
42 Projection::All => table.column_names(),
43 Projection::Columns(cols) => {
44 for c in cols {
45 if !table.contains_column(c.to_string()) {
46 return Err(SQLRiteError::Internal(format!(
47 "Column '{c}' does not exist on table '{}'",
48 query.table_name
49 )));
50 }
51 }
52 cols.clone()
53 }
54 };
55
56 let matching = match select_rowids(table, query.selection.as_ref())? {
60 RowidSource::IndexProbe(rowids) => rowids,
61 RowidSource::FullScan => {
62 let mut out = Vec::new();
63 for rowid in table.rowids() {
64 if let Some(expr) = &query.selection {
65 if !eval_predicate(expr, table, rowid)? {
66 continue;
67 }
68 }
69 out.push(rowid);
70 }
71 out
72 }
73 };
74 let mut matching = matching;
75
76 match (&query.order_by, query.limit) {
96 (Some(order), Some(k)) if k < matching.len() => {
97 matching = select_topk(&matching, table, order, k)?;
98 }
99 (Some(order), _) => {
100 sort_rowids(&mut matching, table, order)?;
101 if let Some(k) = query.limit {
102 matching.truncate(k);
103 }
104 }
105 (None, Some(k)) => {
106 matching.truncate(k);
107 }
108 (None, None) => {}
109 }
110
111 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
115 for rowid in &matching {
116 let row: Vec<Value> = projected_cols
117 .iter()
118 .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
119 .collect();
120 rows.push(row);
121 }
122
123 Ok(SelectResult {
124 columns: projected_cols,
125 rows,
126 })
127}
128
129pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
134 let result = execute_select_rows(query, db)?;
135 let row_count = result.rows.len();
136
137 let mut print_table = PrintTable::new();
138 let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
139 print_table.add_row(PrintRow::new(header_cells));
140
141 for row in &result.rows {
142 let cells: Vec<PrintCell> = row
143 .iter()
144 .map(|v| PrintCell::new(&v.to_display_string()))
145 .collect();
146 print_table.add_row(PrintRow::new(cells));
147 }
148
149 Ok((print_table.to_string(), row_count))
150}
151
152pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
154 let Statement::Delete(Delete {
155 from, selection, ..
156 }) = stmt
157 else {
158 return Err(SQLRiteError::Internal(
159 "execute_delete called on a non-DELETE statement".to_string(),
160 ));
161 };
162
163 let tables = match from {
164 FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
165 };
166 let table_name = extract_single_table_name(tables)?;
167
168 let matching: Vec<i64> = {
170 let table = db
171 .get_table(table_name.clone())
172 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
173 match select_rowids(table, selection.as_ref())? {
174 RowidSource::IndexProbe(rowids) => rowids,
175 RowidSource::FullScan => {
176 let mut out = Vec::new();
177 for rowid in table.rowids() {
178 if let Some(expr) = selection {
179 if !eval_predicate(expr, table, rowid)? {
180 continue;
181 }
182 }
183 out.push(rowid);
184 }
185 out
186 }
187 }
188 };
189
190 let table = db.get_table_mut(table_name)?;
191 for rowid in &matching {
192 table.delete_row(*rowid);
193 }
194 Ok(matching.len())
195}
196
197pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
199 let Statement::Update(Update {
200 table,
201 assignments,
202 from,
203 selection,
204 ..
205 }) = stmt
206 else {
207 return Err(SQLRiteError::Internal(
208 "execute_update called on a non-UPDATE statement".to_string(),
209 ));
210 };
211
212 if from.is_some() {
213 return Err(SQLRiteError::NotImplemented(
214 "UPDATE ... FROM is not supported yet".to_string(),
215 ));
216 }
217
218 let table_name = extract_table_name(table)?;
219
220 let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
222 {
223 let tbl = db
224 .get_table(table_name.clone())
225 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
226 for a in assignments {
227 let col = match &a.target {
228 AssignmentTarget::ColumnName(name) => name
229 .0
230 .last()
231 .map(|p| p.to_string())
232 .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
233 AssignmentTarget::Tuple(_) => {
234 return Err(SQLRiteError::NotImplemented(
235 "tuple assignment targets are not supported".to_string(),
236 ));
237 }
238 };
239 if !tbl.contains_column(col.clone()) {
240 return Err(SQLRiteError::Internal(format!(
241 "UPDATE references unknown column '{col}'"
242 )));
243 }
244 parsed_assignments.push((col, a.value.clone()));
245 }
246 }
247
248 let work: Vec<(i64, Vec<(String, Value)>)> = {
252 let tbl = db.get_table(table_name.clone())?;
253 let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
254 RowidSource::IndexProbe(rowids) => rowids,
255 RowidSource::FullScan => {
256 let mut out = Vec::new();
257 for rowid in tbl.rowids() {
258 if let Some(expr) = selection {
259 if !eval_predicate(expr, tbl, rowid)? {
260 continue;
261 }
262 }
263 out.push(rowid);
264 }
265 out
266 }
267 };
268 let mut rows_to_update = Vec::new();
269 for rowid in matched_rowids {
270 let mut values = Vec::with_capacity(parsed_assignments.len());
271 for (col, expr) in &parsed_assignments {
272 let v = eval_expr(expr, tbl, rowid)?;
275 values.push((col.clone(), v));
276 }
277 rows_to_update.push((rowid, values));
278 }
279 rows_to_update
280 };
281
282 let tbl = db.get_table_mut(table_name)?;
283 for (rowid, values) in &work {
284 for (col, v) in values {
285 tbl.set_value(col, *rowid, v.clone())?;
286 }
287 }
288 Ok(work.len())
289}
290
291pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
295 let Statement::CreateIndex(CreateIndex {
296 name,
297 table_name,
298 columns,
299 unique,
300 if_not_exists,
301 predicate,
302 ..
303 }) = stmt
304 else {
305 return Err(SQLRiteError::Internal(
306 "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
307 ));
308 };
309
310 if predicate.is_some() {
311 return Err(SQLRiteError::NotImplemented(
312 "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
313 ));
314 }
315
316 if columns.len() != 1 {
317 return Err(SQLRiteError::NotImplemented(format!(
318 "multi-column indexes are not supported yet ({} columns given)",
319 columns.len()
320 )));
321 }
322
323 let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
324 SQLRiteError::NotImplemented(
325 "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
326 )
327 })?;
328
329 let table_name_str = table_name.to_string();
330 let column_name = match &columns[0].column.expr {
331 Expr::Identifier(ident) => ident.value.clone(),
332 Expr::CompoundIdentifier(parts) => parts
333 .last()
334 .map(|p| p.value.clone())
335 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
336 other => {
337 return Err(SQLRiteError::NotImplemented(format!(
338 "CREATE INDEX only supports simple column references, got {other:?}"
339 )));
340 }
341 };
342
343 let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
345 let table = db.get_table(table_name_str.clone()).map_err(|_| {
346 SQLRiteError::General(format!(
347 "CREATE INDEX references unknown table '{table_name_str}'"
348 ))
349 })?;
350 if !table.contains_column(column_name.clone()) {
351 return Err(SQLRiteError::General(format!(
352 "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
353 )));
354 }
355 let col = table
356 .columns
357 .iter()
358 .find(|c| c.column_name == column_name)
359 .expect("we just verified the column exists");
360 if table.index_by_name(&index_name).is_some() {
361 if *if_not_exists {
362 return Ok(index_name);
363 }
364 return Err(SQLRiteError::General(format!(
365 "index '{index_name}' already exists"
366 )));
367 }
368 let datatype = clone_datatype(&col.datatype);
369
370 let mut pairs = Vec::new();
374 for rowid in table.rowids() {
375 if let Some(v) = table.get_value(&column_name, rowid) {
376 pairs.push((rowid, v));
377 }
378 }
379 (datatype, pairs)
380 };
381
382 let mut idx = SecondaryIndex::new(
384 index_name.clone(),
385 table_name_str.clone(),
386 column_name.clone(),
387 &datatype,
388 *unique,
389 IndexOrigin::Explicit,
390 )?;
391
392 for (rowid, v) in &existing_rowids_and_values {
396 if *unique && idx.would_violate_unique(v) {
397 return Err(SQLRiteError::General(format!(
398 "cannot create UNIQUE index '{index_name}': column '{column_name}' \
399 already contains the duplicate value {}",
400 v.to_display_string()
401 )));
402 }
403 idx.insert(v, *rowid)?;
404 }
405
406 let table_mut = db.get_table_mut(table_name_str)?;
408 table_mut.secondary_indexes.push(idx);
409 Ok(index_name)
410}
411
412fn clone_datatype(dt: &DataType) -> DataType {
415 match dt {
416 DataType::Integer => DataType::Integer,
417 DataType::Text => DataType::Text,
418 DataType::Real => DataType::Real,
419 DataType::Bool => DataType::Bool,
420 DataType::Vector(dim) => DataType::Vector(*dim),
421 DataType::None => DataType::None,
422 DataType::Invalid => DataType::Invalid,
423 }
424}
425
426fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
427 if tables.len() != 1 {
428 return Err(SQLRiteError::NotImplemented(
429 "multi-table DELETE is not supported yet".to_string(),
430 ));
431 }
432 extract_table_name(&tables[0])
433}
434
435fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
436 if !twj.joins.is_empty() {
437 return Err(SQLRiteError::NotImplemented(
438 "JOIN is not supported yet".to_string(),
439 ));
440 }
441 match &twj.relation {
442 TableFactor::Table { name, .. } => Ok(name.to_string()),
443 _ => Err(SQLRiteError::NotImplemented(
444 "only plain table references are supported".to_string(),
445 )),
446 }
447}
448
449enum RowidSource {
451 IndexProbe(Vec<i64>),
455 FullScan,
458}
459
460fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
465 let Some(expr) = selection else {
466 return Ok(RowidSource::FullScan);
467 };
468 let Some((col, literal)) = try_extract_equality(expr) else {
469 return Ok(RowidSource::FullScan);
470 };
471 let Some(idx) = table.index_for_column(&col) else {
472 return Ok(RowidSource::FullScan);
473 };
474
475 let literal_value = match convert_literal(&literal) {
479 Ok(v) => v,
480 Err(_) => return Ok(RowidSource::FullScan),
481 };
482
483 let mut rowids = idx.lookup(&literal_value);
487 rowids.sort_unstable();
488 Ok(RowidSource::IndexProbe(rowids))
489}
490
491fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
495 let peeled = match expr {
497 Expr::Nested(inner) => inner.as_ref(),
498 other => other,
499 };
500 let Expr::BinaryOp { left, op, right } = peeled else {
501 return None;
502 };
503 if !matches!(op, BinaryOperator::Eq) {
504 return None;
505 }
506 let col_from = |e: &Expr| -> Option<String> {
507 match e {
508 Expr::Identifier(ident) => Some(ident.value.clone()),
509 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
510 _ => None,
511 }
512 };
513 let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
514 if let Expr::Value(v) = e {
515 Some(v.value.clone())
516 } else {
517 None
518 }
519 };
520 if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
521 return Some((c, l));
522 }
523 if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
524 return Some((c, l));
525 }
526 None
527}
528
529struct HeapEntry {
542 key: Value,
543 rowid: i64,
544 asc: bool,
545}
546
547impl PartialEq for HeapEntry {
548 fn eq(&self, other: &Self) -> bool {
549 self.cmp(other) == Ordering::Equal
550 }
551}
552
553impl Eq for HeapEntry {}
554
555impl PartialOrd for HeapEntry {
556 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
557 Some(self.cmp(other))
558 }
559}
560
561impl Ord for HeapEntry {
562 fn cmp(&self, other: &Self) -> Ordering {
563 let raw = compare_values(Some(&self.key), Some(&other.key));
564 if self.asc { raw } else { raw.reverse() }
565 }
566}
567
568fn select_topk(
577 matching: &[i64],
578 table: &Table,
579 order: &OrderByClause,
580 k: usize,
581) -> Result<Vec<i64>> {
582 use std::collections::BinaryHeap;
583
584 if k == 0 || matching.is_empty() {
585 return Ok(Vec::new());
586 }
587
588 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
589
590 for &rowid in matching {
591 let key = eval_expr(&order.expr, table, rowid)?;
592 let entry = HeapEntry {
593 key,
594 rowid,
595 asc: order.ascending,
596 };
597
598 if heap.len() < k {
599 heap.push(entry);
600 } else {
601 if entry < *heap.peek().unwrap() {
605 heap.pop();
606 heap.push(entry);
607 }
608 }
609 }
610
611 Ok(heap
616 .into_sorted_vec()
617 .into_iter()
618 .map(|e| e.rowid)
619 .collect())
620}
621
622fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
623 let mut keys: Vec<(i64, Result<Value>)> = rowids
631 .iter()
632 .map(|r| (*r, eval_expr(&order.expr, table, *r)))
633 .collect();
634
635 for (_, k) in &keys {
639 if let Err(e) = k {
640 return Err(SQLRiteError::General(format!(
641 "ORDER BY expression failed: {e}"
642 )));
643 }
644 }
645
646 keys.sort_by(|(_, ka), (_, kb)| {
647 let va = ka.as_ref().unwrap();
650 let vb = kb.as_ref().unwrap();
651 let ord = compare_values(Some(va), Some(vb));
652 if order.ascending { ord } else { ord.reverse() }
653 });
654
655 for (i, (rowid, _)) in keys.into_iter().enumerate() {
657 rowids[i] = rowid;
658 }
659 Ok(())
660}
661
662fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
663 match (a, b) {
664 (None, None) => Ordering::Equal,
665 (None, _) => Ordering::Less,
666 (_, None) => Ordering::Greater,
667 (Some(a), Some(b)) => match (a, b) {
668 (Value::Null, Value::Null) => Ordering::Equal,
669 (Value::Null, _) => Ordering::Less,
670 (_, Value::Null) => Ordering::Greater,
671 (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
672 (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
673 (Value::Integer(x), Value::Real(y)) => {
674 (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
675 }
676 (Value::Real(x), Value::Integer(y)) => {
677 x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
678 }
679 (Value::Text(x), Value::Text(y)) => x.cmp(y),
680 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
681 (x, y) => x.to_display_string().cmp(&y.to_display_string()),
683 },
684 }
685}
686
687pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
689 let v = eval_expr(expr, table, rowid)?;
690 match v {
691 Value::Bool(b) => Ok(b),
692 Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
694 other => Err(SQLRiteError::Internal(format!(
695 "WHERE clause must evaluate to boolean, got {}",
696 other.to_display_string()
697 ))),
698 }
699}
700
701fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
702 match expr {
703 Expr::Nested(inner) => eval_expr(inner, table, rowid),
704
705 Expr::Identifier(ident) => {
706 if ident.quote_style == Some('[') {
716 let raw = format!("[{}]", ident.value);
717 let v = parse_vector_literal(&raw)?;
718 return Ok(Value::Vector(v));
719 }
720 Ok(table.get_value(&ident.value, rowid).unwrap_or(Value::Null))
721 }
722
723 Expr::CompoundIdentifier(parts) => {
724 let col = parts
726 .last()
727 .map(|i| i.value.as_str())
728 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
729 Ok(table.get_value(col, rowid).unwrap_or(Value::Null))
730 }
731
732 Expr::Value(v) => convert_literal(&v.value),
733
734 Expr::UnaryOp { op, expr } => {
735 let inner = eval_expr(expr, table, rowid)?;
736 match op {
737 UnaryOperator::Not => match inner {
738 Value::Bool(b) => Ok(Value::Bool(!b)),
739 Value::Null => Ok(Value::Null),
740 other => Err(SQLRiteError::Internal(format!(
741 "NOT applied to non-boolean value: {}",
742 other.to_display_string()
743 ))),
744 },
745 UnaryOperator::Minus => match inner {
746 Value::Integer(i) => Ok(Value::Integer(-i)),
747 Value::Real(f) => Ok(Value::Real(-f)),
748 Value::Null => Ok(Value::Null),
749 other => Err(SQLRiteError::Internal(format!(
750 "unary minus on non-numeric value: {}",
751 other.to_display_string()
752 ))),
753 },
754 UnaryOperator::Plus => Ok(inner),
755 other => Err(SQLRiteError::NotImplemented(format!(
756 "unary operator {other:?} is not supported"
757 ))),
758 }
759 }
760
761 Expr::BinaryOp { left, op, right } => match op {
762 BinaryOperator::And => {
763 let l = eval_expr(left, table, rowid)?;
764 let r = eval_expr(right, table, rowid)?;
765 Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
766 }
767 BinaryOperator::Or => {
768 let l = eval_expr(left, table, rowid)?;
769 let r = eval_expr(right, table, rowid)?;
770 Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
771 }
772 cmp @ (BinaryOperator::Eq
773 | BinaryOperator::NotEq
774 | BinaryOperator::Lt
775 | BinaryOperator::LtEq
776 | BinaryOperator::Gt
777 | BinaryOperator::GtEq) => {
778 let l = eval_expr(left, table, rowid)?;
779 let r = eval_expr(right, table, rowid)?;
780 if matches!(l, Value::Null) || matches!(r, Value::Null) {
782 return Ok(Value::Bool(false));
783 }
784 let ord = compare_values(Some(&l), Some(&r));
785 let result = match cmp {
786 BinaryOperator::Eq => ord == Ordering::Equal,
787 BinaryOperator::NotEq => ord != Ordering::Equal,
788 BinaryOperator::Lt => ord == Ordering::Less,
789 BinaryOperator::LtEq => ord != Ordering::Greater,
790 BinaryOperator::Gt => ord == Ordering::Greater,
791 BinaryOperator::GtEq => ord != Ordering::Less,
792 _ => unreachable!(),
793 };
794 Ok(Value::Bool(result))
795 }
796 arith @ (BinaryOperator::Plus
797 | BinaryOperator::Minus
798 | BinaryOperator::Multiply
799 | BinaryOperator::Divide
800 | BinaryOperator::Modulo) => {
801 let l = eval_expr(left, table, rowid)?;
802 let r = eval_expr(right, table, rowid)?;
803 eval_arith(arith, &l, &r)
804 }
805 BinaryOperator::StringConcat => {
806 let l = eval_expr(left, table, rowid)?;
807 let r = eval_expr(right, table, rowid)?;
808 if matches!(l, Value::Null) || matches!(r, Value::Null) {
809 return Ok(Value::Null);
810 }
811 Ok(Value::Text(format!(
812 "{}{}",
813 l.to_display_string(),
814 r.to_display_string()
815 )))
816 }
817 other => Err(SQLRiteError::NotImplemented(format!(
818 "binary operator {other:?} is not supported yet"
819 ))),
820 },
821
822 Expr::Function(func) => eval_function(func, table, rowid),
833
834 other => Err(SQLRiteError::NotImplemented(format!(
835 "unsupported expression in WHERE/projection: {other:?}"
836 ))),
837 }
838}
839
840fn eval_function(func: &sqlparser::ast::Function, table: &Table, rowid: i64) -> Result<Value> {
845 let name = match func.name.0.as_slice() {
848 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
849 _ => {
850 return Err(SQLRiteError::NotImplemented(format!(
851 "qualified function names not supported: {:?}",
852 func.name
853 )));
854 }
855 };
856
857 match name.as_str() {
858 "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
859 let (a, b) = extract_two_vector_args(&name, &func.args, table, rowid)?;
860 let dist = match name.as_str() {
861 "vec_distance_l2" => vec_distance_l2(&a, &b),
862 "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
863 "vec_distance_dot" => vec_distance_dot(&a, &b),
864 _ => unreachable!(),
865 };
866 Ok(Value::Real(dist as f64))
872 }
873 other => Err(SQLRiteError::NotImplemented(format!(
874 "unknown function: {other}(...)"
875 ))),
876 }
877}
878
879fn extract_two_vector_args(
883 fn_name: &str,
884 args: &FunctionArguments,
885 table: &Table,
886 rowid: i64,
887) -> Result<(Vec<f32>, Vec<f32>)> {
888 let arg_list = match args {
889 FunctionArguments::List(l) => &l.args,
890 _ => {
891 return Err(SQLRiteError::General(format!(
892 "{fn_name}() expects exactly two vector arguments"
893 )));
894 }
895 };
896 if arg_list.len() != 2 {
897 return Err(SQLRiteError::General(format!(
898 "{fn_name}() expects exactly 2 arguments, got {}",
899 arg_list.len()
900 )));
901 }
902 let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
903 for (i, arg) in arg_list.iter().enumerate() {
904 let expr = match arg {
905 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
906 other => {
907 return Err(SQLRiteError::NotImplemented(format!(
908 "{fn_name}() argument {i} has unsupported shape: {other:?}"
909 )));
910 }
911 };
912 let val = eval_expr(expr, table, rowid)?;
913 match val {
914 Value::Vector(v) => out.push(v),
915 other => {
916 return Err(SQLRiteError::General(format!(
917 "{fn_name}() argument {i} is not a vector: got {}",
918 other.to_display_string()
919 )));
920 }
921 }
922 }
923 let b = out.pop().unwrap();
924 let a = out.pop().unwrap();
925 if a.len() != b.len() {
926 return Err(SQLRiteError::General(format!(
927 "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
928 a.len(),
929 b.len()
930 )));
931 }
932 Ok((a, b))
933}
934
935pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
938 debug_assert_eq!(a.len(), b.len());
939 let mut sum = 0.0f32;
940 for i in 0..a.len() {
941 let d = a[i] - b[i];
942 sum += d * d;
943 }
944 sum.sqrt()
945}
946
947pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
957 debug_assert_eq!(a.len(), b.len());
958 let mut dot = 0.0f32;
959 let mut norm_a_sq = 0.0f32;
960 let mut norm_b_sq = 0.0f32;
961 for i in 0..a.len() {
962 dot += a[i] * b[i];
963 norm_a_sq += a[i] * a[i];
964 norm_b_sq += b[i] * b[i];
965 }
966 let denom = (norm_a_sq * norm_b_sq).sqrt();
967 if denom == 0.0 {
968 return Err(SQLRiteError::General(
969 "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
970 ));
971 }
972 Ok(1.0 - dot / denom)
973}
974
975pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
979 debug_assert_eq!(a.len(), b.len());
980 let mut dot = 0.0f32;
981 for i in 0..a.len() {
982 dot += a[i] * b[i];
983 }
984 -dot
985}
986
987fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
990 if matches!(l, Value::Null) || matches!(r, Value::Null) {
991 return Ok(Value::Null);
992 }
993 match (l, r) {
994 (Value::Integer(a), Value::Integer(b)) => match op {
995 BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
996 BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
997 BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
998 BinaryOperator::Divide => {
999 if *b == 0 {
1000 Err(SQLRiteError::General("division by zero".to_string()))
1001 } else {
1002 Ok(Value::Integer(a / b))
1003 }
1004 }
1005 BinaryOperator::Modulo => {
1006 if *b == 0 {
1007 Err(SQLRiteError::General("modulo by zero".to_string()))
1008 } else {
1009 Ok(Value::Integer(a % b))
1010 }
1011 }
1012 _ => unreachable!(),
1013 },
1014 (a, b) => {
1016 let af = as_number(a)?;
1017 let bf = as_number(b)?;
1018 match op {
1019 BinaryOperator::Plus => Ok(Value::Real(af + bf)),
1020 BinaryOperator::Minus => Ok(Value::Real(af - bf)),
1021 BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
1022 BinaryOperator::Divide => {
1023 if bf == 0.0 {
1024 Err(SQLRiteError::General("division by zero".to_string()))
1025 } else {
1026 Ok(Value::Real(af / bf))
1027 }
1028 }
1029 BinaryOperator::Modulo => {
1030 if bf == 0.0 {
1031 Err(SQLRiteError::General("modulo by zero".to_string()))
1032 } else {
1033 Ok(Value::Real(af % bf))
1034 }
1035 }
1036 _ => unreachable!(),
1037 }
1038 }
1039 }
1040}
1041
1042fn as_number(v: &Value) -> Result<f64> {
1043 match v {
1044 Value::Integer(i) => Ok(*i as f64),
1045 Value::Real(f) => Ok(*f),
1046 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
1047 other => Err(SQLRiteError::General(format!(
1048 "arithmetic on non-numeric value '{}'",
1049 other.to_display_string()
1050 ))),
1051 }
1052}
1053
1054fn as_bool(v: &Value) -> Result<bool> {
1055 match v {
1056 Value::Bool(b) => Ok(*b),
1057 Value::Null => Ok(false),
1058 Value::Integer(i) => Ok(*i != 0),
1059 other => Err(SQLRiteError::Internal(format!(
1060 "expected boolean, got {}",
1061 other.to_display_string()
1062 ))),
1063 }
1064}
1065
1066fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
1067 use sqlparser::ast::Value as AstValue;
1068 match v {
1069 AstValue::Number(n, _) => {
1070 if let Ok(i) = n.parse::<i64>() {
1071 Ok(Value::Integer(i))
1072 } else if let Ok(f) = n.parse::<f64>() {
1073 Ok(Value::Real(f))
1074 } else {
1075 Err(SQLRiteError::Internal(format!(
1076 "could not parse numeric literal '{n}'"
1077 )))
1078 }
1079 }
1080 AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
1081 AstValue::Boolean(b) => Ok(Value::Bool(*b)),
1082 AstValue::Null => Ok(Value::Null),
1083 other => Err(SQLRiteError::NotImplemented(format!(
1084 "unsupported literal value: {other:?}"
1085 ))),
1086 }
1087}
1088
1089#[cfg(test)]
1090mod tests {
1091 use super::*;
1092
1093 fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
1100 (a - b).abs() < eps
1101 }
1102
1103 #[test]
1104 fn vec_distance_l2_identical_is_zero() {
1105 let v = vec![0.1, 0.2, 0.3];
1106 assert_eq!(vec_distance_l2(&v, &v), 0.0);
1107 }
1108
1109 #[test]
1110 fn vec_distance_l2_unit_basis_is_sqrt2() {
1111 let a = vec![1.0, 0.0];
1113 let b = vec![0.0, 1.0];
1114 assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
1115 }
1116
1117 #[test]
1118 fn vec_distance_l2_known_value() {
1119 let a = vec![0.0, 0.0, 0.0];
1121 let b = vec![3.0, 4.0, 0.0];
1122 assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
1123 }
1124
1125 #[test]
1126 fn vec_distance_cosine_identical_is_zero() {
1127 let v = vec![0.1, 0.2, 0.3];
1128 let d = vec_distance_cosine(&v, &v).unwrap();
1129 assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
1130 }
1131
1132 #[test]
1133 fn vec_distance_cosine_orthogonal_is_one() {
1134 let a = vec![1.0, 0.0];
1137 let b = vec![0.0, 1.0];
1138 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
1139 }
1140
1141 #[test]
1142 fn vec_distance_cosine_opposite_is_two() {
1143 let a = vec![1.0, 0.0, 0.0];
1145 let b = vec![-1.0, 0.0, 0.0];
1146 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
1147 }
1148
1149 #[test]
1150 fn vec_distance_cosine_zero_magnitude_errors() {
1151 let a = vec![0.0, 0.0];
1153 let b = vec![1.0, 0.0];
1154 let err = vec_distance_cosine(&a, &b).unwrap_err();
1155 assert!(format!("{err}").contains("zero-magnitude"));
1156 }
1157
1158 #[test]
1159 fn vec_distance_dot_negates() {
1160 let a = vec![1.0, 2.0, 3.0];
1162 let b = vec![4.0, 5.0, 6.0];
1163 assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
1164 }
1165
1166 #[test]
1167 fn vec_distance_dot_orthogonal_is_zero() {
1168 let a = vec![1.0, 0.0];
1170 let b = vec![0.0, 1.0];
1171 assert_eq!(vec_distance_dot(&a, &b), 0.0);
1172 }
1173
1174 #[test]
1175 fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
1176 let a = vec![0.6f32, 0.8]; let b = vec![0.8f32, 0.6]; let dot = vec_distance_dot(&a, &b);
1182 let cos = vec_distance_cosine(&a, &b).unwrap();
1183 assert!(approx_eq(dot, cos - 1.0, 1e-5));
1184 }
1185
1186 use crate::sql::db::database::Database;
1191 use crate::sql::parser::select::SelectQuery;
1192 use sqlparser::dialect::SQLiteDialect;
1193 use sqlparser::parser::Parser;
1194
1195 fn seed_score_table(n: usize) -> Database {
1208 let mut db = Database::new("tempdb".to_string());
1209 crate::sql::process_command(
1210 "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
1211 &mut db,
1212 )
1213 .expect("create");
1214 for i in 0..n {
1215 let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
1219 let sql = format!("INSERT INTO docs (score) VALUES ({score});");
1220 crate::sql::process_command(&sql, &mut db).expect("insert");
1221 }
1222 db
1223 }
1224
1225 fn parse_select(sql: &str) -> SelectQuery {
1229 let dialect = SQLiteDialect {};
1230 let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
1231 let stmt = ast.pop().expect("one statement");
1232 SelectQuery::new(&stmt).expect("select-query")
1233 }
1234
1235 #[test]
1236 fn topk_matches_full_sort_asc() {
1237 let db = seed_score_table(200);
1240 let table = db.get_table("docs".to_string()).unwrap();
1241 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
1242 let order = q.order_by.as_ref().unwrap();
1243 let all_rowids = table.rowids();
1244
1245 let mut full = all_rowids.clone();
1247 sort_rowids(&mut full, table, order).unwrap();
1248 full.truncate(10);
1249
1250 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
1252
1253 assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
1254 }
1255
1256 #[test]
1257 fn topk_matches_full_sort_desc() {
1258 let db = seed_score_table(200);
1260 let table = db.get_table("docs".to_string()).unwrap();
1261 let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
1262 let order = q.order_by.as_ref().unwrap();
1263 let all_rowids = table.rowids();
1264
1265 let mut full = all_rowids.clone();
1266 sort_rowids(&mut full, table, order).unwrap();
1267 full.truncate(10);
1268
1269 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
1270
1271 assert_eq!(
1272 topk, full,
1273 "top-k DESC via heap should match full-sort+truncate"
1274 );
1275 }
1276
1277 #[test]
1278 fn topk_k_larger_than_n_returns_everything_sorted() {
1279 let db = seed_score_table(50);
1284 let table = db.get_table("docs".to_string()).unwrap();
1285 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
1286 let order = q.order_by.as_ref().unwrap();
1287 let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
1288 assert_eq!(topk.len(), 50);
1289 let scores: Vec<f64> = topk
1291 .iter()
1292 .filter_map(|r| match table.get_value("score", *r) {
1293 Some(Value::Real(f)) => Some(f),
1294 _ => None,
1295 })
1296 .collect();
1297 assert!(scores.windows(2).all(|w| w[0] <= w[1]));
1298 }
1299
1300 #[test]
1301 fn topk_k_zero_returns_empty() {
1302 let db = seed_score_table(10);
1303 let table = db.get_table("docs".to_string()).unwrap();
1304 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
1305 let order = q.order_by.as_ref().unwrap();
1306 let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
1307 assert!(topk.is_empty());
1308 }
1309
1310 #[test]
1311 fn topk_empty_input_returns_empty() {
1312 let db = seed_score_table(0);
1313 let table = db.get_table("docs".to_string()).unwrap();
1314 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
1315 let order = q.order_by.as_ref().unwrap();
1316 let topk = select_topk(&[], table, order, 5).unwrap();
1317 assert!(topk.is_empty());
1318 }
1319
1320 #[test]
1321 fn topk_works_through_select_executor_with_distance_function() {
1322 let mut db = Database::new("tempdb".to_string());
1326 crate::sql::process_command(
1327 "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
1328 &mut db,
1329 )
1330 .unwrap();
1331 for v in &[
1338 "[1.0, 0.0]",
1339 "[2.0, 0.0]",
1340 "[0.0, 3.0]",
1341 "[1.0, 4.0]",
1342 "[10.0, 10.0]",
1343 ] {
1344 crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
1345 .unwrap();
1346 }
1347 let resp = crate::sql::process_command(
1348 "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
1349 &mut db,
1350 )
1351 .unwrap();
1352 assert!(resp.contains("3 rows returned"), "got: {resp}");
1355 }
1356
1357 #[test]
1380 #[ignore]
1381 fn topk_benchmark() {
1382 use std::time::Instant;
1383 const N: usize = 10_000;
1384 const K: usize = 10;
1385
1386 let db = seed_score_table(N);
1387 let table = db.get_table("docs".to_string()).unwrap();
1388 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
1389 let order = q.order_by.as_ref().unwrap();
1390 let all_rowids = table.rowids();
1391
1392 let t0 = Instant::now();
1394 let _topk = select_topk(&all_rowids, table, order, K).unwrap();
1395 let heap_dur = t0.elapsed();
1396
1397 let t1 = Instant::now();
1399 let mut full = all_rowids.clone();
1400 sort_rowids(&mut full, table, order).unwrap();
1401 full.truncate(K);
1402 let sort_dur = t1.elapsed();
1403
1404 let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
1405 println!("\n--- topk_benchmark (N={N}, k={K}) ---");
1406 println!(" bounded heap: {heap_dur:?}");
1407 println!(" full sort+trunc: {sort_dur:?}");
1408 println!(" speedup ratio: {ratio:.2}×");
1409
1410 assert!(
1417 ratio > 1.4,
1418 "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
1419 );
1420 }
1421}