Skip to main content

sqlrite/sql/
executor.rs

1//! Query executors — evaluate parsed SQL statements against the in-memory
2//! storage and produce formatted output.
3
4use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8    AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr, FromTable, FunctionArg,
9    FunctionArgExpr, FunctionArguments, IndexType, ObjectNamePart, Statement, TableFactor,
10    TableWithJoins, UnaryOperator, Update,
11};
12
13use crate::error::{Result, SQLRiteError};
14use crate::sql::db::database::Database;
15use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
16use crate::sql::db::table::{DataType, HnswIndexEntry, Table, Value, parse_vector_literal};
17use crate::sql::hnsw::{DistanceMetric, HnswIndex};
18use crate::sql::parser::select::{OrderByClause, Projection, SelectQuery};
19
20/// Executes a parsed `SelectQuery` against the database and returns a
21/// human-readable rendering of the result set (prettytable). Also returns
22/// the number of rows produced, for the top-level status message.
23/// Structured result of a SELECT: column names in projection order,
24/// and each matching row as a `Vec<Value>` aligned with the columns.
25/// Phase 5a introduced this so the public `Connection` / `Statement`
26/// API has typed rows to yield; the existing `execute_select` that
27/// returns pre-rendered text is now a thin wrapper on top.
28pub struct SelectResult {
29    pub columns: Vec<String>,
30    pub rows: Vec<Vec<Value>>,
31}
32
33/// Executes a SELECT and returns structured rows. The typed rows are
34/// what the new public API streams to callers; the REPL / Tauri app
35/// pre-render into a prettytable via `execute_select`.
36pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
37    let table = db
38        .get_table(query.table_name.clone())
39        .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
40
41    // Resolve projection to a concrete ordered column list.
42    let projected_cols: Vec<String> = match &query.projection {
43        Projection::All => table.column_names(),
44        Projection::Columns(cols) => {
45            for c in cols {
46                if !table.contains_column(c.to_string()) {
47                    return Err(SQLRiteError::Internal(format!(
48                        "Column '{c}' does not exist on table '{}'",
49                        query.table_name
50                    )));
51                }
52            }
53            cols.clone()
54        }
55    };
56
57    // Collect matching rowids. If the WHERE is the shape `col = literal`
58    // and `col` has a secondary index, probe the index for an O(log N)
59    // seek; otherwise fall back to the full table scan.
60    let matching = match select_rowids(table, query.selection.as_ref())? {
61        RowidSource::IndexProbe(rowids) => rowids,
62        RowidSource::FullScan => {
63            let mut out = Vec::new();
64            for rowid in table.rowids() {
65                if let Some(expr) = &query.selection {
66                    if !eval_predicate(expr, table, rowid)? {
67                        continue;
68                    }
69                }
70                out.push(rowid);
71            }
72            out
73        }
74    };
75    let mut matching = matching;
76
77    // Phase 7c — bounded-heap top-k optimization.
78    //
79    // The naive "ORDER BY <expr>" path (Phase 7b) sorts every matching
80    // rowid: O(N log N) sort_by + a truncate. For KNN queries
81    //
82    //     SELECT id FROM docs
83    //     ORDER BY vec_distance_l2(embedding, [...])
84    //     LIMIT 10;
85    //
86    // N is the table row count and k is the LIMIT. With a bounded
87    // max-heap of size k we can find the top-k in O(N log k) — same
88    // sort_by-per-row cost on the heap operations, but k is typically
89    // 10-100 while N can be millions.
90    //
91    // Phase 7d.2 — HNSW ANN probe.
92    //
93    // Even better than the bounded heap: if the ORDER BY expression is
94    // exactly `vec_distance_l2(<col>, <bracket-array literal>)` AND
95    // `<col>` has an HNSW index attached, skip the linear scan
96    // entirely and probe the graph in O(log N). Approximate but
97    // typically ≥ 0.95 recall (verified by the recall tests in
98    // src/sql/hnsw.rs).
99    //
100    // We branch in cases:
101    //   1. ORDER BY + LIMIT k matches the HNSW probe pattern  → graph probe.
102    //   2. ORDER BY + LIMIT k where k < |matching|            → bounded heap (7c).
103    //   3. ORDER BY without LIMIT, or LIMIT >= |matching|     → full sort.
104    //   4. LIMIT without ORDER BY                              → just truncate.
105    match (&query.order_by, query.limit) {
106        (Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
107            matching = try_hnsw_probe(table, &order.expr, k).unwrap();
108        }
109        (Some(order), Some(k)) if k < matching.len() => {
110            matching = select_topk(&matching, table, order, k)?;
111        }
112        (Some(order), _) => {
113            sort_rowids(&mut matching, table, order)?;
114            if let Some(k) = query.limit {
115                matching.truncate(k);
116            }
117        }
118        (None, Some(k)) => {
119            matching.truncate(k);
120        }
121        (None, None) => {}
122    }
123
124    // Build typed rows. Missing cells surface as `Value::Null` — that
125    // maps a column-not-present-for-this-rowid case onto the public
126    // `Row::get` → `Option<T>` surface cleanly.
127    let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
128    for rowid in &matching {
129        let row: Vec<Value> = projected_cols
130            .iter()
131            .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
132            .collect();
133        rows.push(row);
134    }
135
136    Ok(SelectResult {
137        columns: projected_cols,
138        rows,
139    })
140}
141
142/// Executes a SELECT and returns `(rendered_table, row_count)`. The
143/// REPL and Tauri app use this to keep the table-printing behaviour
144/// the engine has always shipped. Structured callers use
145/// `execute_select_rows` instead.
146pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
147    let result = execute_select_rows(query, db)?;
148    let row_count = result.rows.len();
149
150    let mut print_table = PrintTable::new();
151    let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
152    print_table.add_row(PrintRow::new(header_cells));
153
154    for row in &result.rows {
155        let cells: Vec<PrintCell> = row
156            .iter()
157            .map(|v| PrintCell::new(&v.to_display_string()))
158            .collect();
159        print_table.add_row(PrintRow::new(cells));
160    }
161
162    Ok((print_table.to_string(), row_count))
163}
164
165/// Executes a DELETE statement. Returns the number of rows removed.
166pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
167    let Statement::Delete(Delete {
168        from, selection, ..
169    }) = stmt
170    else {
171        return Err(SQLRiteError::Internal(
172            "execute_delete called on a non-DELETE statement".to_string(),
173        ));
174    };
175
176    let tables = match from {
177        FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
178    };
179    let table_name = extract_single_table_name(tables)?;
180
181    // Compute matching rowids with an immutable borrow, then mutate.
182    let matching: Vec<i64> = {
183        let table = db
184            .get_table(table_name.clone())
185            .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
186        match select_rowids(table, selection.as_ref())? {
187            RowidSource::IndexProbe(rowids) => rowids,
188            RowidSource::FullScan => {
189                let mut out = Vec::new();
190                for rowid in table.rowids() {
191                    if let Some(expr) = selection {
192                        if !eval_predicate(expr, table, rowid)? {
193                            continue;
194                        }
195                    }
196                    out.push(rowid);
197                }
198                out
199            }
200        }
201    };
202
203    let table = db.get_table_mut(table_name)?;
204    for rowid in &matching {
205        table.delete_row(*rowid);
206    }
207    // Phase 7d.3 — any DELETE invalidates every HNSW index on this
208    // table (the deleted node could still appear in other nodes'
209    // neighbor lists, breaking subsequent searches). Mark dirty so
210    // the next save rebuilds from current rows before serializing.
211    if !matching.is_empty() {
212        for entry in &mut table.hnsw_indexes {
213            entry.needs_rebuild = true;
214        }
215    }
216    Ok(matching.len())
217}
218
219/// Executes an UPDATE statement. Returns the number of rows updated.
220pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
221    let Statement::Update(Update {
222        table,
223        assignments,
224        from,
225        selection,
226        ..
227    }) = stmt
228    else {
229        return Err(SQLRiteError::Internal(
230            "execute_update called on a non-UPDATE statement".to_string(),
231        ));
232    };
233
234    if from.is_some() {
235        return Err(SQLRiteError::NotImplemented(
236            "UPDATE ... FROM is not supported yet".to_string(),
237        ));
238    }
239
240    let table_name = extract_table_name(table)?;
241
242    // Resolve assignment targets to plain column names and verify they exist.
243    let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
244    {
245        let tbl = db
246            .get_table(table_name.clone())
247            .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
248        for a in assignments {
249            let col = match &a.target {
250                AssignmentTarget::ColumnName(name) => name
251                    .0
252                    .last()
253                    .map(|p| p.to_string())
254                    .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
255                AssignmentTarget::Tuple(_) => {
256                    return Err(SQLRiteError::NotImplemented(
257                        "tuple assignment targets are not supported".to_string(),
258                    ));
259                }
260            };
261            if !tbl.contains_column(col.clone()) {
262                return Err(SQLRiteError::Internal(format!(
263                    "UPDATE references unknown column '{col}'"
264                )));
265            }
266            parsed_assignments.push((col, a.value.clone()));
267        }
268    }
269
270    // Gather matching rowids + the new values to write for each assignment, under
271    // an immutable borrow. Uses the index-probe fast path when the WHERE is
272    // `col = literal` on an indexed column.
273    let work: Vec<(i64, Vec<(String, Value)>)> = {
274        let tbl = db.get_table(table_name.clone())?;
275        let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
276            RowidSource::IndexProbe(rowids) => rowids,
277            RowidSource::FullScan => {
278                let mut out = Vec::new();
279                for rowid in tbl.rowids() {
280                    if let Some(expr) = selection {
281                        if !eval_predicate(expr, tbl, rowid)? {
282                            continue;
283                        }
284                    }
285                    out.push(rowid);
286                }
287                out
288            }
289        };
290        let mut rows_to_update = Vec::new();
291        for rowid in matched_rowids {
292            let mut values = Vec::with_capacity(parsed_assignments.len());
293            for (col, expr) in &parsed_assignments {
294                // UPDATE's RHS is evaluated in the context of the row being updated,
295                // so column references on the right resolve to the current row's values.
296                let v = eval_expr(expr, tbl, rowid)?;
297                values.push((col.clone(), v));
298            }
299            rows_to_update.push((rowid, values));
300        }
301        rows_to_update
302    };
303
304    let tbl = db.get_table_mut(table_name)?;
305    for (rowid, values) in &work {
306        for (col, v) in values {
307            tbl.set_value(col, *rowid, v.clone())?;
308        }
309    }
310
311    // Phase 7d.3 — UPDATE may have changed a vector column that an
312    // HNSW index covers. Mark every covering index dirty so save
313    // rebuilds from current rows. (Updates that only touched
314    // non-vector columns also mark dirty, which is over-conservative
315    // but harmless — the rebuild walks rows anyway, and the cost is
316    // only paid on save.)
317    if !work.is_empty() {
318        let updated_columns: std::collections::HashSet<&str> = work
319            .iter()
320            .flat_map(|(_, values)| values.iter().map(|(c, _)| c.as_str()))
321            .collect();
322        for entry in &mut tbl.hnsw_indexes {
323            if updated_columns.contains(entry.column_name.as_str()) {
324                entry.needs_rebuild = true;
325            }
326        }
327    }
328    Ok(work.len())
329}
330
331/// Handles `CREATE INDEX [UNIQUE] <name> ON <table> [USING <method>] (<column>)`.
332/// Single-column indexes only.
333///
334/// Two flavours, branching on the optional `USING <method>` clause:
335///   - **No USING, or `USING btree`**: regular B-Tree secondary index
336///     (Phase 3e). Indexable types: Integer, Text.
337///   - **`USING hnsw`**: HNSW ANN index (Phase 7d.2). Indexable types:
338///     Vector(N) only. Distance metric is L2 by default; cosine and
339///     dot variants are deferred to Phase 7d.x.
340///
341/// Returns the (possibly synthesized) index name for the status message.
342pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
343    let Statement::CreateIndex(CreateIndex {
344        name,
345        table_name,
346        columns,
347        using,
348        unique,
349        if_not_exists,
350        predicate,
351        ..
352    }) = stmt
353    else {
354        return Err(SQLRiteError::Internal(
355            "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
356        ));
357    };
358
359    if predicate.is_some() {
360        return Err(SQLRiteError::NotImplemented(
361            "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
362        ));
363    }
364
365    if columns.len() != 1 {
366        return Err(SQLRiteError::NotImplemented(format!(
367            "multi-column indexes are not supported yet ({} columns given)",
368            columns.len()
369        )));
370    }
371
372    let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
373        SQLRiteError::NotImplemented(
374            "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
375        )
376    })?;
377
378    // Detect USING <method>. The `using` field on CreateIndex covers the
379    // pre-column form `CREATE INDEX … USING hnsw (col)`. (sqlparser also
380    // accepts a post-column form `… (col) USING hnsw` and parks that in
381    // `index_options`; we don't bother with it — the canonical form is
382    // pre-column and matches PG/pgvector convention.)
383    let method = match using {
384        Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
385            IndexMethod::Hnsw
386        }
387        Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
388            IndexMethod::Btree
389        }
390        Some(other) => {
391            return Err(SQLRiteError::NotImplemented(format!(
392                "CREATE INDEX … USING {other:?} is not supported (try `hnsw` or no USING clause)"
393            )));
394        }
395        None => IndexMethod::Btree,
396    };
397
398    let table_name_str = table_name.to_string();
399    let column_name = match &columns[0].column.expr {
400        Expr::Identifier(ident) => ident.value.clone(),
401        Expr::CompoundIdentifier(parts) => parts
402            .last()
403            .map(|p| p.value.clone())
404            .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
405        other => {
406            return Err(SQLRiteError::NotImplemented(format!(
407                "CREATE INDEX only supports simple column references, got {other:?}"
408            )));
409        }
410    };
411
412    // Validate: table exists, column exists, type matches the index method,
413    // name is unique across both index kinds. Snapshot (rowid, value) pairs
414    // up front under the immutable borrow so the mutable attach later
415    // doesn't fight over `self`.
416    let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
417        let table = db.get_table(table_name_str.clone()).map_err(|_| {
418            SQLRiteError::General(format!(
419                "CREATE INDEX references unknown table '{table_name_str}'"
420            ))
421        })?;
422        if !table.contains_column(column_name.clone()) {
423            return Err(SQLRiteError::General(format!(
424                "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
425            )));
426        }
427        let col = table
428            .columns
429            .iter()
430            .find(|c| c.column_name == column_name)
431            .expect("we just verified the column exists");
432
433        // Name uniqueness check spans BOTH index kinds — a btree and an
434        // hnsw can't share a name.
435        if table.index_by_name(&index_name).is_some()
436            || table.hnsw_indexes.iter().any(|i| i.name == index_name)
437        {
438            if *if_not_exists {
439                return Ok(index_name);
440            }
441            return Err(SQLRiteError::General(format!(
442                "index '{index_name}' already exists"
443            )));
444        }
445        let datatype = clone_datatype(&col.datatype);
446
447        let mut pairs = Vec::new();
448        for rowid in table.rowids() {
449            if let Some(v) = table.get_value(&column_name, rowid) {
450                pairs.push((rowid, v));
451            }
452        }
453        (datatype, pairs)
454    };
455
456    match method {
457        IndexMethod::Btree => create_btree_index(
458            db,
459            &table_name_str,
460            &index_name,
461            &column_name,
462            &datatype,
463            *unique,
464            &existing_rowids_and_values,
465        ),
466        IndexMethod::Hnsw => create_hnsw_index(
467            db,
468            &table_name_str,
469            &index_name,
470            &column_name,
471            &datatype,
472            *unique,
473            &existing_rowids_and_values,
474        ),
475    }
476}
477
478/// `USING <method>` choices recognized by `execute_create_index`. A
479/// missing USING clause defaults to `Btree` so existing CREATE INDEX
480/// statements (Phase 3e) keep working unchanged.
481#[derive(Debug, Clone, Copy)]
482enum IndexMethod {
483    Btree,
484    Hnsw,
485}
486
487/// Builds a Phase 3e B-Tree secondary index and attaches it to the table.
488fn create_btree_index(
489    db: &mut Database,
490    table_name: &str,
491    index_name: &str,
492    column_name: &str,
493    datatype: &DataType,
494    unique: bool,
495    existing: &[(i64, Value)],
496) -> Result<String> {
497    let mut idx = SecondaryIndex::new(
498        index_name.to_string(),
499        table_name.to_string(),
500        column_name.to_string(),
501        datatype,
502        unique,
503        IndexOrigin::Explicit,
504    )?;
505
506    // Populate from existing rows. UNIQUE violations here mean the
507    // existing data already breaks the new index's constraint — a
508    // common source of user confusion, so be explicit.
509    for (rowid, v) in existing {
510        if unique && idx.would_violate_unique(v) {
511            return Err(SQLRiteError::General(format!(
512                "cannot create UNIQUE index '{index_name}': column '{column_name}' \
513                 already contains the duplicate value {}",
514                v.to_display_string()
515            )));
516        }
517        idx.insert(v, *rowid)?;
518    }
519
520    let table_mut = db.get_table_mut(table_name.to_string())?;
521    table_mut.secondary_indexes.push(idx);
522    Ok(index_name.to_string())
523}
524
525/// Builds a Phase 7d.2 HNSW index and attaches it to the table.
526fn create_hnsw_index(
527    db: &mut Database,
528    table_name: &str,
529    index_name: &str,
530    column_name: &str,
531    datatype: &DataType,
532    unique: bool,
533    existing: &[(i64, Value)],
534) -> Result<String> {
535    // HNSW only makes sense on VECTOR columns. Reject anything else
536    // with a clear message — this is the most likely user error.
537    let dim = match datatype {
538        DataType::Vector(d) => *d,
539        other => {
540            return Err(SQLRiteError::General(format!(
541                "USING hnsw requires a VECTOR column; '{column_name}' is {other}"
542            )));
543        }
544    };
545
546    if unique {
547        return Err(SQLRiteError::General(
548            "UNIQUE has no meaning for HNSW indexes".to_string(),
549        ));
550    }
551
552    // Build the in-memory graph. Distance metric is L2 by default
553    // (Phase 7d.2 doesn't yet expose a knob for picking cosine/dot —
554    // see `docs/phase-7-plan.md` for the deferral).
555    //
556    // Seed: hash the index name so different indexes get different
557    // graph topologies, but the same index always gets the same one
558    // — useful when debugging recall / index size.
559    let seed = hash_str_to_seed(index_name);
560    let mut idx = HnswIndex::new(DistanceMetric::L2, seed);
561
562    // Snapshot the (rowid, vector) pairs into a side map so the
563    // get_vec closure below can serve them by id without re-borrowing
564    // the table (we're already holding `existing` — flatten it).
565    let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
566        std::collections::HashMap::with_capacity(existing.len());
567    for (rowid, v) in existing {
568        match v {
569            Value::Vector(vec) => {
570                if vec.len() != dim {
571                    return Err(SQLRiteError::Internal(format!(
572                        "row {rowid} stores a {}-dim vector in column '{column_name}' \
573                         declared as VECTOR({dim}) — schema invariant violated",
574                        vec.len()
575                    )));
576                }
577                vec_map.insert(*rowid, vec.clone());
578            }
579            // Non-vector values (theoretical NULL, type coercion bug)
580            // get skipped — they wouldn't have a sensible graph
581            // position anyway.
582            _ => continue,
583        }
584    }
585
586    for (rowid, _) in existing {
587        if let Some(v) = vec_map.get(rowid) {
588            let v_clone = v.clone();
589            idx.insert(*rowid, &v_clone, |id| {
590                vec_map.get(&id).cloned().unwrap_or_default()
591            });
592        }
593    }
594
595    let table_mut = db.get_table_mut(table_name.to_string())?;
596    table_mut.hnsw_indexes.push(HnswIndexEntry {
597        name: index_name.to_string(),
598        column_name: column_name.to_string(),
599        index: idx,
600        // Freshly built — no DELETE/UPDATE has invalidated it yet.
601        needs_rebuild: false,
602    });
603    Ok(index_name.to_string())
604}
605
606/// Stable, deterministic hash of a string into a u64 RNG seed. FNV-1a;
607/// avoids pulling in `std::hash::DefaultHasher` (which is randomized
608/// per process).
609fn hash_str_to_seed(s: &str) -> u64 {
610    let mut h: u64 = 0xCBF29CE484222325;
611    for b in s.as_bytes() {
612        h ^= *b as u64;
613        h = h.wrapping_mul(0x100000001B3);
614    }
615    h
616}
617
618/// Cheap clone helper — `DataType` intentionally doesn't derive `Clone`
619/// because the enum has no ergonomic reason to be cloneable elsewhere.
620fn clone_datatype(dt: &DataType) -> DataType {
621    match dt {
622        DataType::Integer => DataType::Integer,
623        DataType::Text => DataType::Text,
624        DataType::Real => DataType::Real,
625        DataType::Bool => DataType::Bool,
626        DataType::Vector(dim) => DataType::Vector(*dim),
627        DataType::Json => DataType::Json,
628        DataType::None => DataType::None,
629        DataType::Invalid => DataType::Invalid,
630    }
631}
632
633fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
634    if tables.len() != 1 {
635        return Err(SQLRiteError::NotImplemented(
636            "multi-table DELETE is not supported yet".to_string(),
637        ));
638    }
639    extract_table_name(&tables[0])
640}
641
642fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
643    if !twj.joins.is_empty() {
644        return Err(SQLRiteError::NotImplemented(
645            "JOIN is not supported yet".to_string(),
646        ));
647    }
648    match &twj.relation {
649        TableFactor::Table { name, .. } => Ok(name.to_string()),
650        _ => Err(SQLRiteError::NotImplemented(
651            "only plain table references are supported".to_string(),
652        )),
653    }
654}
655
656/// Tells the executor how to produce its candidate rowid list.
657enum RowidSource {
658    /// The WHERE was simple enough to probe a secondary index directly.
659    /// The `Vec` already contains exactly the rows the index matched;
660    /// no further WHERE evaluation is needed (the probe is precise).
661    IndexProbe(Vec<i64>),
662    /// No applicable index; caller falls back to walking `table.rowids()`
663    /// and evaluating the WHERE on each row.
664    FullScan,
665}
666
667/// Try to satisfy `WHERE` with an index probe. Currently supports the
668/// simplest shape: a single `col = literal` (or `literal = col`) where
669/// `col` is on a secondary index. AND/OR/range predicates fall back to
670/// full scan — those can be layered on later without changing the caller.
671fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
672    let Some(expr) = selection else {
673        return Ok(RowidSource::FullScan);
674    };
675    let Some((col, literal)) = try_extract_equality(expr) else {
676        return Ok(RowidSource::FullScan);
677    };
678    let Some(idx) = table.index_for_column(&col) else {
679        return Ok(RowidSource::FullScan);
680    };
681
682    // Convert the literal into a runtime Value. If the literal type doesn't
683    // match the column's index we still need correct semantics — evaluate
684    // the WHERE against every row. Fall back to full scan.
685    let literal_value = match convert_literal(&literal) {
686        Ok(v) => v,
687        Err(_) => return Ok(RowidSource::FullScan),
688    };
689
690    // Index lookup returns the full list of rowids matching this equality
691    // predicate. For unique indexes that's at most one; for non-unique it
692    // can be many.
693    let mut rowids = idx.lookup(&literal_value);
694    rowids.sort_unstable();
695    Ok(RowidSource::IndexProbe(rowids))
696}
697
698/// Recognizes `expr` as a simple equality on a column reference against a
699/// literal. Returns `(column_name, literal_value)` if the shape matches;
700/// `None` otherwise. Accepts both `col = literal` and `literal = col`.
701fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
702    // Peel off Nested parens so `WHERE (x = 1)` is recognized too.
703    let peeled = match expr {
704        Expr::Nested(inner) => inner.as_ref(),
705        other => other,
706    };
707    let Expr::BinaryOp { left, op, right } = peeled else {
708        return None;
709    };
710    if !matches!(op, BinaryOperator::Eq) {
711        return None;
712    }
713    let col_from = |e: &Expr| -> Option<String> {
714        match e {
715            Expr::Identifier(ident) => Some(ident.value.clone()),
716            Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
717            _ => None,
718        }
719    };
720    let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
721        if let Expr::Value(v) = e {
722            Some(v.value.clone())
723        } else {
724            None
725        }
726    };
727    if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
728        return Some((c, l));
729    }
730    if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
731        return Some((c, l));
732    }
733    None
734}
735
736/// Recognizes the HNSW-probable query pattern and probes the graph
737/// if a matching index exists.
738///
739/// Looks for ORDER BY `vec_distance_l2(<col>, <bracket-array literal>)`
740/// where the table has an HNSW index attached to `<col>`. On a match,
741/// returns the top-k rowids straight from the graph (O(log N)). On
742/// any miss — different function name, no matching index, query
743/// dimension wrong, etc. — returns `None` and the caller falls through
744/// to the bounded-heap brute-force path (7c) or the full sort (7b),
745/// preserving correct results regardless of whether the HNSW pathway
746/// kicked in.
747///
748/// Phase 7d.2 caveats:
749/// - Only `vec_distance_l2` is recognized. Cosine and dot fall through
750///   to brute-force because we don't yet expose a per-index distance
751///   knob (deferred to Phase 7d.x — see `docs/phase-7-plan.md`).
752/// - Only ASCENDING order makes sense for "k nearest" — DESC ORDER BY
753///   `vec_distance_l2(...) LIMIT k` would mean "k farthest", which
754///   isn't what the index is built for. We don't bother to detect
755///   `ascending == false` here; the optimizer just skips and the
756///   fallback path handles it correctly (slower).
757fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
758    if k == 0 {
759        return None;
760    }
761
762    // Pattern-match: order expr must be a function call vec_distance_l2(a, b).
763    let func = match order_expr {
764        Expr::Function(f) => f,
765        _ => return None,
766    };
767    let fname = match func.name.0.as_slice() {
768        [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
769        _ => return None,
770    };
771    if fname != "vec_distance_l2" {
772        return None;
773    }
774
775    // Extract the two args as raw Exprs.
776    let arg_list = match &func.args {
777        FunctionArguments::List(l) => &l.args,
778        _ => return None,
779    };
780    if arg_list.len() != 2 {
781        return None;
782    }
783    let exprs: Vec<&Expr> = arg_list
784        .iter()
785        .filter_map(|a| match a {
786            FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
787            _ => None,
788        })
789        .collect();
790    if exprs.len() != 2 {
791        return None;
792    }
793
794    // One arg must be a column reference (the indexed col); the other
795    // must be a bracket-array literal (the query vector). Try both
796    // orderings — pgvector's idiom puts the column on the left, but
797    // SQL is commutative for distance.
798    let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
799        Some(v) => v,
800        None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
801            Some(v) => v,
802            None => return None,
803        },
804    };
805
806    // Find the HNSW index on this column.
807    let entry = table
808        .hnsw_indexes
809        .iter()
810        .find(|e| e.column_name == col_name)?;
811
812    // Dimension sanity check — the query vector must match the
813    // indexed column's declared dimension. If it doesn't, the brute-
814    // force fallback would also error at the vec_distance_l2 dim-check;
815    // returning None here lets that path produce the user-visible
816    // error message.
817    let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
818        Some(c) => match &c.datatype {
819            DataType::Vector(d) => *d,
820            _ => return None,
821        },
822        None => return None,
823    };
824    if query_vec.len() != declared_dim {
825        return None;
826    }
827
828    // Probe the graph. Vectors are looked up from the table's row
829    // storage — a closure rather than a `&Table` so the algorithm
830    // module stays decoupled from the SQL types.
831    let column_for_closure = col_name.clone();
832    let table_ref = table;
833    let result = entry.index.search(&query_vec, k, |id| {
834        match table_ref.get_value(&column_for_closure, id) {
835            Some(Value::Vector(v)) => v,
836            _ => Vec::new(),
837        }
838    });
839    Some(result)
840}
841
842/// Helper for `try_hnsw_probe`: given two function args, identify which
843/// one is a bare column identifier (the indexed column) and which is a
844/// bracket-array literal (the query vector). Returns
845/// `Some((column_name, query_vec))` on a match, `None` otherwise.
846fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
847    let col_name = match a {
848        Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
849        _ => return None,
850    };
851    let lit_str = match b {
852        Expr::Identifier(ident) if ident.quote_style == Some('[') => {
853            format!("[{}]", ident.value)
854        }
855        _ => return None,
856    };
857    let v = parse_vector_literal(&lit_str).ok()?;
858    Some((col_name, v))
859}
860
861/// One entry in the bounded-heap top-k path. Holds a pre-evaluated
862/// sort key + the rowid it came from. The `asc` flag inverts `Ord`
863/// so a single `BinaryHeap<HeapEntry>` works for both ASC and DESC
864/// without wrapping in `std::cmp::Reverse` at the call site:
865///
866///   - ASC LIMIT k = "k smallest": natural Ord. Max-heap top is the
867///     largest currently kept; new items smaller than top displace.
868///   - DESC LIMIT k = "k largest": Ord reversed. Max-heap top is now
869///     the smallest currently kept (under reversed Ord, smallest
870///     looks largest); new items larger than top displace.
871///
872/// In both cases the displacement test reduces to "new entry < heap top".
873struct HeapEntry {
874    key: Value,
875    rowid: i64,
876    asc: bool,
877}
878
879impl PartialEq for HeapEntry {
880    fn eq(&self, other: &Self) -> bool {
881        self.cmp(other) == Ordering::Equal
882    }
883}
884
885impl Eq for HeapEntry {}
886
887impl PartialOrd for HeapEntry {
888    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
889        Some(self.cmp(other))
890    }
891}
892
893impl Ord for HeapEntry {
894    fn cmp(&self, other: &Self) -> Ordering {
895        let raw = compare_values(Some(&self.key), Some(&other.key));
896        if self.asc { raw } else { raw.reverse() }
897    }
898}
899
900/// Bounded-heap top-k selection. Returns at most `k` rowids in the
901/// caller's desired order (ascending key for `order.ascending`,
902/// descending otherwise).
903///
904/// O(N log k) where N = `matching.len()`. Caller must check
905/// `k < matching.len()` for this to be a win — for k ≥ N the
906/// `sort_rowids` full-sort path is the same asymptotic cost without
907/// the heap overhead.
908fn select_topk(
909    matching: &[i64],
910    table: &Table,
911    order: &OrderByClause,
912    k: usize,
913) -> Result<Vec<i64>> {
914    use std::collections::BinaryHeap;
915
916    if k == 0 || matching.is_empty() {
917        return Ok(Vec::new());
918    }
919
920    let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
921
922    for &rowid in matching {
923        let key = eval_expr(&order.expr, table, rowid)?;
924        let entry = HeapEntry {
925            key,
926            rowid,
927            asc: order.ascending,
928        };
929
930        if heap.len() < k {
931            heap.push(entry);
932        } else {
933            // peek() returns the largest under our direction-aware Ord
934            // — the worst entry currently kept. Displace it iff the
935            // new entry is "better" (i.e. compares Less).
936            if entry < *heap.peek().unwrap() {
937                heap.pop();
938                heap.push(entry);
939            }
940        }
941    }
942
943    // `into_sorted_vec` returns ascending under our direction-aware Ord:
944    //   ASC: ascending by raw key (what we want)
945    //   DESC: ascending under reversed Ord = descending by raw key (what
946    //         we want for an ORDER BY DESC LIMIT k result)
947    Ok(heap
948        .into_sorted_vec()
949        .into_iter()
950        .map(|e| e.rowid)
951        .collect())
952}
953
954fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
955    // Phase 7b: ORDER BY now accepts any expression (column ref,
956    // arithmetic, function call, …). Pre-compute the sort key for
957    // every rowid up front so the comparator is called O(N log N)
958    // times against pre-evaluated Values rather than re-evaluating
959    // the expression O(N log N) times. Not strictly necessary today,
960    // but vital once 7d's HNSW index lands and this same code path
961    // could be running tens of millions of distance computations.
962    let mut keys: Vec<(i64, Result<Value>)> = rowids
963        .iter()
964        .map(|r| (*r, eval_expr(&order.expr, table, *r)))
965        .collect();
966
967    // Surface the FIRST evaluation error if any. We could be lazy
968    // and let sort_by encounter it, but `Ord::cmp` can't return a
969    // Result and we'd have to swallow errors silently.
970    for (_, k) in &keys {
971        if let Err(e) = k {
972            return Err(SQLRiteError::General(format!(
973                "ORDER BY expression failed: {e}"
974            )));
975        }
976    }
977
978    keys.sort_by(|(_, ka), (_, kb)| {
979        // Both unwrap()s are safe — we just verified above that
980        // every key Result is Ok.
981        let va = ka.as_ref().unwrap();
982        let vb = kb.as_ref().unwrap();
983        let ord = compare_values(Some(va), Some(vb));
984        if order.ascending { ord } else { ord.reverse() }
985    });
986
987    // Write the sorted rowids back into the caller's slice.
988    for (i, (rowid, _)) in keys.into_iter().enumerate() {
989        rowids[i] = rowid;
990    }
991    Ok(())
992}
993
994fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
995    match (a, b) {
996        (None, None) => Ordering::Equal,
997        (None, _) => Ordering::Less,
998        (_, None) => Ordering::Greater,
999        (Some(a), Some(b)) => match (a, b) {
1000            (Value::Null, Value::Null) => Ordering::Equal,
1001            (Value::Null, _) => Ordering::Less,
1002            (_, Value::Null) => Ordering::Greater,
1003            (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
1004            (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
1005            (Value::Integer(x), Value::Real(y)) => {
1006                (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
1007            }
1008            (Value::Real(x), Value::Integer(y)) => {
1009                x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
1010            }
1011            (Value::Text(x), Value::Text(y)) => x.cmp(y),
1012            (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
1013            // Cross-type fallback: stringify and compare; keeps ORDER BY total.
1014            (x, y) => x.to_display_string().cmp(&y.to_display_string()),
1015        },
1016    }
1017}
1018
1019/// Returns `true` if the row at `rowid` matches the predicate expression.
1020pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
1021    let v = eval_expr(expr, table, rowid)?;
1022    match v {
1023        Value::Bool(b) => Ok(b),
1024        Value::Null => Ok(false), // SQL NULL in a WHERE is treated as false
1025        Value::Integer(i) => Ok(i != 0),
1026        other => Err(SQLRiteError::Internal(format!(
1027            "WHERE clause must evaluate to boolean, got {}",
1028            other.to_display_string()
1029        ))),
1030    }
1031}
1032
1033fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
1034    match expr {
1035        Expr::Nested(inner) => eval_expr(inner, table, rowid),
1036
1037        Expr::Identifier(ident) => {
1038            // Phase 7b — sqlparser parses bracket-array literals like
1039            // `[0.1, 0.2, 0.3]` as bracket-quoted identifiers (it inherits
1040            // MSSQL `[name]` syntax). When we see `quote_style == Some('[')`
1041            // in expression-evaluation position (SELECT projection, WHERE,
1042            // ORDER BY, function args), parse the bracketed content as a
1043            // vector literal so the rest of the executor can compare /
1044            // distance-compute against it. Same trick the INSERT parser
1045            // uses; the executor needed its own copy because expression
1046            // eval runs on a different code path.
1047            if ident.quote_style == Some('[') {
1048                let raw = format!("[{}]", ident.value);
1049                let v = parse_vector_literal(&raw)?;
1050                return Ok(Value::Vector(v));
1051            }
1052            Ok(table.get_value(&ident.value, rowid).unwrap_or(Value::Null))
1053        }
1054
1055        Expr::CompoundIdentifier(parts) => {
1056            // Accept `table.col` — we only have one table in scope, so ignore the qualifier.
1057            let col = parts
1058                .last()
1059                .map(|i| i.value.as_str())
1060                .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
1061            Ok(table.get_value(col, rowid).unwrap_or(Value::Null))
1062        }
1063
1064        Expr::Value(v) => convert_literal(&v.value),
1065
1066        Expr::UnaryOp { op, expr } => {
1067            let inner = eval_expr(expr, table, rowid)?;
1068            match op {
1069                UnaryOperator::Not => match inner {
1070                    Value::Bool(b) => Ok(Value::Bool(!b)),
1071                    Value::Null => Ok(Value::Null),
1072                    other => Err(SQLRiteError::Internal(format!(
1073                        "NOT applied to non-boolean value: {}",
1074                        other.to_display_string()
1075                    ))),
1076                },
1077                UnaryOperator::Minus => match inner {
1078                    Value::Integer(i) => Ok(Value::Integer(-i)),
1079                    Value::Real(f) => Ok(Value::Real(-f)),
1080                    Value::Null => Ok(Value::Null),
1081                    other => Err(SQLRiteError::Internal(format!(
1082                        "unary minus on non-numeric value: {}",
1083                        other.to_display_string()
1084                    ))),
1085                },
1086                UnaryOperator::Plus => Ok(inner),
1087                other => Err(SQLRiteError::NotImplemented(format!(
1088                    "unary operator {other:?} is not supported"
1089                ))),
1090            }
1091        }
1092
1093        Expr::BinaryOp { left, op, right } => match op {
1094            BinaryOperator::And => {
1095                let l = eval_expr(left, table, rowid)?;
1096                let r = eval_expr(right, table, rowid)?;
1097                Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
1098            }
1099            BinaryOperator::Or => {
1100                let l = eval_expr(left, table, rowid)?;
1101                let r = eval_expr(right, table, rowid)?;
1102                Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
1103            }
1104            cmp @ (BinaryOperator::Eq
1105            | BinaryOperator::NotEq
1106            | BinaryOperator::Lt
1107            | BinaryOperator::LtEq
1108            | BinaryOperator::Gt
1109            | BinaryOperator::GtEq) => {
1110                let l = eval_expr(left, table, rowid)?;
1111                let r = eval_expr(right, table, rowid)?;
1112                // Any comparison involving NULL is unknown → false in a WHERE.
1113                if matches!(l, Value::Null) || matches!(r, Value::Null) {
1114                    return Ok(Value::Bool(false));
1115                }
1116                let ord = compare_values(Some(&l), Some(&r));
1117                let result = match cmp {
1118                    BinaryOperator::Eq => ord == Ordering::Equal,
1119                    BinaryOperator::NotEq => ord != Ordering::Equal,
1120                    BinaryOperator::Lt => ord == Ordering::Less,
1121                    BinaryOperator::LtEq => ord != Ordering::Greater,
1122                    BinaryOperator::Gt => ord == Ordering::Greater,
1123                    BinaryOperator::GtEq => ord != Ordering::Less,
1124                    _ => unreachable!(),
1125                };
1126                Ok(Value::Bool(result))
1127            }
1128            arith @ (BinaryOperator::Plus
1129            | BinaryOperator::Minus
1130            | BinaryOperator::Multiply
1131            | BinaryOperator::Divide
1132            | BinaryOperator::Modulo) => {
1133                let l = eval_expr(left, table, rowid)?;
1134                let r = eval_expr(right, table, rowid)?;
1135                eval_arith(arith, &l, &r)
1136            }
1137            BinaryOperator::StringConcat => {
1138                let l = eval_expr(left, table, rowid)?;
1139                let r = eval_expr(right, table, rowid)?;
1140                if matches!(l, Value::Null) || matches!(r, Value::Null) {
1141                    return Ok(Value::Null);
1142                }
1143                Ok(Value::Text(format!(
1144                    "{}{}",
1145                    l.to_display_string(),
1146                    r.to_display_string()
1147                )))
1148            }
1149            other => Err(SQLRiteError::NotImplemented(format!(
1150                "binary operator {other:?} is not supported yet"
1151            ))),
1152        },
1153
1154        // Phase 7b — function-call dispatch. Currently only the three
1155        // vector-distance functions; this match arm becomes the single
1156        // place to register more SQL functions later (e.g. abs(),
1157        // length(), …) without re-touching the rest of the executor.
1158        //
1159        // Operator forms (`<->` `<=>` `<#>`) are NOT plumbed here: two
1160        // of three don't parse natively in sqlparser (we'd need a
1161        // string-preprocessing pass or a sqlparser fork). Deferred to
1162        // a follow-up sub-phase; see docs/phase-7-plan.md's "Scope
1163        // corrections" note.
1164        Expr::Function(func) => eval_function(func, table, rowid),
1165
1166        other => Err(SQLRiteError::NotImplemented(format!(
1167            "unsupported expression in WHERE/projection: {other:?}"
1168        ))),
1169    }
1170}
1171
1172/// Dispatches an `Expr::Function` to its built-in implementation.
1173/// Currently only the three vec_distance_* functions; other functions
1174/// surface as `NotImplemented` errors with the function name in the
1175/// message so users see what they tried.
1176fn eval_function(func: &sqlparser::ast::Function, table: &Table, rowid: i64) -> Result<Value> {
1177    // Function name lives in `name.0[0]` for unqualified calls. Anything
1178    // qualified (e.g. `pkg.fn(...)`) falls through to NotImplemented.
1179    let name = match func.name.0.as_slice() {
1180        [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1181        _ => {
1182            return Err(SQLRiteError::NotImplemented(format!(
1183                "qualified function names not supported: {:?}",
1184                func.name
1185            )));
1186        }
1187    };
1188
1189    match name.as_str() {
1190        "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
1191            let (a, b) = extract_two_vector_args(&name, &func.args, table, rowid)?;
1192            let dist = match name.as_str() {
1193                "vec_distance_l2" => vec_distance_l2(&a, &b),
1194                "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
1195                "vec_distance_dot" => vec_distance_dot(&a, &b),
1196                _ => unreachable!(),
1197            };
1198            // Widen f32 → f64 for the runtime Value. Vectors are stored
1199            // as f32 (consistent with industry convention for embeddings),
1200            // but the executor's numeric type is f64 so distances slot
1201            // into Value::Real cleanly and can be compared / ordered with
1202            // other reals via the existing arithmetic + comparison paths.
1203            Ok(Value::Real(dist as f64))
1204        }
1205        // Phase 7e — JSON functions. All four parse the JSON text on
1206        // demand (we don't cache parsed values), then resolve a path
1207        // (default `$` = root). The path resolver handles `.key` for
1208        // object access and `[N]` for array index. SQLite-style.
1209        "json_extract" => json_fn_extract(&name, &func.args, table, rowid),
1210        "json_type" => json_fn_type(&name, &func.args, table, rowid),
1211        "json_array_length" => json_fn_array_length(&name, &func.args, table, rowid),
1212        "json_object_keys" => json_fn_object_keys(&name, &func.args, table, rowid),
1213        other => Err(SQLRiteError::NotImplemented(format!(
1214            "unknown function: {other}(...)"
1215        ))),
1216    }
1217}
1218
1219// -----------------------------------------------------------------
1220// Phase 7e — JSON path-extraction functions
1221// -----------------------------------------------------------------
1222
1223/// Extracts the JSON-typed text + optional path string out of a
1224/// function call's args. Used by all four json_* functions.
1225///
1226/// Arity rules (matching SQLite JSON1):
1227///   - 1 arg  → JSON value, path defaults to `$` (root)
1228///   - 2 args → (JSON value, path text)
1229///
1230/// Returns `(json_text, path)` so caller can serde_json::from_str
1231/// + walk_json_path on it.
1232fn extract_json_and_path(
1233    fn_name: &str,
1234    args: &FunctionArguments,
1235    table: &Table,
1236    rowid: i64,
1237) -> Result<(String, String)> {
1238    let arg_list = match args {
1239        FunctionArguments::List(l) => &l.args,
1240        _ => {
1241            return Err(SQLRiteError::General(format!(
1242                "{fn_name}() expects 1 or 2 arguments"
1243            )));
1244        }
1245    };
1246    if !(arg_list.len() == 1 || arg_list.len() == 2) {
1247        return Err(SQLRiteError::General(format!(
1248            "{fn_name}() expects 1 or 2 arguments, got {}",
1249            arg_list.len()
1250        )));
1251    }
1252    // Evaluate first arg → must produce text.
1253    let first_expr = match &arg_list[0] {
1254        FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1255        other => {
1256            return Err(SQLRiteError::NotImplemented(format!(
1257                "{fn_name}() argument 0 has unsupported shape: {other:?}"
1258            )));
1259        }
1260    };
1261    let json_text = match eval_expr(first_expr, table, rowid)? {
1262        Value::Text(s) => s,
1263        Value::Null => {
1264            return Err(SQLRiteError::General(format!(
1265                "{fn_name}() called on NULL — JSON column has no value for this row"
1266            )));
1267        }
1268        other => {
1269            return Err(SQLRiteError::General(format!(
1270                "{fn_name}() argument 0 is not JSON-typed: got {}",
1271                other.to_display_string()
1272            )));
1273        }
1274    };
1275
1276    // Path defaults to root `$` when omitted.
1277    let path = if arg_list.len() == 2 {
1278        let path_expr = match &arg_list[1] {
1279            FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1280            other => {
1281                return Err(SQLRiteError::NotImplemented(format!(
1282                    "{fn_name}() argument 1 has unsupported shape: {other:?}"
1283                )));
1284            }
1285        };
1286        match eval_expr(path_expr, table, rowid)? {
1287            Value::Text(s) => s,
1288            other => {
1289                return Err(SQLRiteError::General(format!(
1290                    "{fn_name}() path argument must be a string literal, got {}",
1291                    other.to_display_string()
1292                )));
1293            }
1294        }
1295    } else {
1296        "$".to_string()
1297    };
1298
1299    Ok((json_text, path))
1300}
1301
1302/// Walks a `serde_json::Value` along a JSONPath subset:
1303///   - `$` is the root
1304///   - `.key` for object access (key may not contain `.` or `[`)
1305///   - `[N]` for array index (N a non-negative integer)
1306///   - chains arbitrarily: `$.foo.bar[0].baz`
1307///
1308/// Returns `Ok(None)` for "path didn't match anything" (NULL in SQL),
1309/// `Err` for malformed paths. Matches SQLite JSON1's semantic
1310/// distinction: missing-key = NULL, malformed-path = error.
1311fn walk_json_path<'a>(
1312    value: &'a serde_json::Value,
1313    path: &str,
1314) -> Result<Option<&'a serde_json::Value>> {
1315    let mut chars = path.chars().peekable();
1316    if chars.next() != Some('$') {
1317        return Err(SQLRiteError::General(format!(
1318            "JSON path must start with '$', got `{path}`"
1319        )));
1320    }
1321    let mut current = value;
1322    while let Some(&c) = chars.peek() {
1323        match c {
1324            '.' => {
1325                chars.next();
1326                let mut key = String::new();
1327                while let Some(&c) = chars.peek() {
1328                    if c == '.' || c == '[' {
1329                        break;
1330                    }
1331                    key.push(c);
1332                    chars.next();
1333                }
1334                if key.is_empty() {
1335                    return Err(SQLRiteError::General(format!(
1336                        "JSON path has empty key after '.' in `{path}`"
1337                    )));
1338                }
1339                match current.get(&key) {
1340                    Some(v) => current = v,
1341                    None => return Ok(None),
1342                }
1343            }
1344            '[' => {
1345                chars.next();
1346                let mut idx_str = String::new();
1347                while let Some(&c) = chars.peek() {
1348                    if c == ']' {
1349                        break;
1350                    }
1351                    idx_str.push(c);
1352                    chars.next();
1353                }
1354                if chars.next() != Some(']') {
1355                    return Err(SQLRiteError::General(format!(
1356                        "JSON path has unclosed `[` in `{path}`"
1357                    )));
1358                }
1359                let idx: usize = idx_str.trim().parse().map_err(|_| {
1360                    SQLRiteError::General(format!(
1361                        "JSON path has non-integer index `[{idx_str}]` in `{path}`"
1362                    ))
1363                })?;
1364                match current.get(idx) {
1365                    Some(v) => current = v,
1366                    None => return Ok(None),
1367                }
1368            }
1369            other => {
1370                return Err(SQLRiteError::General(format!(
1371                    "JSON path has unexpected character `{other}` in `{path}` \
1372                     (expected `.`, `[`, or end-of-path)"
1373                )));
1374            }
1375        }
1376    }
1377    Ok(Some(current))
1378}
1379
1380/// Converts a serde_json scalar to a SQLRite Value. For composite
1381/// types (object, array) returns the JSON-encoded text — callers
1382/// pattern-match on shape from the calling json_* function.
1383fn json_value_to_sql(v: &serde_json::Value) -> Value {
1384    match v {
1385        serde_json::Value::Null => Value::Null,
1386        serde_json::Value::Bool(b) => Value::Bool(*b),
1387        serde_json::Value::Number(n) => {
1388            // Match SQLite: integer if it fits an i64, else f64.
1389            if let Some(i) = n.as_i64() {
1390                Value::Integer(i)
1391            } else if let Some(f) = n.as_f64() {
1392                Value::Real(f)
1393            } else {
1394                Value::Null
1395            }
1396        }
1397        serde_json::Value::String(s) => Value::Text(s.clone()),
1398        // Objects + arrays come out as JSON-encoded text. Same as
1399        // SQLite's json_extract: composite results round-trip through
1400        // text rather than being modeled as a richer Value type.
1401        composite => Value::Text(composite.to_string()),
1402    }
1403}
1404
1405fn json_fn_extract(
1406    name: &str,
1407    args: &FunctionArguments,
1408    table: &Table,
1409    rowid: i64,
1410) -> Result<Value> {
1411    let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
1412    let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
1413        SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
1414    })?;
1415    match walk_json_path(&parsed, &path)? {
1416        Some(v) => Ok(json_value_to_sql(v)),
1417        None => Ok(Value::Null),
1418    }
1419}
1420
1421fn json_fn_type(name: &str, args: &FunctionArguments, table: &Table, rowid: i64) -> Result<Value> {
1422    let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
1423    let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
1424        SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
1425    })?;
1426    let resolved = match walk_json_path(&parsed, &path)? {
1427        Some(v) => v,
1428        None => return Ok(Value::Null),
1429    };
1430    let ty = match resolved {
1431        serde_json::Value::Null => "null",
1432        serde_json::Value::Bool(true) => "true",
1433        serde_json::Value::Bool(false) => "false",
1434        serde_json::Value::Number(n) => {
1435            if n.is_i64() || n.is_u64() {
1436                "integer"
1437            } else {
1438                "real"
1439            }
1440        }
1441        serde_json::Value::String(_) => "text",
1442        serde_json::Value::Array(_) => "array",
1443        serde_json::Value::Object(_) => "object",
1444    };
1445    Ok(Value::Text(ty.to_string()))
1446}
1447
1448fn json_fn_array_length(
1449    name: &str,
1450    args: &FunctionArguments,
1451    table: &Table,
1452    rowid: i64,
1453) -> Result<Value> {
1454    let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
1455    let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
1456        SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
1457    })?;
1458    let resolved = match walk_json_path(&parsed, &path)? {
1459        Some(v) => v,
1460        None => return Ok(Value::Null),
1461    };
1462    match resolved.as_array() {
1463        Some(arr) => Ok(Value::Integer(arr.len() as i64)),
1464        None => Err(SQLRiteError::General(format!(
1465            "{name}() resolved to a non-array value at path `{path}`"
1466        ))),
1467    }
1468}
1469
1470fn json_fn_object_keys(
1471    name: &str,
1472    args: &FunctionArguments,
1473    table: &Table,
1474    rowid: i64,
1475) -> Result<Value> {
1476    let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
1477    let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
1478        SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
1479    })?;
1480    let resolved = match walk_json_path(&parsed, &path)? {
1481        Some(v) => v,
1482        None => return Ok(Value::Null),
1483    };
1484    let obj = resolved.as_object().ok_or_else(|| {
1485        SQLRiteError::General(format!(
1486            "{name}() resolved to a non-object value at path `{path}`"
1487        ))
1488    })?;
1489    // SQLite's json_object_keys is a table-valued function (one row
1490    // per key). Without set-returning function support we can't
1491    // reproduce that shape; instead return the keys as a JSON array
1492    // text. Caller can iterate via json_array_length + json_extract,
1493    // or just treat it as a serialized list. Document this divergence
1494    // in supported-sql.md.
1495    let keys: Vec<serde_json::Value> = obj
1496        .keys()
1497        .map(|k| serde_json::Value::String(k.clone()))
1498        .collect();
1499    Ok(Value::Text(serde_json::Value::Array(keys).to_string()))
1500}
1501
1502/// Extracts exactly two `Vec<f32>` arguments from a function call,
1503/// validating arity and that both sides are Vector-typed with matching
1504/// dimensions. Used by all three vec_distance_* functions.
1505fn extract_two_vector_args(
1506    fn_name: &str,
1507    args: &FunctionArguments,
1508    table: &Table,
1509    rowid: i64,
1510) -> Result<(Vec<f32>, Vec<f32>)> {
1511    let arg_list = match args {
1512        FunctionArguments::List(l) => &l.args,
1513        _ => {
1514            return Err(SQLRiteError::General(format!(
1515                "{fn_name}() expects exactly two vector arguments"
1516            )));
1517        }
1518    };
1519    if arg_list.len() != 2 {
1520        return Err(SQLRiteError::General(format!(
1521            "{fn_name}() expects exactly 2 arguments, got {}",
1522            arg_list.len()
1523        )));
1524    }
1525    let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
1526    for (i, arg) in arg_list.iter().enumerate() {
1527        let expr = match arg {
1528            FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1529            other => {
1530                return Err(SQLRiteError::NotImplemented(format!(
1531                    "{fn_name}() argument {i} has unsupported shape: {other:?}"
1532                )));
1533            }
1534        };
1535        let val = eval_expr(expr, table, rowid)?;
1536        match val {
1537            Value::Vector(v) => out.push(v),
1538            other => {
1539                return Err(SQLRiteError::General(format!(
1540                    "{fn_name}() argument {i} is not a vector: got {}",
1541                    other.to_display_string()
1542                )));
1543            }
1544        }
1545    }
1546    let b = out.pop().unwrap();
1547    let a = out.pop().unwrap();
1548    if a.len() != b.len() {
1549        return Err(SQLRiteError::General(format!(
1550            "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
1551            a.len(),
1552            b.len()
1553        )));
1554    }
1555    Ok((a, b))
1556}
1557
1558/// Euclidean (L2) distance: √Σ(aᵢ − bᵢ)².
1559/// Smaller-is-closer; identical vectors return 0.0.
1560pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
1561    debug_assert_eq!(a.len(), b.len());
1562    let mut sum = 0.0f32;
1563    for i in 0..a.len() {
1564        let d = a[i] - b[i];
1565        sum += d * d;
1566    }
1567    sum.sqrt()
1568}
1569
1570/// Cosine distance: 1 − (a·b) / (‖a‖·‖b‖).
1571/// Smaller-is-closer; identical (non-zero) vectors return 0.0,
1572/// orthogonal vectors return 1.0, opposite-direction vectors return 2.0.
1573///
1574/// Errors if either vector has zero magnitude — cosine similarity is
1575/// undefined for the zero vector and silently returning NaN would
1576/// poison `ORDER BY` ranking. Callers who want the silent-NaN
1577/// behavior can compute `vec_distance_dot(a, b) / (norm(a) * norm(b))`
1578/// themselves.
1579pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
1580    debug_assert_eq!(a.len(), b.len());
1581    let mut dot = 0.0f32;
1582    let mut norm_a_sq = 0.0f32;
1583    let mut norm_b_sq = 0.0f32;
1584    for i in 0..a.len() {
1585        dot += a[i] * b[i];
1586        norm_a_sq += a[i] * a[i];
1587        norm_b_sq += b[i] * b[i];
1588    }
1589    let denom = (norm_a_sq * norm_b_sq).sqrt();
1590    if denom == 0.0 {
1591        return Err(SQLRiteError::General(
1592            "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
1593        ));
1594    }
1595    Ok(1.0 - dot / denom)
1596}
1597
1598/// Negated dot product: −(a·b).
1599/// pgvector convention — negated so smaller-is-closer like L2 / cosine.
1600/// For unit-norm vectors `vec_distance_dot(a, b) == vec_distance_cosine(a, b) - 1`.
1601pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
1602    debug_assert_eq!(a.len(), b.len());
1603    let mut dot = 0.0f32;
1604    for i in 0..a.len() {
1605        dot += a[i] * b[i];
1606    }
1607    -dot
1608}
1609
1610/// Evaluates an integer/real arithmetic op. NULL on either side propagates.
1611/// Mixed Integer/Real promotes to Real. Divide/Modulo by zero → error.
1612fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
1613    if matches!(l, Value::Null) || matches!(r, Value::Null) {
1614        return Ok(Value::Null);
1615    }
1616    match (l, r) {
1617        (Value::Integer(a), Value::Integer(b)) => match op {
1618            BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
1619            BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
1620            BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
1621            BinaryOperator::Divide => {
1622                if *b == 0 {
1623                    Err(SQLRiteError::General("division by zero".to_string()))
1624                } else {
1625                    Ok(Value::Integer(a / b))
1626                }
1627            }
1628            BinaryOperator::Modulo => {
1629                if *b == 0 {
1630                    Err(SQLRiteError::General("modulo by zero".to_string()))
1631                } else {
1632                    Ok(Value::Integer(a % b))
1633                }
1634            }
1635            _ => unreachable!(),
1636        },
1637        // Anything involving a Real promotes both sides to f64.
1638        (a, b) => {
1639            let af = as_number(a)?;
1640            let bf = as_number(b)?;
1641            match op {
1642                BinaryOperator::Plus => Ok(Value::Real(af + bf)),
1643                BinaryOperator::Minus => Ok(Value::Real(af - bf)),
1644                BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
1645                BinaryOperator::Divide => {
1646                    if bf == 0.0 {
1647                        Err(SQLRiteError::General("division by zero".to_string()))
1648                    } else {
1649                        Ok(Value::Real(af / bf))
1650                    }
1651                }
1652                BinaryOperator::Modulo => {
1653                    if bf == 0.0 {
1654                        Err(SQLRiteError::General("modulo by zero".to_string()))
1655                    } else {
1656                        Ok(Value::Real(af % bf))
1657                    }
1658                }
1659                _ => unreachable!(),
1660            }
1661        }
1662    }
1663}
1664
1665fn as_number(v: &Value) -> Result<f64> {
1666    match v {
1667        Value::Integer(i) => Ok(*i as f64),
1668        Value::Real(f) => Ok(*f),
1669        Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
1670        other => Err(SQLRiteError::General(format!(
1671            "arithmetic on non-numeric value '{}'",
1672            other.to_display_string()
1673        ))),
1674    }
1675}
1676
1677fn as_bool(v: &Value) -> Result<bool> {
1678    match v {
1679        Value::Bool(b) => Ok(*b),
1680        Value::Null => Ok(false),
1681        Value::Integer(i) => Ok(*i != 0),
1682        other => Err(SQLRiteError::Internal(format!(
1683            "expected boolean, got {}",
1684            other.to_display_string()
1685        ))),
1686    }
1687}
1688
1689fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
1690    use sqlparser::ast::Value as AstValue;
1691    match v {
1692        AstValue::Number(n, _) => {
1693            if let Ok(i) = n.parse::<i64>() {
1694                Ok(Value::Integer(i))
1695            } else if let Ok(f) = n.parse::<f64>() {
1696                Ok(Value::Real(f))
1697            } else {
1698                Err(SQLRiteError::Internal(format!(
1699                    "could not parse numeric literal '{n}'"
1700                )))
1701            }
1702        }
1703        AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
1704        AstValue::Boolean(b) => Ok(Value::Bool(*b)),
1705        AstValue::Null => Ok(Value::Null),
1706        other => Err(SQLRiteError::NotImplemented(format!(
1707            "unsupported literal value: {other:?}"
1708        ))),
1709    }
1710}
1711
1712#[cfg(test)]
1713mod tests {
1714    use super::*;
1715
1716    // -----------------------------------------------------------------
1717    // Phase 7b — Vector distance function math
1718    // -----------------------------------------------------------------
1719
1720    /// Float comparison helper — distance results need a small epsilon
1721    /// because we accumulate sums across many f32 multiplies.
1722    fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
1723        (a - b).abs() < eps
1724    }
1725
1726    #[test]
1727    fn vec_distance_l2_identical_is_zero() {
1728        let v = vec![0.1, 0.2, 0.3];
1729        assert_eq!(vec_distance_l2(&v, &v), 0.0);
1730    }
1731
1732    #[test]
1733    fn vec_distance_l2_unit_basis_is_sqrt2() {
1734        // [1, 0] vs [0, 1]: distance = √((1-0)² + (0-1)²) = √2 ≈ 1.414
1735        let a = vec![1.0, 0.0];
1736        let b = vec![0.0, 1.0];
1737        assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
1738    }
1739
1740    #[test]
1741    fn vec_distance_l2_known_value() {
1742        // [0, 0, 0] vs [3, 4, 0]: √(9 + 16 + 0) = 5 (the classic 3-4-5 triangle).
1743        let a = vec![0.0, 0.0, 0.0];
1744        let b = vec![3.0, 4.0, 0.0];
1745        assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
1746    }
1747
1748    #[test]
1749    fn vec_distance_cosine_identical_is_zero() {
1750        let v = vec![0.1, 0.2, 0.3];
1751        let d = vec_distance_cosine(&v, &v).unwrap();
1752        assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
1753    }
1754
1755    #[test]
1756    fn vec_distance_cosine_orthogonal_is_one() {
1757        // Two orthogonal unit vectors should have cosine distance = 1.0
1758        // (cosine similarity = 0 → distance = 1 - 0 = 1).
1759        let a = vec![1.0, 0.0];
1760        let b = vec![0.0, 1.0];
1761        assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
1762    }
1763
1764    #[test]
1765    fn vec_distance_cosine_opposite_is_two() {
1766        // a and -a have cosine similarity = -1 → distance = 1 - (-1) = 2.
1767        let a = vec![1.0, 0.0, 0.0];
1768        let b = vec![-1.0, 0.0, 0.0];
1769        assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
1770    }
1771
1772    #[test]
1773    fn vec_distance_cosine_zero_magnitude_errors() {
1774        // Cosine is undefined for the zero vector — error rather than NaN.
1775        let a = vec![0.0, 0.0];
1776        let b = vec![1.0, 0.0];
1777        let err = vec_distance_cosine(&a, &b).unwrap_err();
1778        assert!(format!("{err}").contains("zero-magnitude"));
1779    }
1780
1781    #[test]
1782    fn vec_distance_dot_negates() {
1783        // a·b = 1*4 + 2*5 + 3*6 = 32. Negated → -32.
1784        let a = vec![1.0, 2.0, 3.0];
1785        let b = vec![4.0, 5.0, 6.0];
1786        assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
1787    }
1788
1789    #[test]
1790    fn vec_distance_dot_orthogonal_is_zero() {
1791        // Orthogonal vectors have dot product 0 → negated is also 0.
1792        let a = vec![1.0, 0.0];
1793        let b = vec![0.0, 1.0];
1794        assert_eq!(vec_distance_dot(&a, &b), 0.0);
1795    }
1796
1797    #[test]
1798    fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
1799        // For unit-norm vectors: dot(a,b) = cos(a,b)
1800        // → -dot(a,b) = -cos(a,b) = (1 - cos(a,b)) - 1 = vec_distance_cosine(a,b) - 1.
1801        // Useful sanity check that the two functions agree on unit vectors.
1802        let a = vec![0.6f32, 0.8]; // unit norm: √(0.36+0.64) = 1
1803        let b = vec![0.8f32, 0.6]; // unit norm too
1804        let dot = vec_distance_dot(&a, &b);
1805        let cos = vec_distance_cosine(&a, &b).unwrap();
1806        assert!(approx_eq(dot, cos - 1.0, 1e-5));
1807    }
1808
1809    // -----------------------------------------------------------------
1810    // Phase 7c — bounded-heap top-k correctness + benchmark
1811    // -----------------------------------------------------------------
1812
1813    use crate::sql::db::database::Database;
1814    use crate::sql::parser::select::SelectQuery;
1815    use sqlparser::dialect::SQLiteDialect;
1816    use sqlparser::parser::Parser;
1817
1818    /// Builds a `docs(id INTEGER PK, score REAL)` table with N rows of
1819    /// distinct positive scores so top-k tests aren't sensitive to
1820    /// tie-breaking (heap is unstable; full-sort is stable; we want
1821    /// both to agree without arguing about equal-score row order).
1822    ///
1823    /// **Why positive scores:** the INSERT parser doesn't currently
1824    /// handle `Expr::UnaryOp(Minus, …)` for negative number literals
1825    /// (it would parse `-3.14` as a unary expression and the value
1826    /// extractor would skip it). That's a pre-existing bug, out of
1827    /// scope for 7c. Using the Knuth multiplicative hash gives us
1828    /// distinct positive scrambled values without dancing around the
1829    /// negative-literal limitation.
1830    fn seed_score_table(n: usize) -> Database {
1831        let mut db = Database::new("tempdb".to_string());
1832        crate::sql::process_command(
1833            "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
1834            &mut db,
1835        )
1836        .expect("create");
1837        for i in 0..n {
1838            // Knuth multiplicative hash mod 1_000_000 — distinct,
1839            // dense in [0, 999_999], no collisions for n up to ~tens
1840            // of thousands.
1841            let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
1842            let sql = format!("INSERT INTO docs (score) VALUES ({score});");
1843            crate::sql::process_command(&sql, &mut db).expect("insert");
1844        }
1845        db
1846    }
1847
1848    /// Helper: parses an SQL SELECT into a SelectQuery so we can drive
1849    /// `select_topk` / `sort_rowids` directly without the rest of the
1850    /// process_command pipeline.
1851    fn parse_select(sql: &str) -> SelectQuery {
1852        let dialect = SQLiteDialect {};
1853        let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
1854        let stmt = ast.pop().expect("one statement");
1855        SelectQuery::new(&stmt).expect("select-query")
1856    }
1857
1858    #[test]
1859    fn topk_matches_full_sort_asc() {
1860        // Build N=200, top-k=10. Bounded heap output must equal
1861        // full-sort-then-truncate output (both produce ASC order).
1862        let db = seed_score_table(200);
1863        let table = db.get_table("docs".to_string()).unwrap();
1864        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
1865        let order = q.order_by.as_ref().unwrap();
1866        let all_rowids = table.rowids();
1867
1868        // Full-sort path
1869        let mut full = all_rowids.clone();
1870        sort_rowids(&mut full, table, order).unwrap();
1871        full.truncate(10);
1872
1873        // Bounded-heap path
1874        let topk = select_topk(&all_rowids, table, order, 10).unwrap();
1875
1876        assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
1877    }
1878
1879    #[test]
1880    fn topk_matches_full_sort_desc() {
1881        // Same with DESC — verifies the direction-aware Ord wrapper.
1882        let db = seed_score_table(200);
1883        let table = db.get_table("docs".to_string()).unwrap();
1884        let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
1885        let order = q.order_by.as_ref().unwrap();
1886        let all_rowids = table.rowids();
1887
1888        let mut full = all_rowids.clone();
1889        sort_rowids(&mut full, table, order).unwrap();
1890        full.truncate(10);
1891
1892        let topk = select_topk(&all_rowids, table, order, 10).unwrap();
1893
1894        assert_eq!(
1895            topk, full,
1896            "top-k DESC via heap should match full-sort+truncate"
1897        );
1898    }
1899
1900    #[test]
1901    fn topk_k_larger_than_n_returns_everything_sorted() {
1902        // The executor branches off to the full-sort path when k >= N,
1903        // but if a caller invokes select_topk directly with k > N, it
1904        // should still produce all-sorted output (no truncation
1905        // because we don't have N items to truncate to k).
1906        let db = seed_score_table(50);
1907        let table = db.get_table("docs".to_string()).unwrap();
1908        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
1909        let order = q.order_by.as_ref().unwrap();
1910        let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
1911        assert_eq!(topk.len(), 50);
1912        // All scores in ascending order.
1913        let scores: Vec<f64> = topk
1914            .iter()
1915            .filter_map(|r| match table.get_value("score", *r) {
1916                Some(Value::Real(f)) => Some(f),
1917                _ => None,
1918            })
1919            .collect();
1920        assert!(scores.windows(2).all(|w| w[0] <= w[1]));
1921    }
1922
1923    #[test]
1924    fn topk_k_zero_returns_empty() {
1925        let db = seed_score_table(10);
1926        let table = db.get_table("docs".to_string()).unwrap();
1927        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
1928        let order = q.order_by.as_ref().unwrap();
1929        let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
1930        assert!(topk.is_empty());
1931    }
1932
1933    #[test]
1934    fn topk_empty_input_returns_empty() {
1935        let db = seed_score_table(0);
1936        let table = db.get_table("docs".to_string()).unwrap();
1937        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
1938        let order = q.order_by.as_ref().unwrap();
1939        let topk = select_topk(&[], table, order, 5).unwrap();
1940        assert!(topk.is_empty());
1941    }
1942
1943    #[test]
1944    fn topk_works_through_select_executor_with_distance_function() {
1945        // Integration check that the executor actually picks the
1946        // bounded-heap path on a KNN-shaped query and produces the
1947        // correct top-k.
1948        let mut db = Database::new("tempdb".to_string());
1949        crate::sql::process_command(
1950            "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
1951            &mut db,
1952        )
1953        .unwrap();
1954        // Five rows with distinct distances from probe [1.0, 0.0]:
1955        //   id=1 [1.0, 0.0]   distance=0
1956        //   id=2 [2.0, 0.0]   distance=1
1957        //   id=3 [0.0, 3.0]   distance=√(1+9) = √10 ≈ 3.16
1958        //   id=4 [1.0, 4.0]   distance=4
1959        //   id=5 [10.0, 10.0] distance=√(81+100) ≈ 13.45
1960        for v in &[
1961            "[1.0, 0.0]",
1962            "[2.0, 0.0]",
1963            "[0.0, 3.0]",
1964            "[1.0, 4.0]",
1965            "[10.0, 10.0]",
1966        ] {
1967            crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
1968                .unwrap();
1969        }
1970        let resp = crate::sql::process_command(
1971            "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
1972            &mut db,
1973        )
1974        .unwrap();
1975        // Top-3 closest to [1.0, 0.0] are id=1, id=2, id=3 (in that order).
1976        // The status message tells us how many rows came back.
1977        assert!(resp.contains("3 rows returned"), "got: {resp}");
1978    }
1979
1980    /// Manual benchmark — not run by default. Recommended invocation:
1981    ///
1982    ///     cargo test -p sqlrite-engine --lib topk_benchmark --release \
1983    ///         -- --ignored --nocapture
1984    ///
1985    /// (`--release` matters: Rust's optimized sort gets very fast under
1986    /// optimization, so the heap's relative advantage is best observed
1987    /// against a sort that's also been optimized.)
1988    ///
1989    /// Measured numbers on an Apple Silicon laptop with N=10_000 + k=10:
1990    ///   - bounded heap:    ~820µs
1991    ///   - full sort+trunc: ~1.5ms
1992    ///   - ratio:           ~1.8×
1993    ///
1994    /// The advantage is real but moderate at this size because the sort
1995    /// key here is a single REAL column read (cheap) and Rust's sort_by
1996    /// has a very low constant factor. The asymptotic O(N log k) vs
1997    /// O(N log N) advantage scales with N and with per-row work — KNN
1998    /// queries where the sort key is `vec_distance_l2(col, [...])` are
1999    /// where this path really pays off, because each key evaluation is
2000    /// itself O(dim) and the heap path skips the per-row evaluation
2001    /// in the comparator (see `sort_rowids` for the contrast).
2002    #[test]
2003    #[ignore]
2004    fn topk_benchmark() {
2005        use std::time::Instant;
2006        const N: usize = 10_000;
2007        const K: usize = 10;
2008
2009        let db = seed_score_table(N);
2010        let table = db.get_table("docs".to_string()).unwrap();
2011        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
2012        let order = q.order_by.as_ref().unwrap();
2013        let all_rowids = table.rowids();
2014
2015        // Time bounded heap.
2016        let t0 = Instant::now();
2017        let _topk = select_topk(&all_rowids, table, order, K).unwrap();
2018        let heap_dur = t0.elapsed();
2019
2020        // Time full sort + truncate.
2021        let t1 = Instant::now();
2022        let mut full = all_rowids.clone();
2023        sort_rowids(&mut full, table, order).unwrap();
2024        full.truncate(K);
2025        let sort_dur = t1.elapsed();
2026
2027        let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
2028        println!("\n--- topk_benchmark (N={N}, k={K}) ---");
2029        println!("  bounded heap:   {heap_dur:?}");
2030        println!("  full sort+trunc: {sort_dur:?}");
2031        println!("  speedup ratio:  {ratio:.2}×");
2032
2033        // Soft assertion. Floor is 1.4× because the cheap-key
2034        // benchmark hovers around 1.8× empirically; setting this too
2035        // close to the measured value risks flaky CI on slower
2036        // runners. Floor of 1.4× still catches an actual regression
2037        // (e.g., if select_topk became O(N²) or stopped using the
2038        // heap entirely).
2039        assert!(
2040            ratio > 1.4,
2041            "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
2042        );
2043    }
2044}