1use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8 AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr, FromTable, FunctionArg,
9 FunctionArgExpr, FunctionArguments, IndexType, ObjectNamePart, Statement, TableFactor,
10 TableWithJoins, UnaryOperator, Update,
11};
12
13use crate::error::{Result, SQLRiteError};
14use crate::sql::db::database::Database;
15use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
16use crate::sql::db::table::{DataType, HnswIndexEntry, Table, Value, parse_vector_literal};
17use crate::sql::hnsw::{DistanceMetric, HnswIndex};
18use crate::sql::parser::select::{OrderByClause, Projection, SelectQuery};
19
20pub struct SelectResult {
29 pub columns: Vec<String>,
30 pub rows: Vec<Vec<Value>>,
31}
32
33pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
37 let table = db
38 .get_table(query.table_name.clone())
39 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
40
41 let projected_cols: Vec<String> = match &query.projection {
43 Projection::All => table.column_names(),
44 Projection::Columns(cols) => {
45 for c in cols {
46 if !table.contains_column(c.to_string()) {
47 return Err(SQLRiteError::Internal(format!(
48 "Column '{c}' does not exist on table '{}'",
49 query.table_name
50 )));
51 }
52 }
53 cols.clone()
54 }
55 };
56
57 let matching = match select_rowids(table, query.selection.as_ref())? {
61 RowidSource::IndexProbe(rowids) => rowids,
62 RowidSource::FullScan => {
63 let mut out = Vec::new();
64 for rowid in table.rowids() {
65 if let Some(expr) = &query.selection {
66 if !eval_predicate(expr, table, rowid)? {
67 continue;
68 }
69 }
70 out.push(rowid);
71 }
72 out
73 }
74 };
75 let mut matching = matching;
76
77 match (&query.order_by, query.limit) {
106 (Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
107 matching = try_hnsw_probe(table, &order.expr, k).unwrap();
108 }
109 (Some(order), Some(k)) if k < matching.len() => {
110 matching = select_topk(&matching, table, order, k)?;
111 }
112 (Some(order), _) => {
113 sort_rowids(&mut matching, table, order)?;
114 if let Some(k) = query.limit {
115 matching.truncate(k);
116 }
117 }
118 (None, Some(k)) => {
119 matching.truncate(k);
120 }
121 (None, None) => {}
122 }
123
124 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
128 for rowid in &matching {
129 let row: Vec<Value> = projected_cols
130 .iter()
131 .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
132 .collect();
133 rows.push(row);
134 }
135
136 Ok(SelectResult {
137 columns: projected_cols,
138 rows,
139 })
140}
141
142pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
147 let result = execute_select_rows(query, db)?;
148 let row_count = result.rows.len();
149
150 let mut print_table = PrintTable::new();
151 let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
152 print_table.add_row(PrintRow::new(header_cells));
153
154 for row in &result.rows {
155 let cells: Vec<PrintCell> = row
156 .iter()
157 .map(|v| PrintCell::new(&v.to_display_string()))
158 .collect();
159 print_table.add_row(PrintRow::new(cells));
160 }
161
162 Ok((print_table.to_string(), row_count))
163}
164
165pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
167 let Statement::Delete(Delete {
168 from, selection, ..
169 }) = stmt
170 else {
171 return Err(SQLRiteError::Internal(
172 "execute_delete called on a non-DELETE statement".to_string(),
173 ));
174 };
175
176 let tables = match from {
177 FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
178 };
179 let table_name = extract_single_table_name(tables)?;
180
181 {
188 let table = db.get_table(table_name.clone()).map_err(|_| {
189 SQLRiteError::General(format!("DELETE references unknown table '{table_name}'"))
190 })?;
191 if !table.hnsw_indexes.is_empty() {
192 let names: Vec<&str> = table.hnsw_indexes.iter().map(|e| e.name.as_str()).collect();
193 return Err(SQLRiteError::NotImplemented(format!(
194 "DELETE on tables with HNSW indexes is not supported yet \
195 (Phase 7d.3 follow-up). DROP the index first, then DELETE, then re-CREATE. \
196 Table '{table_name}' currently has: {names:?}"
197 )));
198 }
199 }
200
201 let matching: Vec<i64> = {
203 let table = db
204 .get_table(table_name.clone())
205 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
206 match select_rowids(table, selection.as_ref())? {
207 RowidSource::IndexProbe(rowids) => rowids,
208 RowidSource::FullScan => {
209 let mut out = Vec::new();
210 for rowid in table.rowids() {
211 if let Some(expr) = selection {
212 if !eval_predicate(expr, table, rowid)? {
213 continue;
214 }
215 }
216 out.push(rowid);
217 }
218 out
219 }
220 }
221 };
222
223 let table = db.get_table_mut(table_name)?;
224 for rowid in &matching {
225 table.delete_row(*rowid);
226 }
227 Ok(matching.len())
228}
229
230pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
232 let Statement::Update(Update {
233 table,
234 assignments,
235 from,
236 selection,
237 ..
238 }) = stmt
239 else {
240 return Err(SQLRiteError::Internal(
241 "execute_update called on a non-UPDATE statement".to_string(),
242 ));
243 };
244
245 if from.is_some() {
246 return Err(SQLRiteError::NotImplemented(
247 "UPDATE ... FROM is not supported yet".to_string(),
248 ));
249 }
250
251 let table_name = extract_table_name(table)?;
252
253 {
259 let tbl = db.get_table(table_name.clone()).map_err(|_| {
260 SQLRiteError::General(format!("UPDATE references unknown table '{table_name}'"))
261 })?;
262 if !tbl.hnsw_indexes.is_empty() {
263 let names: Vec<&str> = tbl.hnsw_indexes.iter().map(|e| e.name.as_str()).collect();
264 return Err(SQLRiteError::NotImplemented(format!(
265 "UPDATE on tables with HNSW indexes is not supported yet \
266 (Phase 7d.3 follow-up). DROP the index first if you need to mutate. \
267 Table '{table_name}' currently has: {names:?}"
268 )));
269 }
270 }
271
272 let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
274 {
275 let tbl = db
276 .get_table(table_name.clone())
277 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
278 for a in assignments {
279 let col = match &a.target {
280 AssignmentTarget::ColumnName(name) => name
281 .0
282 .last()
283 .map(|p| p.to_string())
284 .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
285 AssignmentTarget::Tuple(_) => {
286 return Err(SQLRiteError::NotImplemented(
287 "tuple assignment targets are not supported".to_string(),
288 ));
289 }
290 };
291 if !tbl.contains_column(col.clone()) {
292 return Err(SQLRiteError::Internal(format!(
293 "UPDATE references unknown column '{col}'"
294 )));
295 }
296 parsed_assignments.push((col, a.value.clone()));
297 }
298 }
299
300 let work: Vec<(i64, Vec<(String, Value)>)> = {
304 let tbl = db.get_table(table_name.clone())?;
305 let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
306 RowidSource::IndexProbe(rowids) => rowids,
307 RowidSource::FullScan => {
308 let mut out = Vec::new();
309 for rowid in tbl.rowids() {
310 if let Some(expr) = selection {
311 if !eval_predicate(expr, tbl, rowid)? {
312 continue;
313 }
314 }
315 out.push(rowid);
316 }
317 out
318 }
319 };
320 let mut rows_to_update = Vec::new();
321 for rowid in matched_rowids {
322 let mut values = Vec::with_capacity(parsed_assignments.len());
323 for (col, expr) in &parsed_assignments {
324 let v = eval_expr(expr, tbl, rowid)?;
327 values.push((col.clone(), v));
328 }
329 rows_to_update.push((rowid, values));
330 }
331 rows_to_update
332 };
333
334 let tbl = db.get_table_mut(table_name)?;
335 for (rowid, values) in &work {
336 for (col, v) in values {
337 tbl.set_value(col, *rowid, v.clone())?;
338 }
339 }
340 Ok(work.len())
341}
342
343pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
355 let Statement::CreateIndex(CreateIndex {
356 name,
357 table_name,
358 columns,
359 using,
360 unique,
361 if_not_exists,
362 predicate,
363 ..
364 }) = stmt
365 else {
366 return Err(SQLRiteError::Internal(
367 "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
368 ));
369 };
370
371 if predicate.is_some() {
372 return Err(SQLRiteError::NotImplemented(
373 "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
374 ));
375 }
376
377 if columns.len() != 1 {
378 return Err(SQLRiteError::NotImplemented(format!(
379 "multi-column indexes are not supported yet ({} columns given)",
380 columns.len()
381 )));
382 }
383
384 let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
385 SQLRiteError::NotImplemented(
386 "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
387 )
388 })?;
389
390 let method = match using {
396 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
397 IndexMethod::Hnsw
398 }
399 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
400 IndexMethod::Btree
401 }
402 Some(other) => {
403 return Err(SQLRiteError::NotImplemented(format!(
404 "CREATE INDEX … USING {other:?} is not supported (try `hnsw` or no USING clause)"
405 )));
406 }
407 None => IndexMethod::Btree,
408 };
409
410 let table_name_str = table_name.to_string();
411 let column_name = match &columns[0].column.expr {
412 Expr::Identifier(ident) => ident.value.clone(),
413 Expr::CompoundIdentifier(parts) => parts
414 .last()
415 .map(|p| p.value.clone())
416 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
417 other => {
418 return Err(SQLRiteError::NotImplemented(format!(
419 "CREATE INDEX only supports simple column references, got {other:?}"
420 )));
421 }
422 };
423
424 let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
429 let table = db.get_table(table_name_str.clone()).map_err(|_| {
430 SQLRiteError::General(format!(
431 "CREATE INDEX references unknown table '{table_name_str}'"
432 ))
433 })?;
434 if !table.contains_column(column_name.clone()) {
435 return Err(SQLRiteError::General(format!(
436 "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
437 )));
438 }
439 let col = table
440 .columns
441 .iter()
442 .find(|c| c.column_name == column_name)
443 .expect("we just verified the column exists");
444
445 if table.index_by_name(&index_name).is_some()
448 || table.hnsw_indexes.iter().any(|i| i.name == index_name)
449 {
450 if *if_not_exists {
451 return Ok(index_name);
452 }
453 return Err(SQLRiteError::General(format!(
454 "index '{index_name}' already exists"
455 )));
456 }
457 let datatype = clone_datatype(&col.datatype);
458
459 let mut pairs = Vec::new();
460 for rowid in table.rowids() {
461 if let Some(v) = table.get_value(&column_name, rowid) {
462 pairs.push((rowid, v));
463 }
464 }
465 (datatype, pairs)
466 };
467
468 match method {
469 IndexMethod::Btree => create_btree_index(
470 db,
471 &table_name_str,
472 &index_name,
473 &column_name,
474 &datatype,
475 *unique,
476 &existing_rowids_and_values,
477 ),
478 IndexMethod::Hnsw => create_hnsw_index(
479 db,
480 &table_name_str,
481 &index_name,
482 &column_name,
483 &datatype,
484 *unique,
485 &existing_rowids_and_values,
486 ),
487 }
488}
489
490#[derive(Debug, Clone, Copy)]
494enum IndexMethod {
495 Btree,
496 Hnsw,
497}
498
499fn create_btree_index(
501 db: &mut Database,
502 table_name: &str,
503 index_name: &str,
504 column_name: &str,
505 datatype: &DataType,
506 unique: bool,
507 existing: &[(i64, Value)],
508) -> Result<String> {
509 let mut idx = SecondaryIndex::new(
510 index_name.to_string(),
511 table_name.to_string(),
512 column_name.to_string(),
513 datatype,
514 unique,
515 IndexOrigin::Explicit,
516 )?;
517
518 for (rowid, v) in existing {
522 if unique && idx.would_violate_unique(v) {
523 return Err(SQLRiteError::General(format!(
524 "cannot create UNIQUE index '{index_name}': column '{column_name}' \
525 already contains the duplicate value {}",
526 v.to_display_string()
527 )));
528 }
529 idx.insert(v, *rowid)?;
530 }
531
532 let table_mut = db.get_table_mut(table_name.to_string())?;
533 table_mut.secondary_indexes.push(idx);
534 Ok(index_name.to_string())
535}
536
537fn create_hnsw_index(
539 db: &mut Database,
540 table_name: &str,
541 index_name: &str,
542 column_name: &str,
543 datatype: &DataType,
544 unique: bool,
545 existing: &[(i64, Value)],
546) -> Result<String> {
547 let dim = match datatype {
550 DataType::Vector(d) => *d,
551 other => {
552 return Err(SQLRiteError::General(format!(
553 "USING hnsw requires a VECTOR column; '{column_name}' is {other}"
554 )));
555 }
556 };
557
558 if unique {
559 return Err(SQLRiteError::General(
560 "UNIQUE has no meaning for HNSW indexes".to_string(),
561 ));
562 }
563
564 let seed = hash_str_to_seed(index_name);
572 let mut idx = HnswIndex::new(DistanceMetric::L2, seed);
573
574 let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
578 std::collections::HashMap::with_capacity(existing.len());
579 for (rowid, v) in existing {
580 match v {
581 Value::Vector(vec) => {
582 if vec.len() != dim {
583 return Err(SQLRiteError::Internal(format!(
584 "row {rowid} stores a {}-dim vector in column '{column_name}' \
585 declared as VECTOR({dim}) — schema invariant violated",
586 vec.len()
587 )));
588 }
589 vec_map.insert(*rowid, vec.clone());
590 }
591 _ => continue,
595 }
596 }
597
598 for (rowid, _) in existing {
599 if let Some(v) = vec_map.get(rowid) {
600 let v_clone = v.clone();
601 idx.insert(*rowid, &v_clone, |id| {
602 vec_map.get(&id).cloned().unwrap_or_default()
603 });
604 }
605 }
606
607 let table_mut = db.get_table_mut(table_name.to_string())?;
608 table_mut.hnsw_indexes.push(HnswIndexEntry {
609 name: index_name.to_string(),
610 column_name: column_name.to_string(),
611 index: idx,
612 });
613 Ok(index_name.to_string())
614}
615
616fn hash_str_to_seed(s: &str) -> u64 {
620 let mut h: u64 = 0xCBF29CE484222325;
621 for b in s.as_bytes() {
622 h ^= *b as u64;
623 h = h.wrapping_mul(0x100000001B3);
624 }
625 h
626}
627
628fn clone_datatype(dt: &DataType) -> DataType {
631 match dt {
632 DataType::Integer => DataType::Integer,
633 DataType::Text => DataType::Text,
634 DataType::Real => DataType::Real,
635 DataType::Bool => DataType::Bool,
636 DataType::Vector(dim) => DataType::Vector(*dim),
637 DataType::None => DataType::None,
638 DataType::Invalid => DataType::Invalid,
639 }
640}
641
642fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
643 if tables.len() != 1 {
644 return Err(SQLRiteError::NotImplemented(
645 "multi-table DELETE is not supported yet".to_string(),
646 ));
647 }
648 extract_table_name(&tables[0])
649}
650
651fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
652 if !twj.joins.is_empty() {
653 return Err(SQLRiteError::NotImplemented(
654 "JOIN is not supported yet".to_string(),
655 ));
656 }
657 match &twj.relation {
658 TableFactor::Table { name, .. } => Ok(name.to_string()),
659 _ => Err(SQLRiteError::NotImplemented(
660 "only plain table references are supported".to_string(),
661 )),
662 }
663}
664
665enum RowidSource {
667 IndexProbe(Vec<i64>),
671 FullScan,
674}
675
676fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
681 let Some(expr) = selection else {
682 return Ok(RowidSource::FullScan);
683 };
684 let Some((col, literal)) = try_extract_equality(expr) else {
685 return Ok(RowidSource::FullScan);
686 };
687 let Some(idx) = table.index_for_column(&col) else {
688 return Ok(RowidSource::FullScan);
689 };
690
691 let literal_value = match convert_literal(&literal) {
695 Ok(v) => v,
696 Err(_) => return Ok(RowidSource::FullScan),
697 };
698
699 let mut rowids = idx.lookup(&literal_value);
703 rowids.sort_unstable();
704 Ok(RowidSource::IndexProbe(rowids))
705}
706
707fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
711 let peeled = match expr {
713 Expr::Nested(inner) => inner.as_ref(),
714 other => other,
715 };
716 let Expr::BinaryOp { left, op, right } = peeled else {
717 return None;
718 };
719 if !matches!(op, BinaryOperator::Eq) {
720 return None;
721 }
722 let col_from = |e: &Expr| -> Option<String> {
723 match e {
724 Expr::Identifier(ident) => Some(ident.value.clone()),
725 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
726 _ => None,
727 }
728 };
729 let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
730 if let Expr::Value(v) = e {
731 Some(v.value.clone())
732 } else {
733 None
734 }
735 };
736 if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
737 return Some((c, l));
738 }
739 if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
740 return Some((c, l));
741 }
742 None
743}
744
745fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
767 if k == 0 {
768 return None;
769 }
770
771 let func = match order_expr {
773 Expr::Function(f) => f,
774 _ => return None,
775 };
776 let fname = match func.name.0.as_slice() {
777 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
778 _ => return None,
779 };
780 if fname != "vec_distance_l2" {
781 return None;
782 }
783
784 let arg_list = match &func.args {
786 FunctionArguments::List(l) => &l.args,
787 _ => return None,
788 };
789 if arg_list.len() != 2 {
790 return None;
791 }
792 let exprs: Vec<&Expr> = arg_list
793 .iter()
794 .filter_map(|a| match a {
795 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
796 _ => None,
797 })
798 .collect();
799 if exprs.len() != 2 {
800 return None;
801 }
802
803 let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
808 Some(v) => v,
809 None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
810 Some(v) => v,
811 None => return None,
812 },
813 };
814
815 let entry = table
817 .hnsw_indexes
818 .iter()
819 .find(|e| e.column_name == col_name)?;
820
821 let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
827 Some(c) => match &c.datatype {
828 DataType::Vector(d) => *d,
829 _ => return None,
830 },
831 None => return None,
832 };
833 if query_vec.len() != declared_dim {
834 return None;
835 }
836
837 let column_for_closure = col_name.clone();
841 let table_ref = table;
842 let result = entry.index.search(&query_vec, k, |id| {
843 match table_ref.get_value(&column_for_closure, id) {
844 Some(Value::Vector(v)) => v,
845 _ => Vec::new(),
846 }
847 });
848 Some(result)
849}
850
851fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
856 let col_name = match a {
857 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
858 _ => return None,
859 };
860 let lit_str = match b {
861 Expr::Identifier(ident) if ident.quote_style == Some('[') => {
862 format!("[{}]", ident.value)
863 }
864 _ => return None,
865 };
866 let v = parse_vector_literal(&lit_str).ok()?;
867 Some((col_name, v))
868}
869
870struct HeapEntry {
883 key: Value,
884 rowid: i64,
885 asc: bool,
886}
887
888impl PartialEq for HeapEntry {
889 fn eq(&self, other: &Self) -> bool {
890 self.cmp(other) == Ordering::Equal
891 }
892}
893
894impl Eq for HeapEntry {}
895
896impl PartialOrd for HeapEntry {
897 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
898 Some(self.cmp(other))
899 }
900}
901
902impl Ord for HeapEntry {
903 fn cmp(&self, other: &Self) -> Ordering {
904 let raw = compare_values(Some(&self.key), Some(&other.key));
905 if self.asc { raw } else { raw.reverse() }
906 }
907}
908
909fn select_topk(
918 matching: &[i64],
919 table: &Table,
920 order: &OrderByClause,
921 k: usize,
922) -> Result<Vec<i64>> {
923 use std::collections::BinaryHeap;
924
925 if k == 0 || matching.is_empty() {
926 return Ok(Vec::new());
927 }
928
929 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
930
931 for &rowid in matching {
932 let key = eval_expr(&order.expr, table, rowid)?;
933 let entry = HeapEntry {
934 key,
935 rowid,
936 asc: order.ascending,
937 };
938
939 if heap.len() < k {
940 heap.push(entry);
941 } else {
942 if entry < *heap.peek().unwrap() {
946 heap.pop();
947 heap.push(entry);
948 }
949 }
950 }
951
952 Ok(heap
957 .into_sorted_vec()
958 .into_iter()
959 .map(|e| e.rowid)
960 .collect())
961}
962
963fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
964 let mut keys: Vec<(i64, Result<Value>)> = rowids
972 .iter()
973 .map(|r| (*r, eval_expr(&order.expr, table, *r)))
974 .collect();
975
976 for (_, k) in &keys {
980 if let Err(e) = k {
981 return Err(SQLRiteError::General(format!(
982 "ORDER BY expression failed: {e}"
983 )));
984 }
985 }
986
987 keys.sort_by(|(_, ka), (_, kb)| {
988 let va = ka.as_ref().unwrap();
991 let vb = kb.as_ref().unwrap();
992 let ord = compare_values(Some(va), Some(vb));
993 if order.ascending { ord } else { ord.reverse() }
994 });
995
996 for (i, (rowid, _)) in keys.into_iter().enumerate() {
998 rowids[i] = rowid;
999 }
1000 Ok(())
1001}
1002
1003fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
1004 match (a, b) {
1005 (None, None) => Ordering::Equal,
1006 (None, _) => Ordering::Less,
1007 (_, None) => Ordering::Greater,
1008 (Some(a), Some(b)) => match (a, b) {
1009 (Value::Null, Value::Null) => Ordering::Equal,
1010 (Value::Null, _) => Ordering::Less,
1011 (_, Value::Null) => Ordering::Greater,
1012 (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
1013 (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
1014 (Value::Integer(x), Value::Real(y)) => {
1015 (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
1016 }
1017 (Value::Real(x), Value::Integer(y)) => {
1018 x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
1019 }
1020 (Value::Text(x), Value::Text(y)) => x.cmp(y),
1021 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
1022 (x, y) => x.to_display_string().cmp(&y.to_display_string()),
1024 },
1025 }
1026}
1027
1028pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
1030 let v = eval_expr(expr, table, rowid)?;
1031 match v {
1032 Value::Bool(b) => Ok(b),
1033 Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
1035 other => Err(SQLRiteError::Internal(format!(
1036 "WHERE clause must evaluate to boolean, got {}",
1037 other.to_display_string()
1038 ))),
1039 }
1040}
1041
1042fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
1043 match expr {
1044 Expr::Nested(inner) => eval_expr(inner, table, rowid),
1045
1046 Expr::Identifier(ident) => {
1047 if ident.quote_style == Some('[') {
1057 let raw = format!("[{}]", ident.value);
1058 let v = parse_vector_literal(&raw)?;
1059 return Ok(Value::Vector(v));
1060 }
1061 Ok(table.get_value(&ident.value, rowid).unwrap_or(Value::Null))
1062 }
1063
1064 Expr::CompoundIdentifier(parts) => {
1065 let col = parts
1067 .last()
1068 .map(|i| i.value.as_str())
1069 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
1070 Ok(table.get_value(col, rowid).unwrap_or(Value::Null))
1071 }
1072
1073 Expr::Value(v) => convert_literal(&v.value),
1074
1075 Expr::UnaryOp { op, expr } => {
1076 let inner = eval_expr(expr, table, rowid)?;
1077 match op {
1078 UnaryOperator::Not => match inner {
1079 Value::Bool(b) => Ok(Value::Bool(!b)),
1080 Value::Null => Ok(Value::Null),
1081 other => Err(SQLRiteError::Internal(format!(
1082 "NOT applied to non-boolean value: {}",
1083 other.to_display_string()
1084 ))),
1085 },
1086 UnaryOperator::Minus => match inner {
1087 Value::Integer(i) => Ok(Value::Integer(-i)),
1088 Value::Real(f) => Ok(Value::Real(-f)),
1089 Value::Null => Ok(Value::Null),
1090 other => Err(SQLRiteError::Internal(format!(
1091 "unary minus on non-numeric value: {}",
1092 other.to_display_string()
1093 ))),
1094 },
1095 UnaryOperator::Plus => Ok(inner),
1096 other => Err(SQLRiteError::NotImplemented(format!(
1097 "unary operator {other:?} is not supported"
1098 ))),
1099 }
1100 }
1101
1102 Expr::BinaryOp { left, op, right } => match op {
1103 BinaryOperator::And => {
1104 let l = eval_expr(left, table, rowid)?;
1105 let r = eval_expr(right, table, rowid)?;
1106 Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
1107 }
1108 BinaryOperator::Or => {
1109 let l = eval_expr(left, table, rowid)?;
1110 let r = eval_expr(right, table, rowid)?;
1111 Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
1112 }
1113 cmp @ (BinaryOperator::Eq
1114 | BinaryOperator::NotEq
1115 | BinaryOperator::Lt
1116 | BinaryOperator::LtEq
1117 | BinaryOperator::Gt
1118 | BinaryOperator::GtEq) => {
1119 let l = eval_expr(left, table, rowid)?;
1120 let r = eval_expr(right, table, rowid)?;
1121 if matches!(l, Value::Null) || matches!(r, Value::Null) {
1123 return Ok(Value::Bool(false));
1124 }
1125 let ord = compare_values(Some(&l), Some(&r));
1126 let result = match cmp {
1127 BinaryOperator::Eq => ord == Ordering::Equal,
1128 BinaryOperator::NotEq => ord != Ordering::Equal,
1129 BinaryOperator::Lt => ord == Ordering::Less,
1130 BinaryOperator::LtEq => ord != Ordering::Greater,
1131 BinaryOperator::Gt => ord == Ordering::Greater,
1132 BinaryOperator::GtEq => ord != Ordering::Less,
1133 _ => unreachable!(),
1134 };
1135 Ok(Value::Bool(result))
1136 }
1137 arith @ (BinaryOperator::Plus
1138 | BinaryOperator::Minus
1139 | BinaryOperator::Multiply
1140 | BinaryOperator::Divide
1141 | BinaryOperator::Modulo) => {
1142 let l = eval_expr(left, table, rowid)?;
1143 let r = eval_expr(right, table, rowid)?;
1144 eval_arith(arith, &l, &r)
1145 }
1146 BinaryOperator::StringConcat => {
1147 let l = eval_expr(left, table, rowid)?;
1148 let r = eval_expr(right, table, rowid)?;
1149 if matches!(l, Value::Null) || matches!(r, Value::Null) {
1150 return Ok(Value::Null);
1151 }
1152 Ok(Value::Text(format!(
1153 "{}{}",
1154 l.to_display_string(),
1155 r.to_display_string()
1156 )))
1157 }
1158 other => Err(SQLRiteError::NotImplemented(format!(
1159 "binary operator {other:?} is not supported yet"
1160 ))),
1161 },
1162
1163 Expr::Function(func) => eval_function(func, table, rowid),
1174
1175 other => Err(SQLRiteError::NotImplemented(format!(
1176 "unsupported expression in WHERE/projection: {other:?}"
1177 ))),
1178 }
1179}
1180
1181fn eval_function(func: &sqlparser::ast::Function, table: &Table, rowid: i64) -> Result<Value> {
1186 let name = match func.name.0.as_slice() {
1189 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1190 _ => {
1191 return Err(SQLRiteError::NotImplemented(format!(
1192 "qualified function names not supported: {:?}",
1193 func.name
1194 )));
1195 }
1196 };
1197
1198 match name.as_str() {
1199 "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
1200 let (a, b) = extract_two_vector_args(&name, &func.args, table, rowid)?;
1201 let dist = match name.as_str() {
1202 "vec_distance_l2" => vec_distance_l2(&a, &b),
1203 "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
1204 "vec_distance_dot" => vec_distance_dot(&a, &b),
1205 _ => unreachable!(),
1206 };
1207 Ok(Value::Real(dist as f64))
1213 }
1214 other => Err(SQLRiteError::NotImplemented(format!(
1215 "unknown function: {other}(...)"
1216 ))),
1217 }
1218}
1219
1220fn extract_two_vector_args(
1224 fn_name: &str,
1225 args: &FunctionArguments,
1226 table: &Table,
1227 rowid: i64,
1228) -> Result<(Vec<f32>, Vec<f32>)> {
1229 let arg_list = match args {
1230 FunctionArguments::List(l) => &l.args,
1231 _ => {
1232 return Err(SQLRiteError::General(format!(
1233 "{fn_name}() expects exactly two vector arguments"
1234 )));
1235 }
1236 };
1237 if arg_list.len() != 2 {
1238 return Err(SQLRiteError::General(format!(
1239 "{fn_name}() expects exactly 2 arguments, got {}",
1240 arg_list.len()
1241 )));
1242 }
1243 let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
1244 for (i, arg) in arg_list.iter().enumerate() {
1245 let expr = match arg {
1246 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1247 other => {
1248 return Err(SQLRiteError::NotImplemented(format!(
1249 "{fn_name}() argument {i} has unsupported shape: {other:?}"
1250 )));
1251 }
1252 };
1253 let val = eval_expr(expr, table, rowid)?;
1254 match val {
1255 Value::Vector(v) => out.push(v),
1256 other => {
1257 return Err(SQLRiteError::General(format!(
1258 "{fn_name}() argument {i} is not a vector: got {}",
1259 other.to_display_string()
1260 )));
1261 }
1262 }
1263 }
1264 let b = out.pop().unwrap();
1265 let a = out.pop().unwrap();
1266 if a.len() != b.len() {
1267 return Err(SQLRiteError::General(format!(
1268 "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
1269 a.len(),
1270 b.len()
1271 )));
1272 }
1273 Ok((a, b))
1274}
1275
1276pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
1279 debug_assert_eq!(a.len(), b.len());
1280 let mut sum = 0.0f32;
1281 for i in 0..a.len() {
1282 let d = a[i] - b[i];
1283 sum += d * d;
1284 }
1285 sum.sqrt()
1286}
1287
1288pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
1298 debug_assert_eq!(a.len(), b.len());
1299 let mut dot = 0.0f32;
1300 let mut norm_a_sq = 0.0f32;
1301 let mut norm_b_sq = 0.0f32;
1302 for i in 0..a.len() {
1303 dot += a[i] * b[i];
1304 norm_a_sq += a[i] * a[i];
1305 norm_b_sq += b[i] * b[i];
1306 }
1307 let denom = (norm_a_sq * norm_b_sq).sqrt();
1308 if denom == 0.0 {
1309 return Err(SQLRiteError::General(
1310 "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
1311 ));
1312 }
1313 Ok(1.0 - dot / denom)
1314}
1315
1316pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
1320 debug_assert_eq!(a.len(), b.len());
1321 let mut dot = 0.0f32;
1322 for i in 0..a.len() {
1323 dot += a[i] * b[i];
1324 }
1325 -dot
1326}
1327
1328fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
1331 if matches!(l, Value::Null) || matches!(r, Value::Null) {
1332 return Ok(Value::Null);
1333 }
1334 match (l, r) {
1335 (Value::Integer(a), Value::Integer(b)) => match op {
1336 BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
1337 BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
1338 BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
1339 BinaryOperator::Divide => {
1340 if *b == 0 {
1341 Err(SQLRiteError::General("division by zero".to_string()))
1342 } else {
1343 Ok(Value::Integer(a / b))
1344 }
1345 }
1346 BinaryOperator::Modulo => {
1347 if *b == 0 {
1348 Err(SQLRiteError::General("modulo by zero".to_string()))
1349 } else {
1350 Ok(Value::Integer(a % b))
1351 }
1352 }
1353 _ => unreachable!(),
1354 },
1355 (a, b) => {
1357 let af = as_number(a)?;
1358 let bf = as_number(b)?;
1359 match op {
1360 BinaryOperator::Plus => Ok(Value::Real(af + bf)),
1361 BinaryOperator::Minus => Ok(Value::Real(af - bf)),
1362 BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
1363 BinaryOperator::Divide => {
1364 if bf == 0.0 {
1365 Err(SQLRiteError::General("division by zero".to_string()))
1366 } else {
1367 Ok(Value::Real(af / bf))
1368 }
1369 }
1370 BinaryOperator::Modulo => {
1371 if bf == 0.0 {
1372 Err(SQLRiteError::General("modulo by zero".to_string()))
1373 } else {
1374 Ok(Value::Real(af % bf))
1375 }
1376 }
1377 _ => unreachable!(),
1378 }
1379 }
1380 }
1381}
1382
1383fn as_number(v: &Value) -> Result<f64> {
1384 match v {
1385 Value::Integer(i) => Ok(*i as f64),
1386 Value::Real(f) => Ok(*f),
1387 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
1388 other => Err(SQLRiteError::General(format!(
1389 "arithmetic on non-numeric value '{}'",
1390 other.to_display_string()
1391 ))),
1392 }
1393}
1394
1395fn as_bool(v: &Value) -> Result<bool> {
1396 match v {
1397 Value::Bool(b) => Ok(*b),
1398 Value::Null => Ok(false),
1399 Value::Integer(i) => Ok(*i != 0),
1400 other => Err(SQLRiteError::Internal(format!(
1401 "expected boolean, got {}",
1402 other.to_display_string()
1403 ))),
1404 }
1405}
1406
1407fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
1408 use sqlparser::ast::Value as AstValue;
1409 match v {
1410 AstValue::Number(n, _) => {
1411 if let Ok(i) = n.parse::<i64>() {
1412 Ok(Value::Integer(i))
1413 } else if let Ok(f) = n.parse::<f64>() {
1414 Ok(Value::Real(f))
1415 } else {
1416 Err(SQLRiteError::Internal(format!(
1417 "could not parse numeric literal '{n}'"
1418 )))
1419 }
1420 }
1421 AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
1422 AstValue::Boolean(b) => Ok(Value::Bool(*b)),
1423 AstValue::Null => Ok(Value::Null),
1424 other => Err(SQLRiteError::NotImplemented(format!(
1425 "unsupported literal value: {other:?}"
1426 ))),
1427 }
1428}
1429
1430#[cfg(test)]
1431mod tests {
1432 use super::*;
1433
1434 fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
1441 (a - b).abs() < eps
1442 }
1443
1444 #[test]
1445 fn vec_distance_l2_identical_is_zero() {
1446 let v = vec![0.1, 0.2, 0.3];
1447 assert_eq!(vec_distance_l2(&v, &v), 0.0);
1448 }
1449
1450 #[test]
1451 fn vec_distance_l2_unit_basis_is_sqrt2() {
1452 let a = vec![1.0, 0.0];
1454 let b = vec![0.0, 1.0];
1455 assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
1456 }
1457
1458 #[test]
1459 fn vec_distance_l2_known_value() {
1460 let a = vec![0.0, 0.0, 0.0];
1462 let b = vec![3.0, 4.0, 0.0];
1463 assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
1464 }
1465
1466 #[test]
1467 fn vec_distance_cosine_identical_is_zero() {
1468 let v = vec![0.1, 0.2, 0.3];
1469 let d = vec_distance_cosine(&v, &v).unwrap();
1470 assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
1471 }
1472
1473 #[test]
1474 fn vec_distance_cosine_orthogonal_is_one() {
1475 let a = vec![1.0, 0.0];
1478 let b = vec![0.0, 1.0];
1479 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
1480 }
1481
1482 #[test]
1483 fn vec_distance_cosine_opposite_is_two() {
1484 let a = vec![1.0, 0.0, 0.0];
1486 let b = vec![-1.0, 0.0, 0.0];
1487 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
1488 }
1489
1490 #[test]
1491 fn vec_distance_cosine_zero_magnitude_errors() {
1492 let a = vec![0.0, 0.0];
1494 let b = vec![1.0, 0.0];
1495 let err = vec_distance_cosine(&a, &b).unwrap_err();
1496 assert!(format!("{err}").contains("zero-magnitude"));
1497 }
1498
1499 #[test]
1500 fn vec_distance_dot_negates() {
1501 let a = vec![1.0, 2.0, 3.0];
1503 let b = vec![4.0, 5.0, 6.0];
1504 assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
1505 }
1506
1507 #[test]
1508 fn vec_distance_dot_orthogonal_is_zero() {
1509 let a = vec![1.0, 0.0];
1511 let b = vec![0.0, 1.0];
1512 assert_eq!(vec_distance_dot(&a, &b), 0.0);
1513 }
1514
1515 #[test]
1516 fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
1517 let a = vec![0.6f32, 0.8]; let b = vec![0.8f32, 0.6]; let dot = vec_distance_dot(&a, &b);
1523 let cos = vec_distance_cosine(&a, &b).unwrap();
1524 assert!(approx_eq(dot, cos - 1.0, 1e-5));
1525 }
1526
1527 use crate::sql::db::database::Database;
1532 use crate::sql::parser::select::SelectQuery;
1533 use sqlparser::dialect::SQLiteDialect;
1534 use sqlparser::parser::Parser;
1535
1536 fn seed_score_table(n: usize) -> Database {
1549 let mut db = Database::new("tempdb".to_string());
1550 crate::sql::process_command(
1551 "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
1552 &mut db,
1553 )
1554 .expect("create");
1555 for i in 0..n {
1556 let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
1560 let sql = format!("INSERT INTO docs (score) VALUES ({score});");
1561 crate::sql::process_command(&sql, &mut db).expect("insert");
1562 }
1563 db
1564 }
1565
1566 fn parse_select(sql: &str) -> SelectQuery {
1570 let dialect = SQLiteDialect {};
1571 let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
1572 let stmt = ast.pop().expect("one statement");
1573 SelectQuery::new(&stmt).expect("select-query")
1574 }
1575
1576 #[test]
1577 fn topk_matches_full_sort_asc() {
1578 let db = seed_score_table(200);
1581 let table = db.get_table("docs".to_string()).unwrap();
1582 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
1583 let order = q.order_by.as_ref().unwrap();
1584 let all_rowids = table.rowids();
1585
1586 let mut full = all_rowids.clone();
1588 sort_rowids(&mut full, table, order).unwrap();
1589 full.truncate(10);
1590
1591 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
1593
1594 assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
1595 }
1596
1597 #[test]
1598 fn topk_matches_full_sort_desc() {
1599 let db = seed_score_table(200);
1601 let table = db.get_table("docs".to_string()).unwrap();
1602 let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
1603 let order = q.order_by.as_ref().unwrap();
1604 let all_rowids = table.rowids();
1605
1606 let mut full = all_rowids.clone();
1607 sort_rowids(&mut full, table, order).unwrap();
1608 full.truncate(10);
1609
1610 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
1611
1612 assert_eq!(
1613 topk, full,
1614 "top-k DESC via heap should match full-sort+truncate"
1615 );
1616 }
1617
1618 #[test]
1619 fn topk_k_larger_than_n_returns_everything_sorted() {
1620 let db = seed_score_table(50);
1625 let table = db.get_table("docs".to_string()).unwrap();
1626 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
1627 let order = q.order_by.as_ref().unwrap();
1628 let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
1629 assert_eq!(topk.len(), 50);
1630 let scores: Vec<f64> = topk
1632 .iter()
1633 .filter_map(|r| match table.get_value("score", *r) {
1634 Some(Value::Real(f)) => Some(f),
1635 _ => None,
1636 })
1637 .collect();
1638 assert!(scores.windows(2).all(|w| w[0] <= w[1]));
1639 }
1640
1641 #[test]
1642 fn topk_k_zero_returns_empty() {
1643 let db = seed_score_table(10);
1644 let table = db.get_table("docs".to_string()).unwrap();
1645 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
1646 let order = q.order_by.as_ref().unwrap();
1647 let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
1648 assert!(topk.is_empty());
1649 }
1650
1651 #[test]
1652 fn topk_empty_input_returns_empty() {
1653 let db = seed_score_table(0);
1654 let table = db.get_table("docs".to_string()).unwrap();
1655 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
1656 let order = q.order_by.as_ref().unwrap();
1657 let topk = select_topk(&[], table, order, 5).unwrap();
1658 assert!(topk.is_empty());
1659 }
1660
1661 #[test]
1662 fn topk_works_through_select_executor_with_distance_function() {
1663 let mut db = Database::new("tempdb".to_string());
1667 crate::sql::process_command(
1668 "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
1669 &mut db,
1670 )
1671 .unwrap();
1672 for v in &[
1679 "[1.0, 0.0]",
1680 "[2.0, 0.0]",
1681 "[0.0, 3.0]",
1682 "[1.0, 4.0]",
1683 "[10.0, 10.0]",
1684 ] {
1685 crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
1686 .unwrap();
1687 }
1688 let resp = crate::sql::process_command(
1689 "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
1690 &mut db,
1691 )
1692 .unwrap();
1693 assert!(resp.contains("3 rows returned"), "got: {resp}");
1696 }
1697
1698 #[test]
1721 #[ignore]
1722 fn topk_benchmark() {
1723 use std::time::Instant;
1724 const N: usize = 10_000;
1725 const K: usize = 10;
1726
1727 let db = seed_score_table(N);
1728 let table = db.get_table("docs".to_string()).unwrap();
1729 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
1730 let order = q.order_by.as_ref().unwrap();
1731 let all_rowids = table.rowids();
1732
1733 let t0 = Instant::now();
1735 let _topk = select_topk(&all_rowids, table, order, K).unwrap();
1736 let heap_dur = t0.elapsed();
1737
1738 let t1 = Instant::now();
1740 let mut full = all_rowids.clone();
1741 sort_rowids(&mut full, table, order).unwrap();
1742 full.truncate(K);
1743 let sort_dur = t1.elapsed();
1744
1745 let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
1746 println!("\n--- topk_benchmark (N={N}, k={K}) ---");
1747 println!(" bounded heap: {heap_dur:?}");
1748 println!(" full sort+trunc: {sort_dur:?}");
1749 println!(" speedup ratio: {ratio:.2}×");
1750
1751 assert!(
1758 ratio > 1.4,
1759 "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
1760 );
1761 }
1762}