Skip to main content

sqlrite/sql/
executor.rs

1//! Query executors — evaluate parsed SQL statements against the in-memory
2//! storage and produce formatted output.
3
4use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8    AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr, FromTable, FunctionArg,
9    FunctionArgExpr, FunctionArguments, ObjectNamePart, Statement, TableFactor, TableWithJoins,
10    UnaryOperator, Update,
11};
12
13use crate::error::{Result, SQLRiteError};
14use crate::sql::db::database::Database;
15use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
16use crate::sql::db::table::{DataType, Table, Value, parse_vector_literal};
17use crate::sql::parser::select::{OrderByClause, Projection, SelectQuery};
18
19/// Executes a parsed `SelectQuery` against the database and returns a
20/// human-readable rendering of the result set (prettytable). Also returns
21/// the number of rows produced, for the top-level status message.
22/// Structured result of a SELECT: column names in projection order,
23/// and each matching row as a `Vec<Value>` aligned with the columns.
24/// Phase 5a introduced this so the public `Connection` / `Statement`
25/// API has typed rows to yield; the existing `execute_select` that
26/// returns pre-rendered text is now a thin wrapper on top.
27pub struct SelectResult {
28    pub columns: Vec<String>,
29    pub rows: Vec<Vec<Value>>,
30}
31
32/// Executes a SELECT and returns structured rows. The typed rows are
33/// what the new public API streams to callers; the REPL / Tauri app
34/// pre-render into a prettytable via `execute_select`.
35pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
36    let table = db
37        .get_table(query.table_name.clone())
38        .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
39
40    // Resolve projection to a concrete ordered column list.
41    let projected_cols: Vec<String> = match &query.projection {
42        Projection::All => table.column_names(),
43        Projection::Columns(cols) => {
44            for c in cols {
45                if !table.contains_column(c.to_string()) {
46                    return Err(SQLRiteError::Internal(format!(
47                        "Column '{c}' does not exist on table '{}'",
48                        query.table_name
49                    )));
50                }
51            }
52            cols.clone()
53        }
54    };
55
56    // Collect matching rowids. If the WHERE is the shape `col = literal`
57    // and `col` has a secondary index, probe the index for an O(log N)
58    // seek; otherwise fall back to the full table scan.
59    let matching = match select_rowids(table, query.selection.as_ref())? {
60        RowidSource::IndexProbe(rowids) => rowids,
61        RowidSource::FullScan => {
62            let mut out = Vec::new();
63            for rowid in table.rowids() {
64                if let Some(expr) = &query.selection {
65                    if !eval_predicate(expr, table, rowid)? {
66                        continue;
67                    }
68                }
69                out.push(rowid);
70            }
71            out
72        }
73    };
74    let mut matching = matching;
75
76    // Phase 7c — bounded-heap top-k optimization.
77    //
78    // The naive "ORDER BY <expr>" path (Phase 7b) sorts every matching
79    // rowid: O(N log N) sort_by + a truncate. For KNN queries
80    //
81    //     SELECT id FROM docs
82    //     ORDER BY vec_distance_l2(embedding, [...])
83    //     LIMIT 10;
84    //
85    // N is the table row count and k is the LIMIT. With a bounded
86    // max-heap of size k we can find the top-k in O(N log k) — same
87    // sort_by-per-row cost on the heap operations, but k is typically
88    // 10-100 while N can be millions.
89    //
90    // We branch in three cases:
91    //   1. ORDER BY + LIMIT k where k < |matching|  → bounded heap.
92    //   2. ORDER BY without LIMIT, or LIMIT >= |matching| → full sort
93    //      (heap saves nothing when we'd keep everyone anyway).
94    //   3. LIMIT without ORDER BY → just truncate (no sort needed).
95    match (&query.order_by, query.limit) {
96        (Some(order), Some(k)) if k < matching.len() => {
97            matching = select_topk(&matching, table, order, k)?;
98        }
99        (Some(order), _) => {
100            sort_rowids(&mut matching, table, order)?;
101            if let Some(k) = query.limit {
102                matching.truncate(k);
103            }
104        }
105        (None, Some(k)) => {
106            matching.truncate(k);
107        }
108        (None, None) => {}
109    }
110
111    // Build typed rows. Missing cells surface as `Value::Null` — that
112    // maps a column-not-present-for-this-rowid case onto the public
113    // `Row::get` → `Option<T>` surface cleanly.
114    let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
115    for rowid in &matching {
116        let row: Vec<Value> = projected_cols
117            .iter()
118            .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
119            .collect();
120        rows.push(row);
121    }
122
123    Ok(SelectResult {
124        columns: projected_cols,
125        rows,
126    })
127}
128
129/// Executes a SELECT and returns `(rendered_table, row_count)`. The
130/// REPL and Tauri app use this to keep the table-printing behaviour
131/// the engine has always shipped. Structured callers use
132/// `execute_select_rows` instead.
133pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
134    let result = execute_select_rows(query, db)?;
135    let row_count = result.rows.len();
136
137    let mut print_table = PrintTable::new();
138    let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
139    print_table.add_row(PrintRow::new(header_cells));
140
141    for row in &result.rows {
142        let cells: Vec<PrintCell> = row
143            .iter()
144            .map(|v| PrintCell::new(&v.to_display_string()))
145            .collect();
146        print_table.add_row(PrintRow::new(cells));
147    }
148
149    Ok((print_table.to_string(), row_count))
150}
151
152/// Executes a DELETE statement. Returns the number of rows removed.
153pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
154    let Statement::Delete(Delete {
155        from, selection, ..
156    }) = stmt
157    else {
158        return Err(SQLRiteError::Internal(
159            "execute_delete called on a non-DELETE statement".to_string(),
160        ));
161    };
162
163    let tables = match from {
164        FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
165    };
166    let table_name = extract_single_table_name(tables)?;
167
168    // Compute matching rowids with an immutable borrow, then mutate.
169    let matching: Vec<i64> = {
170        let table = db
171            .get_table(table_name.clone())
172            .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
173        match select_rowids(table, selection.as_ref())? {
174            RowidSource::IndexProbe(rowids) => rowids,
175            RowidSource::FullScan => {
176                let mut out = Vec::new();
177                for rowid in table.rowids() {
178                    if let Some(expr) = selection {
179                        if !eval_predicate(expr, table, rowid)? {
180                            continue;
181                        }
182                    }
183                    out.push(rowid);
184                }
185                out
186            }
187        }
188    };
189
190    let table = db.get_table_mut(table_name)?;
191    for rowid in &matching {
192        table.delete_row(*rowid);
193    }
194    Ok(matching.len())
195}
196
197/// Executes an UPDATE statement. Returns the number of rows updated.
198pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
199    let Statement::Update(Update {
200        table,
201        assignments,
202        from,
203        selection,
204        ..
205    }) = stmt
206    else {
207        return Err(SQLRiteError::Internal(
208            "execute_update called on a non-UPDATE statement".to_string(),
209        ));
210    };
211
212    if from.is_some() {
213        return Err(SQLRiteError::NotImplemented(
214            "UPDATE ... FROM is not supported yet".to_string(),
215        ));
216    }
217
218    let table_name = extract_table_name(table)?;
219
220    // Resolve assignment targets to plain column names and verify they exist.
221    let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
222    {
223        let tbl = db
224            .get_table(table_name.clone())
225            .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
226        for a in assignments {
227            let col = match &a.target {
228                AssignmentTarget::ColumnName(name) => name
229                    .0
230                    .last()
231                    .map(|p| p.to_string())
232                    .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
233                AssignmentTarget::Tuple(_) => {
234                    return Err(SQLRiteError::NotImplemented(
235                        "tuple assignment targets are not supported".to_string(),
236                    ));
237                }
238            };
239            if !tbl.contains_column(col.clone()) {
240                return Err(SQLRiteError::Internal(format!(
241                    "UPDATE references unknown column '{col}'"
242                )));
243            }
244            parsed_assignments.push((col, a.value.clone()));
245        }
246    }
247
248    // Gather matching rowids + the new values to write for each assignment, under
249    // an immutable borrow. Uses the index-probe fast path when the WHERE is
250    // `col = literal` on an indexed column.
251    let work: Vec<(i64, Vec<(String, Value)>)> = {
252        let tbl = db.get_table(table_name.clone())?;
253        let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
254            RowidSource::IndexProbe(rowids) => rowids,
255            RowidSource::FullScan => {
256                let mut out = Vec::new();
257                for rowid in tbl.rowids() {
258                    if let Some(expr) = selection {
259                        if !eval_predicate(expr, tbl, rowid)? {
260                            continue;
261                        }
262                    }
263                    out.push(rowid);
264                }
265                out
266            }
267        };
268        let mut rows_to_update = Vec::new();
269        for rowid in matched_rowids {
270            let mut values = Vec::with_capacity(parsed_assignments.len());
271            for (col, expr) in &parsed_assignments {
272                // UPDATE's RHS is evaluated in the context of the row being updated,
273                // so column references on the right resolve to the current row's values.
274                let v = eval_expr(expr, tbl, rowid)?;
275                values.push((col.clone(), v));
276            }
277            rows_to_update.push((rowid, values));
278        }
279        rows_to_update
280    };
281
282    let tbl = db.get_table_mut(table_name)?;
283    for (rowid, values) in &work {
284        for (col, v) in values {
285            tbl.set_value(col, *rowid, v.clone())?;
286        }
287    }
288    Ok(work.len())
289}
290
291/// Handles `CREATE INDEX [UNIQUE] <name> ON <table> (<column>)`. Single-
292/// column indexes only; multi-column / composite indexes are future work.
293/// Returns the (possibly synthesized) index name for the status message.
294pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
295    let Statement::CreateIndex(CreateIndex {
296        name,
297        table_name,
298        columns,
299        unique,
300        if_not_exists,
301        predicate,
302        ..
303    }) = stmt
304    else {
305        return Err(SQLRiteError::Internal(
306            "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
307        ));
308    };
309
310    if predicate.is_some() {
311        return Err(SQLRiteError::NotImplemented(
312            "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
313        ));
314    }
315
316    if columns.len() != 1 {
317        return Err(SQLRiteError::NotImplemented(format!(
318            "multi-column indexes are not supported yet ({} columns given)",
319            columns.len()
320        )));
321    }
322
323    let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
324        SQLRiteError::NotImplemented(
325            "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
326        )
327    })?;
328
329    let table_name_str = table_name.to_string();
330    let column_name = match &columns[0].column.expr {
331        Expr::Identifier(ident) => ident.value.clone(),
332        Expr::CompoundIdentifier(parts) => parts
333            .last()
334            .map(|p| p.value.clone())
335            .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
336        other => {
337            return Err(SQLRiteError::NotImplemented(format!(
338                "CREATE INDEX only supports simple column references, got {other:?}"
339            )));
340        }
341    };
342
343    // Validate: table exists, column exists, type is indexable, name is unique.
344    let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
345        let table = db.get_table(table_name_str.clone()).map_err(|_| {
346            SQLRiteError::General(format!(
347                "CREATE INDEX references unknown table '{table_name_str}'"
348            ))
349        })?;
350        if !table.contains_column(column_name.clone()) {
351            return Err(SQLRiteError::General(format!(
352                "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
353            )));
354        }
355        let col = table
356            .columns
357            .iter()
358            .find(|c| c.column_name == column_name)
359            .expect("we just verified the column exists");
360        if table.index_by_name(&index_name).is_some() {
361            if *if_not_exists {
362                return Ok(index_name);
363            }
364            return Err(SQLRiteError::General(format!(
365                "index '{index_name}' already exists"
366            )));
367        }
368        let datatype = clone_datatype(&col.datatype);
369
370        // Snapshot (rowid, value) pairs so we can populate the index after
371        // it's attached. Doing this under the immutable borrow of the table
372        // means the mutable attach below can proceed without aliasing.
373        let mut pairs = Vec::new();
374        for rowid in table.rowids() {
375            if let Some(v) = table.get_value(&column_name, rowid) {
376                pairs.push((rowid, v));
377            }
378        }
379        (datatype, pairs)
380    };
381
382    // Build the index.
383    let mut idx = SecondaryIndex::new(
384        index_name.clone(),
385        table_name_str.clone(),
386        column_name.clone(),
387        &datatype,
388        *unique,
389        IndexOrigin::Explicit,
390    )?;
391
392    // Populate from the existing rows. UNIQUE violations here mean the
393    // existing data already breaks the new index's constraint — a common
394    // source of user confusion, so be explicit.
395    for (rowid, v) in &existing_rowids_and_values {
396        if *unique && idx.would_violate_unique(v) {
397            return Err(SQLRiteError::General(format!(
398                "cannot create UNIQUE index '{index_name}': column '{column_name}' \
399                 already contains the duplicate value {}",
400                v.to_display_string()
401            )));
402        }
403        idx.insert(v, *rowid)?;
404    }
405
406    // Attach to the table.
407    let table_mut = db.get_table_mut(table_name_str)?;
408    table_mut.secondary_indexes.push(idx);
409    Ok(index_name)
410}
411
412/// Cheap clone helper — `DataType` intentionally doesn't derive `Clone`
413/// because the enum has no ergonomic reason to be cloneable elsewhere.
414fn clone_datatype(dt: &DataType) -> DataType {
415    match dt {
416        DataType::Integer => DataType::Integer,
417        DataType::Text => DataType::Text,
418        DataType::Real => DataType::Real,
419        DataType::Bool => DataType::Bool,
420        DataType::Vector(dim) => DataType::Vector(*dim),
421        DataType::None => DataType::None,
422        DataType::Invalid => DataType::Invalid,
423    }
424}
425
426fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
427    if tables.len() != 1 {
428        return Err(SQLRiteError::NotImplemented(
429            "multi-table DELETE is not supported yet".to_string(),
430        ));
431    }
432    extract_table_name(&tables[0])
433}
434
435fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
436    if !twj.joins.is_empty() {
437        return Err(SQLRiteError::NotImplemented(
438            "JOIN is not supported yet".to_string(),
439        ));
440    }
441    match &twj.relation {
442        TableFactor::Table { name, .. } => Ok(name.to_string()),
443        _ => Err(SQLRiteError::NotImplemented(
444            "only plain table references are supported".to_string(),
445        )),
446    }
447}
448
449/// Tells the executor how to produce its candidate rowid list.
450enum RowidSource {
451    /// The WHERE was simple enough to probe a secondary index directly.
452    /// The `Vec` already contains exactly the rows the index matched;
453    /// no further WHERE evaluation is needed (the probe is precise).
454    IndexProbe(Vec<i64>),
455    /// No applicable index; caller falls back to walking `table.rowids()`
456    /// and evaluating the WHERE on each row.
457    FullScan,
458}
459
460/// Try to satisfy `WHERE` with an index probe. Currently supports the
461/// simplest shape: a single `col = literal` (or `literal = col`) where
462/// `col` is on a secondary index. AND/OR/range predicates fall back to
463/// full scan — those can be layered on later without changing the caller.
464fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
465    let Some(expr) = selection else {
466        return Ok(RowidSource::FullScan);
467    };
468    let Some((col, literal)) = try_extract_equality(expr) else {
469        return Ok(RowidSource::FullScan);
470    };
471    let Some(idx) = table.index_for_column(&col) else {
472        return Ok(RowidSource::FullScan);
473    };
474
475    // Convert the literal into a runtime Value. If the literal type doesn't
476    // match the column's index we still need correct semantics — evaluate
477    // the WHERE against every row. Fall back to full scan.
478    let literal_value = match convert_literal(&literal) {
479        Ok(v) => v,
480        Err(_) => return Ok(RowidSource::FullScan),
481    };
482
483    // Index lookup returns the full list of rowids matching this equality
484    // predicate. For unique indexes that's at most one; for non-unique it
485    // can be many.
486    let mut rowids = idx.lookup(&literal_value);
487    rowids.sort_unstable();
488    Ok(RowidSource::IndexProbe(rowids))
489}
490
491/// Recognizes `expr` as a simple equality on a column reference against a
492/// literal. Returns `(column_name, literal_value)` if the shape matches;
493/// `None` otherwise. Accepts both `col = literal` and `literal = col`.
494fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
495    // Peel off Nested parens so `WHERE (x = 1)` is recognized too.
496    let peeled = match expr {
497        Expr::Nested(inner) => inner.as_ref(),
498        other => other,
499    };
500    let Expr::BinaryOp { left, op, right } = peeled else {
501        return None;
502    };
503    if !matches!(op, BinaryOperator::Eq) {
504        return None;
505    }
506    let col_from = |e: &Expr| -> Option<String> {
507        match e {
508            Expr::Identifier(ident) => Some(ident.value.clone()),
509            Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
510            _ => None,
511        }
512    };
513    let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
514        if let Expr::Value(v) = e {
515            Some(v.value.clone())
516        } else {
517            None
518        }
519    };
520    if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
521        return Some((c, l));
522    }
523    if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
524        return Some((c, l));
525    }
526    None
527}
528
529/// One entry in the bounded-heap top-k path. Holds a pre-evaluated
530/// sort key + the rowid it came from. The `asc` flag inverts `Ord`
531/// so a single `BinaryHeap<HeapEntry>` works for both ASC and DESC
532/// without wrapping in `std::cmp::Reverse` at the call site:
533///
534///   - ASC LIMIT k = "k smallest": natural Ord. Max-heap top is the
535///     largest currently kept; new items smaller than top displace.
536///   - DESC LIMIT k = "k largest": Ord reversed. Max-heap top is now
537///     the smallest currently kept (under reversed Ord, smallest
538///     looks largest); new items larger than top displace.
539///
540/// In both cases the displacement test reduces to "new entry < heap top".
541struct HeapEntry {
542    key: Value,
543    rowid: i64,
544    asc: bool,
545}
546
547impl PartialEq for HeapEntry {
548    fn eq(&self, other: &Self) -> bool {
549        self.cmp(other) == Ordering::Equal
550    }
551}
552
553impl Eq for HeapEntry {}
554
555impl PartialOrd for HeapEntry {
556    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
557        Some(self.cmp(other))
558    }
559}
560
561impl Ord for HeapEntry {
562    fn cmp(&self, other: &Self) -> Ordering {
563        let raw = compare_values(Some(&self.key), Some(&other.key));
564        if self.asc { raw } else { raw.reverse() }
565    }
566}
567
568/// Bounded-heap top-k selection. Returns at most `k` rowids in the
569/// caller's desired order (ascending key for `order.ascending`,
570/// descending otherwise).
571///
572/// O(N log k) where N = `matching.len()`. Caller must check
573/// `k < matching.len()` for this to be a win — for k ≥ N the
574/// `sort_rowids` full-sort path is the same asymptotic cost without
575/// the heap overhead.
576fn select_topk(
577    matching: &[i64],
578    table: &Table,
579    order: &OrderByClause,
580    k: usize,
581) -> Result<Vec<i64>> {
582    use std::collections::BinaryHeap;
583
584    if k == 0 || matching.is_empty() {
585        return Ok(Vec::new());
586    }
587
588    let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
589
590    for &rowid in matching {
591        let key = eval_expr(&order.expr, table, rowid)?;
592        let entry = HeapEntry {
593            key,
594            rowid,
595            asc: order.ascending,
596        };
597
598        if heap.len() < k {
599            heap.push(entry);
600        } else {
601            // peek() returns the largest under our direction-aware Ord
602            // — the worst entry currently kept. Displace it iff the
603            // new entry is "better" (i.e. compares Less).
604            if entry < *heap.peek().unwrap() {
605                heap.pop();
606                heap.push(entry);
607            }
608        }
609    }
610
611    // `into_sorted_vec` returns ascending under our direction-aware Ord:
612    //   ASC: ascending by raw key (what we want)
613    //   DESC: ascending under reversed Ord = descending by raw key (what
614    //         we want for an ORDER BY DESC LIMIT k result)
615    Ok(heap
616        .into_sorted_vec()
617        .into_iter()
618        .map(|e| e.rowid)
619        .collect())
620}
621
622fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
623    // Phase 7b: ORDER BY now accepts any expression (column ref,
624    // arithmetic, function call, …). Pre-compute the sort key for
625    // every rowid up front so the comparator is called O(N log N)
626    // times against pre-evaluated Values rather than re-evaluating
627    // the expression O(N log N) times. Not strictly necessary today,
628    // but vital once 7d's HNSW index lands and this same code path
629    // could be running tens of millions of distance computations.
630    let mut keys: Vec<(i64, Result<Value>)> = rowids
631        .iter()
632        .map(|r| (*r, eval_expr(&order.expr, table, *r)))
633        .collect();
634
635    // Surface the FIRST evaluation error if any. We could be lazy
636    // and let sort_by encounter it, but `Ord::cmp` can't return a
637    // Result and we'd have to swallow errors silently.
638    for (_, k) in &keys {
639        if let Err(e) = k {
640            return Err(SQLRiteError::General(format!(
641                "ORDER BY expression failed: {e}"
642            )));
643        }
644    }
645
646    keys.sort_by(|(_, ka), (_, kb)| {
647        // Both unwrap()s are safe — we just verified above that
648        // every key Result is Ok.
649        let va = ka.as_ref().unwrap();
650        let vb = kb.as_ref().unwrap();
651        let ord = compare_values(Some(va), Some(vb));
652        if order.ascending { ord } else { ord.reverse() }
653    });
654
655    // Write the sorted rowids back into the caller's slice.
656    for (i, (rowid, _)) in keys.into_iter().enumerate() {
657        rowids[i] = rowid;
658    }
659    Ok(())
660}
661
662fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
663    match (a, b) {
664        (None, None) => Ordering::Equal,
665        (None, _) => Ordering::Less,
666        (_, None) => Ordering::Greater,
667        (Some(a), Some(b)) => match (a, b) {
668            (Value::Null, Value::Null) => Ordering::Equal,
669            (Value::Null, _) => Ordering::Less,
670            (_, Value::Null) => Ordering::Greater,
671            (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
672            (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
673            (Value::Integer(x), Value::Real(y)) => {
674                (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
675            }
676            (Value::Real(x), Value::Integer(y)) => {
677                x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
678            }
679            (Value::Text(x), Value::Text(y)) => x.cmp(y),
680            (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
681            // Cross-type fallback: stringify and compare; keeps ORDER BY total.
682            (x, y) => x.to_display_string().cmp(&y.to_display_string()),
683        },
684    }
685}
686
687/// Returns `true` if the row at `rowid` matches the predicate expression.
688pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
689    let v = eval_expr(expr, table, rowid)?;
690    match v {
691        Value::Bool(b) => Ok(b),
692        Value::Null => Ok(false), // SQL NULL in a WHERE is treated as false
693        Value::Integer(i) => Ok(i != 0),
694        other => Err(SQLRiteError::Internal(format!(
695            "WHERE clause must evaluate to boolean, got {}",
696            other.to_display_string()
697        ))),
698    }
699}
700
701fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
702    match expr {
703        Expr::Nested(inner) => eval_expr(inner, table, rowid),
704
705        Expr::Identifier(ident) => {
706            // Phase 7b — sqlparser parses bracket-array literals like
707            // `[0.1, 0.2, 0.3]` as bracket-quoted identifiers (it inherits
708            // MSSQL `[name]` syntax). When we see `quote_style == Some('[')`
709            // in expression-evaluation position (SELECT projection, WHERE,
710            // ORDER BY, function args), parse the bracketed content as a
711            // vector literal so the rest of the executor can compare /
712            // distance-compute against it. Same trick the INSERT parser
713            // uses; the executor needed its own copy because expression
714            // eval runs on a different code path.
715            if ident.quote_style == Some('[') {
716                let raw = format!("[{}]", ident.value);
717                let v = parse_vector_literal(&raw)?;
718                return Ok(Value::Vector(v));
719            }
720            Ok(table.get_value(&ident.value, rowid).unwrap_or(Value::Null))
721        }
722
723        Expr::CompoundIdentifier(parts) => {
724            // Accept `table.col` — we only have one table in scope, so ignore the qualifier.
725            let col = parts
726                .last()
727                .map(|i| i.value.as_str())
728                .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
729            Ok(table.get_value(col, rowid).unwrap_or(Value::Null))
730        }
731
732        Expr::Value(v) => convert_literal(&v.value),
733
734        Expr::UnaryOp { op, expr } => {
735            let inner = eval_expr(expr, table, rowid)?;
736            match op {
737                UnaryOperator::Not => match inner {
738                    Value::Bool(b) => Ok(Value::Bool(!b)),
739                    Value::Null => Ok(Value::Null),
740                    other => Err(SQLRiteError::Internal(format!(
741                        "NOT applied to non-boolean value: {}",
742                        other.to_display_string()
743                    ))),
744                },
745                UnaryOperator::Minus => match inner {
746                    Value::Integer(i) => Ok(Value::Integer(-i)),
747                    Value::Real(f) => Ok(Value::Real(-f)),
748                    Value::Null => Ok(Value::Null),
749                    other => Err(SQLRiteError::Internal(format!(
750                        "unary minus on non-numeric value: {}",
751                        other.to_display_string()
752                    ))),
753                },
754                UnaryOperator::Plus => Ok(inner),
755                other => Err(SQLRiteError::NotImplemented(format!(
756                    "unary operator {other:?} is not supported"
757                ))),
758            }
759        }
760
761        Expr::BinaryOp { left, op, right } => match op {
762            BinaryOperator::And => {
763                let l = eval_expr(left, table, rowid)?;
764                let r = eval_expr(right, table, rowid)?;
765                Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
766            }
767            BinaryOperator::Or => {
768                let l = eval_expr(left, table, rowid)?;
769                let r = eval_expr(right, table, rowid)?;
770                Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
771            }
772            cmp @ (BinaryOperator::Eq
773            | BinaryOperator::NotEq
774            | BinaryOperator::Lt
775            | BinaryOperator::LtEq
776            | BinaryOperator::Gt
777            | BinaryOperator::GtEq) => {
778                let l = eval_expr(left, table, rowid)?;
779                let r = eval_expr(right, table, rowid)?;
780                // Any comparison involving NULL is unknown → false in a WHERE.
781                if matches!(l, Value::Null) || matches!(r, Value::Null) {
782                    return Ok(Value::Bool(false));
783                }
784                let ord = compare_values(Some(&l), Some(&r));
785                let result = match cmp {
786                    BinaryOperator::Eq => ord == Ordering::Equal,
787                    BinaryOperator::NotEq => ord != Ordering::Equal,
788                    BinaryOperator::Lt => ord == Ordering::Less,
789                    BinaryOperator::LtEq => ord != Ordering::Greater,
790                    BinaryOperator::Gt => ord == Ordering::Greater,
791                    BinaryOperator::GtEq => ord != Ordering::Less,
792                    _ => unreachable!(),
793                };
794                Ok(Value::Bool(result))
795            }
796            arith @ (BinaryOperator::Plus
797            | BinaryOperator::Minus
798            | BinaryOperator::Multiply
799            | BinaryOperator::Divide
800            | BinaryOperator::Modulo) => {
801                let l = eval_expr(left, table, rowid)?;
802                let r = eval_expr(right, table, rowid)?;
803                eval_arith(arith, &l, &r)
804            }
805            BinaryOperator::StringConcat => {
806                let l = eval_expr(left, table, rowid)?;
807                let r = eval_expr(right, table, rowid)?;
808                if matches!(l, Value::Null) || matches!(r, Value::Null) {
809                    return Ok(Value::Null);
810                }
811                Ok(Value::Text(format!(
812                    "{}{}",
813                    l.to_display_string(),
814                    r.to_display_string()
815                )))
816            }
817            other => Err(SQLRiteError::NotImplemented(format!(
818                "binary operator {other:?} is not supported yet"
819            ))),
820        },
821
822        // Phase 7b — function-call dispatch. Currently only the three
823        // vector-distance functions; this match arm becomes the single
824        // place to register more SQL functions later (e.g. abs(),
825        // length(), …) without re-touching the rest of the executor.
826        //
827        // Operator forms (`<->` `<=>` `<#>`) are NOT plumbed here: two
828        // of three don't parse natively in sqlparser (we'd need a
829        // string-preprocessing pass or a sqlparser fork). Deferred to
830        // a follow-up sub-phase; see docs/phase-7-plan.md's "Scope
831        // corrections" note.
832        Expr::Function(func) => eval_function(func, table, rowid),
833
834        other => Err(SQLRiteError::NotImplemented(format!(
835            "unsupported expression in WHERE/projection: {other:?}"
836        ))),
837    }
838}
839
840/// Dispatches an `Expr::Function` to its built-in implementation.
841/// Currently only the three vec_distance_* functions; other functions
842/// surface as `NotImplemented` errors with the function name in the
843/// message so users see what they tried.
844fn eval_function(func: &sqlparser::ast::Function, table: &Table, rowid: i64) -> Result<Value> {
845    // Function name lives in `name.0[0]` for unqualified calls. Anything
846    // qualified (e.g. `pkg.fn(...)`) falls through to NotImplemented.
847    let name = match func.name.0.as_slice() {
848        [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
849        _ => {
850            return Err(SQLRiteError::NotImplemented(format!(
851                "qualified function names not supported: {:?}",
852                func.name
853            )));
854        }
855    };
856
857    match name.as_str() {
858        "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
859            let (a, b) = extract_two_vector_args(&name, &func.args, table, rowid)?;
860            let dist = match name.as_str() {
861                "vec_distance_l2" => vec_distance_l2(&a, &b),
862                "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
863                "vec_distance_dot" => vec_distance_dot(&a, &b),
864                _ => unreachable!(),
865            };
866            // Widen f32 → f64 for the runtime Value. Vectors are stored
867            // as f32 (consistent with industry convention for embeddings),
868            // but the executor's numeric type is f64 so distances slot
869            // into Value::Real cleanly and can be compared / ordered with
870            // other reals via the existing arithmetic + comparison paths.
871            Ok(Value::Real(dist as f64))
872        }
873        other => Err(SQLRiteError::NotImplemented(format!(
874            "unknown function: {other}(...)"
875        ))),
876    }
877}
878
879/// Extracts exactly two `Vec<f32>` arguments from a function call,
880/// validating arity and that both sides are Vector-typed with matching
881/// dimensions. Used by all three vec_distance_* functions.
882fn extract_two_vector_args(
883    fn_name: &str,
884    args: &FunctionArguments,
885    table: &Table,
886    rowid: i64,
887) -> Result<(Vec<f32>, Vec<f32>)> {
888    let arg_list = match args {
889        FunctionArguments::List(l) => &l.args,
890        _ => {
891            return Err(SQLRiteError::General(format!(
892                "{fn_name}() expects exactly two vector arguments"
893            )));
894        }
895    };
896    if arg_list.len() != 2 {
897        return Err(SQLRiteError::General(format!(
898            "{fn_name}() expects exactly 2 arguments, got {}",
899            arg_list.len()
900        )));
901    }
902    let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
903    for (i, arg) in arg_list.iter().enumerate() {
904        let expr = match arg {
905            FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
906            other => {
907                return Err(SQLRiteError::NotImplemented(format!(
908                    "{fn_name}() argument {i} has unsupported shape: {other:?}"
909                )));
910            }
911        };
912        let val = eval_expr(expr, table, rowid)?;
913        match val {
914            Value::Vector(v) => out.push(v),
915            other => {
916                return Err(SQLRiteError::General(format!(
917                    "{fn_name}() argument {i} is not a vector: got {}",
918                    other.to_display_string()
919                )));
920            }
921        }
922    }
923    let b = out.pop().unwrap();
924    let a = out.pop().unwrap();
925    if a.len() != b.len() {
926        return Err(SQLRiteError::General(format!(
927            "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
928            a.len(),
929            b.len()
930        )));
931    }
932    Ok((a, b))
933}
934
935/// Euclidean (L2) distance: √Σ(aᵢ − bᵢ)².
936/// Smaller-is-closer; identical vectors return 0.0.
937pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
938    debug_assert_eq!(a.len(), b.len());
939    let mut sum = 0.0f32;
940    for i in 0..a.len() {
941        let d = a[i] - b[i];
942        sum += d * d;
943    }
944    sum.sqrt()
945}
946
947/// Cosine distance: 1 − (a·b) / (‖a‖·‖b‖).
948/// Smaller-is-closer; identical (non-zero) vectors return 0.0,
949/// orthogonal vectors return 1.0, opposite-direction vectors return 2.0.
950///
951/// Errors if either vector has zero magnitude — cosine similarity is
952/// undefined for the zero vector and silently returning NaN would
953/// poison `ORDER BY` ranking. Callers who want the silent-NaN
954/// behavior can compute `vec_distance_dot(a, b) / (norm(a) * norm(b))`
955/// themselves.
956pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
957    debug_assert_eq!(a.len(), b.len());
958    let mut dot = 0.0f32;
959    let mut norm_a_sq = 0.0f32;
960    let mut norm_b_sq = 0.0f32;
961    for i in 0..a.len() {
962        dot += a[i] * b[i];
963        norm_a_sq += a[i] * a[i];
964        norm_b_sq += b[i] * b[i];
965    }
966    let denom = (norm_a_sq * norm_b_sq).sqrt();
967    if denom == 0.0 {
968        return Err(SQLRiteError::General(
969            "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
970        ));
971    }
972    Ok(1.0 - dot / denom)
973}
974
975/// Negated dot product: −(a·b).
976/// pgvector convention — negated so smaller-is-closer like L2 / cosine.
977/// For unit-norm vectors `vec_distance_dot(a, b) == vec_distance_cosine(a, b) - 1`.
978pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
979    debug_assert_eq!(a.len(), b.len());
980    let mut dot = 0.0f32;
981    for i in 0..a.len() {
982        dot += a[i] * b[i];
983    }
984    -dot
985}
986
987/// Evaluates an integer/real arithmetic op. NULL on either side propagates.
988/// Mixed Integer/Real promotes to Real. Divide/Modulo by zero → error.
989fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
990    if matches!(l, Value::Null) || matches!(r, Value::Null) {
991        return Ok(Value::Null);
992    }
993    match (l, r) {
994        (Value::Integer(a), Value::Integer(b)) => match op {
995            BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
996            BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
997            BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
998            BinaryOperator::Divide => {
999                if *b == 0 {
1000                    Err(SQLRiteError::General("division by zero".to_string()))
1001                } else {
1002                    Ok(Value::Integer(a / b))
1003                }
1004            }
1005            BinaryOperator::Modulo => {
1006                if *b == 0 {
1007                    Err(SQLRiteError::General("modulo by zero".to_string()))
1008                } else {
1009                    Ok(Value::Integer(a % b))
1010                }
1011            }
1012            _ => unreachable!(),
1013        },
1014        // Anything involving a Real promotes both sides to f64.
1015        (a, b) => {
1016            let af = as_number(a)?;
1017            let bf = as_number(b)?;
1018            match op {
1019                BinaryOperator::Plus => Ok(Value::Real(af + bf)),
1020                BinaryOperator::Minus => Ok(Value::Real(af - bf)),
1021                BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
1022                BinaryOperator::Divide => {
1023                    if bf == 0.0 {
1024                        Err(SQLRiteError::General("division by zero".to_string()))
1025                    } else {
1026                        Ok(Value::Real(af / bf))
1027                    }
1028                }
1029                BinaryOperator::Modulo => {
1030                    if bf == 0.0 {
1031                        Err(SQLRiteError::General("modulo by zero".to_string()))
1032                    } else {
1033                        Ok(Value::Real(af % bf))
1034                    }
1035                }
1036                _ => unreachable!(),
1037            }
1038        }
1039    }
1040}
1041
1042fn as_number(v: &Value) -> Result<f64> {
1043    match v {
1044        Value::Integer(i) => Ok(*i as f64),
1045        Value::Real(f) => Ok(*f),
1046        Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
1047        other => Err(SQLRiteError::General(format!(
1048            "arithmetic on non-numeric value '{}'",
1049            other.to_display_string()
1050        ))),
1051    }
1052}
1053
1054fn as_bool(v: &Value) -> Result<bool> {
1055    match v {
1056        Value::Bool(b) => Ok(*b),
1057        Value::Null => Ok(false),
1058        Value::Integer(i) => Ok(*i != 0),
1059        other => Err(SQLRiteError::Internal(format!(
1060            "expected boolean, got {}",
1061            other.to_display_string()
1062        ))),
1063    }
1064}
1065
1066fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
1067    use sqlparser::ast::Value as AstValue;
1068    match v {
1069        AstValue::Number(n, _) => {
1070            if let Ok(i) = n.parse::<i64>() {
1071                Ok(Value::Integer(i))
1072            } else if let Ok(f) = n.parse::<f64>() {
1073                Ok(Value::Real(f))
1074            } else {
1075                Err(SQLRiteError::Internal(format!(
1076                    "could not parse numeric literal '{n}'"
1077                )))
1078            }
1079        }
1080        AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
1081        AstValue::Boolean(b) => Ok(Value::Bool(*b)),
1082        AstValue::Null => Ok(Value::Null),
1083        other => Err(SQLRiteError::NotImplemented(format!(
1084            "unsupported literal value: {other:?}"
1085        ))),
1086    }
1087}
1088
1089#[cfg(test)]
1090mod tests {
1091    use super::*;
1092
1093    // -----------------------------------------------------------------
1094    // Phase 7b — Vector distance function math
1095    // -----------------------------------------------------------------
1096
1097    /// Float comparison helper — distance results need a small epsilon
1098    /// because we accumulate sums across many f32 multiplies.
1099    fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
1100        (a - b).abs() < eps
1101    }
1102
1103    #[test]
1104    fn vec_distance_l2_identical_is_zero() {
1105        let v = vec![0.1, 0.2, 0.3];
1106        assert_eq!(vec_distance_l2(&v, &v), 0.0);
1107    }
1108
1109    #[test]
1110    fn vec_distance_l2_unit_basis_is_sqrt2() {
1111        // [1, 0] vs [0, 1]: distance = √((1-0)² + (0-1)²) = √2 ≈ 1.414
1112        let a = vec![1.0, 0.0];
1113        let b = vec![0.0, 1.0];
1114        assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
1115    }
1116
1117    #[test]
1118    fn vec_distance_l2_known_value() {
1119        // [0, 0, 0] vs [3, 4, 0]: √(9 + 16 + 0) = 5 (the classic 3-4-5 triangle).
1120        let a = vec![0.0, 0.0, 0.0];
1121        let b = vec![3.0, 4.0, 0.0];
1122        assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
1123    }
1124
1125    #[test]
1126    fn vec_distance_cosine_identical_is_zero() {
1127        let v = vec![0.1, 0.2, 0.3];
1128        let d = vec_distance_cosine(&v, &v).unwrap();
1129        assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
1130    }
1131
1132    #[test]
1133    fn vec_distance_cosine_orthogonal_is_one() {
1134        // Two orthogonal unit vectors should have cosine distance = 1.0
1135        // (cosine similarity = 0 → distance = 1 - 0 = 1).
1136        let a = vec![1.0, 0.0];
1137        let b = vec![0.0, 1.0];
1138        assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
1139    }
1140
1141    #[test]
1142    fn vec_distance_cosine_opposite_is_two() {
1143        // a and -a have cosine similarity = -1 → distance = 1 - (-1) = 2.
1144        let a = vec![1.0, 0.0, 0.0];
1145        let b = vec![-1.0, 0.0, 0.0];
1146        assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
1147    }
1148
1149    #[test]
1150    fn vec_distance_cosine_zero_magnitude_errors() {
1151        // Cosine is undefined for the zero vector — error rather than NaN.
1152        let a = vec![0.0, 0.0];
1153        let b = vec![1.0, 0.0];
1154        let err = vec_distance_cosine(&a, &b).unwrap_err();
1155        assert!(format!("{err}").contains("zero-magnitude"));
1156    }
1157
1158    #[test]
1159    fn vec_distance_dot_negates() {
1160        // a·b = 1*4 + 2*5 + 3*6 = 32. Negated → -32.
1161        let a = vec![1.0, 2.0, 3.0];
1162        let b = vec![4.0, 5.0, 6.0];
1163        assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
1164    }
1165
1166    #[test]
1167    fn vec_distance_dot_orthogonal_is_zero() {
1168        // Orthogonal vectors have dot product 0 → negated is also 0.
1169        let a = vec![1.0, 0.0];
1170        let b = vec![0.0, 1.0];
1171        assert_eq!(vec_distance_dot(&a, &b), 0.0);
1172    }
1173
1174    #[test]
1175    fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
1176        // For unit-norm vectors: dot(a,b) = cos(a,b)
1177        // → -dot(a,b) = -cos(a,b) = (1 - cos(a,b)) - 1 = vec_distance_cosine(a,b) - 1.
1178        // Useful sanity check that the two functions agree on unit vectors.
1179        let a = vec![0.6f32, 0.8]; // unit norm: √(0.36+0.64) = 1
1180        let b = vec![0.8f32, 0.6]; // unit norm too
1181        let dot = vec_distance_dot(&a, &b);
1182        let cos = vec_distance_cosine(&a, &b).unwrap();
1183        assert!(approx_eq(dot, cos - 1.0, 1e-5));
1184    }
1185
1186    // -----------------------------------------------------------------
1187    // Phase 7c — bounded-heap top-k correctness + benchmark
1188    // -----------------------------------------------------------------
1189
1190    use crate::sql::db::database::Database;
1191    use crate::sql::parser::select::SelectQuery;
1192    use sqlparser::dialect::SQLiteDialect;
1193    use sqlparser::parser::Parser;
1194
1195    /// Builds a `docs(id INTEGER PK, score REAL)` table with N rows of
1196    /// distinct positive scores so top-k tests aren't sensitive to
1197    /// tie-breaking (heap is unstable; full-sort is stable; we want
1198    /// both to agree without arguing about equal-score row order).
1199    ///
1200    /// **Why positive scores:** the INSERT parser doesn't currently
1201    /// handle `Expr::UnaryOp(Minus, …)` for negative number literals
1202    /// (it would parse `-3.14` as a unary expression and the value
1203    /// extractor would skip it). That's a pre-existing bug, out of
1204    /// scope for 7c. Using the Knuth multiplicative hash gives us
1205    /// distinct positive scrambled values without dancing around the
1206    /// negative-literal limitation.
1207    fn seed_score_table(n: usize) -> Database {
1208        let mut db = Database::new("tempdb".to_string());
1209        crate::sql::process_command(
1210            "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
1211            &mut db,
1212        )
1213        .expect("create");
1214        for i in 0..n {
1215            // Knuth multiplicative hash mod 1_000_000 — distinct,
1216            // dense in [0, 999_999], no collisions for n up to ~tens
1217            // of thousands.
1218            let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
1219            let sql = format!("INSERT INTO docs (score) VALUES ({score});");
1220            crate::sql::process_command(&sql, &mut db).expect("insert");
1221        }
1222        db
1223    }
1224
1225    /// Helper: parses an SQL SELECT into a SelectQuery so we can drive
1226    /// `select_topk` / `sort_rowids` directly without the rest of the
1227    /// process_command pipeline.
1228    fn parse_select(sql: &str) -> SelectQuery {
1229        let dialect = SQLiteDialect {};
1230        let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
1231        let stmt = ast.pop().expect("one statement");
1232        SelectQuery::new(&stmt).expect("select-query")
1233    }
1234
1235    #[test]
1236    fn topk_matches_full_sort_asc() {
1237        // Build N=200, top-k=10. Bounded heap output must equal
1238        // full-sort-then-truncate output (both produce ASC order).
1239        let db = seed_score_table(200);
1240        let table = db.get_table("docs".to_string()).unwrap();
1241        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
1242        let order = q.order_by.as_ref().unwrap();
1243        let all_rowids = table.rowids();
1244
1245        // Full-sort path
1246        let mut full = all_rowids.clone();
1247        sort_rowids(&mut full, table, order).unwrap();
1248        full.truncate(10);
1249
1250        // Bounded-heap path
1251        let topk = select_topk(&all_rowids, table, order, 10).unwrap();
1252
1253        assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
1254    }
1255
1256    #[test]
1257    fn topk_matches_full_sort_desc() {
1258        // Same with DESC — verifies the direction-aware Ord wrapper.
1259        let db = seed_score_table(200);
1260        let table = db.get_table("docs".to_string()).unwrap();
1261        let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
1262        let order = q.order_by.as_ref().unwrap();
1263        let all_rowids = table.rowids();
1264
1265        let mut full = all_rowids.clone();
1266        sort_rowids(&mut full, table, order).unwrap();
1267        full.truncate(10);
1268
1269        let topk = select_topk(&all_rowids, table, order, 10).unwrap();
1270
1271        assert_eq!(
1272            topk, full,
1273            "top-k DESC via heap should match full-sort+truncate"
1274        );
1275    }
1276
1277    #[test]
1278    fn topk_k_larger_than_n_returns_everything_sorted() {
1279        // The executor branches off to the full-sort path when k >= N,
1280        // but if a caller invokes select_topk directly with k > N, it
1281        // should still produce all-sorted output (no truncation
1282        // because we don't have N items to truncate to k).
1283        let db = seed_score_table(50);
1284        let table = db.get_table("docs".to_string()).unwrap();
1285        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
1286        let order = q.order_by.as_ref().unwrap();
1287        let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
1288        assert_eq!(topk.len(), 50);
1289        // All scores in ascending order.
1290        let scores: Vec<f64> = topk
1291            .iter()
1292            .filter_map(|r| match table.get_value("score", *r) {
1293                Some(Value::Real(f)) => Some(f),
1294                _ => None,
1295            })
1296            .collect();
1297        assert!(scores.windows(2).all(|w| w[0] <= w[1]));
1298    }
1299
1300    #[test]
1301    fn topk_k_zero_returns_empty() {
1302        let db = seed_score_table(10);
1303        let table = db.get_table("docs".to_string()).unwrap();
1304        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
1305        let order = q.order_by.as_ref().unwrap();
1306        let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
1307        assert!(topk.is_empty());
1308    }
1309
1310    #[test]
1311    fn topk_empty_input_returns_empty() {
1312        let db = seed_score_table(0);
1313        let table = db.get_table("docs".to_string()).unwrap();
1314        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
1315        let order = q.order_by.as_ref().unwrap();
1316        let topk = select_topk(&[], table, order, 5).unwrap();
1317        assert!(topk.is_empty());
1318    }
1319
1320    #[test]
1321    fn topk_works_through_select_executor_with_distance_function() {
1322        // Integration check that the executor actually picks the
1323        // bounded-heap path on a KNN-shaped query and produces the
1324        // correct top-k.
1325        let mut db = Database::new("tempdb".to_string());
1326        crate::sql::process_command(
1327            "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
1328            &mut db,
1329        )
1330        .unwrap();
1331        // Five rows with distinct distances from probe [1.0, 0.0]:
1332        //   id=1 [1.0, 0.0]   distance=0
1333        //   id=2 [2.0, 0.0]   distance=1
1334        //   id=3 [0.0, 3.0]   distance=√(1+9) = √10 ≈ 3.16
1335        //   id=4 [1.0, 4.0]   distance=4
1336        //   id=5 [10.0, 10.0] distance=√(81+100) ≈ 13.45
1337        for v in &[
1338            "[1.0, 0.0]",
1339            "[2.0, 0.0]",
1340            "[0.0, 3.0]",
1341            "[1.0, 4.0]",
1342            "[10.0, 10.0]",
1343        ] {
1344            crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
1345                .unwrap();
1346        }
1347        let resp = crate::sql::process_command(
1348            "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
1349            &mut db,
1350        )
1351        .unwrap();
1352        // Top-3 closest to [1.0, 0.0] are id=1, id=2, id=3 (in that order).
1353        // The status message tells us how many rows came back.
1354        assert!(resp.contains("3 rows returned"), "got: {resp}");
1355    }
1356
1357    /// Manual benchmark — not run by default. Recommended invocation:
1358    ///
1359    ///     cargo test -p sqlrite-engine --lib topk_benchmark --release \
1360    ///         -- --ignored --nocapture
1361    ///
1362    /// (`--release` matters: Rust's optimized sort gets very fast under
1363    /// optimization, so the heap's relative advantage is best observed
1364    /// against a sort that's also been optimized.)
1365    ///
1366    /// Measured numbers on an Apple Silicon laptop with N=10_000 + k=10:
1367    ///   - bounded heap:    ~820µs
1368    ///   - full sort+trunc: ~1.5ms
1369    ///   - ratio:           ~1.8×
1370    ///
1371    /// The advantage is real but moderate at this size because the sort
1372    /// key here is a single REAL column read (cheap) and Rust's sort_by
1373    /// has a very low constant factor. The asymptotic O(N log k) vs
1374    /// O(N log N) advantage scales with N and with per-row work — KNN
1375    /// queries where the sort key is `vec_distance_l2(col, [...])` are
1376    /// where this path really pays off, because each key evaluation is
1377    /// itself O(dim) and the heap path skips the per-row evaluation
1378    /// in the comparator (see `sort_rowids` for the contrast).
1379    #[test]
1380    #[ignore]
1381    fn topk_benchmark() {
1382        use std::time::Instant;
1383        const N: usize = 10_000;
1384        const K: usize = 10;
1385
1386        let db = seed_score_table(N);
1387        let table = db.get_table("docs".to_string()).unwrap();
1388        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
1389        let order = q.order_by.as_ref().unwrap();
1390        let all_rowids = table.rowids();
1391
1392        // Time bounded heap.
1393        let t0 = Instant::now();
1394        let _topk = select_topk(&all_rowids, table, order, K).unwrap();
1395        let heap_dur = t0.elapsed();
1396
1397        // Time full sort + truncate.
1398        let t1 = Instant::now();
1399        let mut full = all_rowids.clone();
1400        sort_rowids(&mut full, table, order).unwrap();
1401        full.truncate(K);
1402        let sort_dur = t1.elapsed();
1403
1404        let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
1405        println!("\n--- topk_benchmark (N={N}, k={K}) ---");
1406        println!("  bounded heap:   {heap_dur:?}");
1407        println!("  full sort+trunc: {sort_dur:?}");
1408        println!("  speedup ratio:  {ratio:.2}×");
1409
1410        // Soft assertion. Floor is 1.4× because the cheap-key
1411        // benchmark hovers around 1.8× empirically; setting this too
1412        // close to the measured value risks flaky CI on slower
1413        // runners. Floor of 1.4× still catches an actual regression
1414        // (e.g., if select_topk became O(N²) or stopped using the
1415        // heap entirely).
1416        assert!(
1417            ratio > 1.4,
1418            "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
1419        );
1420    }
1421}