1use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8 AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr, FromTable, FunctionArg,
9 FunctionArgExpr, FunctionArguments, ObjectNamePart, Statement, TableFactor, TableWithJoins,
10 UnaryOperator, Update,
11};
12
13use crate::error::{Result, SQLRiteError};
14use crate::sql::db::database::Database;
15use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
16use crate::sql::db::table::{DataType, Table, Value, parse_vector_literal};
17use crate::sql::parser::select::{OrderByClause, Projection, SelectQuery};
18
19pub struct SelectResult {
28 pub columns: Vec<String>,
29 pub rows: Vec<Vec<Value>>,
30}
31
32pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
36 let table = db
37 .get_table(query.table_name.clone())
38 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
39
40 let projected_cols: Vec<String> = match &query.projection {
42 Projection::All => table.column_names(),
43 Projection::Columns(cols) => {
44 for c in cols {
45 if !table.contains_column(c.to_string()) {
46 return Err(SQLRiteError::Internal(format!(
47 "Column '{c}' does not exist on table '{}'",
48 query.table_name
49 )));
50 }
51 }
52 cols.clone()
53 }
54 };
55
56 let matching = match select_rowids(table, query.selection.as_ref())? {
60 RowidSource::IndexProbe(rowids) => rowids,
61 RowidSource::FullScan => {
62 let mut out = Vec::new();
63 for rowid in table.rowids() {
64 if let Some(expr) = &query.selection {
65 if !eval_predicate(expr, table, rowid)? {
66 continue;
67 }
68 }
69 out.push(rowid);
70 }
71 out
72 }
73 };
74 let mut matching = matching;
75
76 if let Some(order) = &query.order_by {
78 sort_rowids(&mut matching, table, order)?;
79 }
80
81 if let Some(n) = query.limit {
82 matching.truncate(n);
83 }
84
85 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
89 for rowid in &matching {
90 let row: Vec<Value> = projected_cols
91 .iter()
92 .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
93 .collect();
94 rows.push(row);
95 }
96
97 Ok(SelectResult {
98 columns: projected_cols,
99 rows,
100 })
101}
102
103pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
108 let result = execute_select_rows(query, db)?;
109 let row_count = result.rows.len();
110
111 let mut print_table = PrintTable::new();
112 let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
113 print_table.add_row(PrintRow::new(header_cells));
114
115 for row in &result.rows {
116 let cells: Vec<PrintCell> = row
117 .iter()
118 .map(|v| PrintCell::new(&v.to_display_string()))
119 .collect();
120 print_table.add_row(PrintRow::new(cells));
121 }
122
123 Ok((print_table.to_string(), row_count))
124}
125
126pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
128 let Statement::Delete(Delete {
129 from, selection, ..
130 }) = stmt
131 else {
132 return Err(SQLRiteError::Internal(
133 "execute_delete called on a non-DELETE statement".to_string(),
134 ));
135 };
136
137 let tables = match from {
138 FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
139 };
140 let table_name = extract_single_table_name(tables)?;
141
142 let matching: Vec<i64> = {
144 let table = db
145 .get_table(table_name.clone())
146 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
147 match select_rowids(table, selection.as_ref())? {
148 RowidSource::IndexProbe(rowids) => rowids,
149 RowidSource::FullScan => {
150 let mut out = Vec::new();
151 for rowid in table.rowids() {
152 if let Some(expr) = selection {
153 if !eval_predicate(expr, table, rowid)? {
154 continue;
155 }
156 }
157 out.push(rowid);
158 }
159 out
160 }
161 }
162 };
163
164 let table = db.get_table_mut(table_name)?;
165 for rowid in &matching {
166 table.delete_row(*rowid);
167 }
168 Ok(matching.len())
169}
170
171pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
173 let Statement::Update(Update {
174 table,
175 assignments,
176 from,
177 selection,
178 ..
179 }) = stmt
180 else {
181 return Err(SQLRiteError::Internal(
182 "execute_update called on a non-UPDATE statement".to_string(),
183 ));
184 };
185
186 if from.is_some() {
187 return Err(SQLRiteError::NotImplemented(
188 "UPDATE ... FROM is not supported yet".to_string(),
189 ));
190 }
191
192 let table_name = extract_table_name(table)?;
193
194 let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
196 {
197 let tbl = db
198 .get_table(table_name.clone())
199 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
200 for a in assignments {
201 let col = match &a.target {
202 AssignmentTarget::ColumnName(name) => name
203 .0
204 .last()
205 .map(|p| p.to_string())
206 .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
207 AssignmentTarget::Tuple(_) => {
208 return Err(SQLRiteError::NotImplemented(
209 "tuple assignment targets are not supported".to_string(),
210 ));
211 }
212 };
213 if !tbl.contains_column(col.clone()) {
214 return Err(SQLRiteError::Internal(format!(
215 "UPDATE references unknown column '{col}'"
216 )));
217 }
218 parsed_assignments.push((col, a.value.clone()));
219 }
220 }
221
222 let work: Vec<(i64, Vec<(String, Value)>)> = {
226 let tbl = db.get_table(table_name.clone())?;
227 let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
228 RowidSource::IndexProbe(rowids) => rowids,
229 RowidSource::FullScan => {
230 let mut out = Vec::new();
231 for rowid in tbl.rowids() {
232 if let Some(expr) = selection {
233 if !eval_predicate(expr, tbl, rowid)? {
234 continue;
235 }
236 }
237 out.push(rowid);
238 }
239 out
240 }
241 };
242 let mut rows_to_update = Vec::new();
243 for rowid in matched_rowids {
244 let mut values = Vec::with_capacity(parsed_assignments.len());
245 for (col, expr) in &parsed_assignments {
246 let v = eval_expr(expr, tbl, rowid)?;
249 values.push((col.clone(), v));
250 }
251 rows_to_update.push((rowid, values));
252 }
253 rows_to_update
254 };
255
256 let tbl = db.get_table_mut(table_name)?;
257 for (rowid, values) in &work {
258 for (col, v) in values {
259 tbl.set_value(col, *rowid, v.clone())?;
260 }
261 }
262 Ok(work.len())
263}
264
265pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
269 let Statement::CreateIndex(CreateIndex {
270 name,
271 table_name,
272 columns,
273 unique,
274 if_not_exists,
275 predicate,
276 ..
277 }) = stmt
278 else {
279 return Err(SQLRiteError::Internal(
280 "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
281 ));
282 };
283
284 if predicate.is_some() {
285 return Err(SQLRiteError::NotImplemented(
286 "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
287 ));
288 }
289
290 if columns.len() != 1 {
291 return Err(SQLRiteError::NotImplemented(format!(
292 "multi-column indexes are not supported yet ({} columns given)",
293 columns.len()
294 )));
295 }
296
297 let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
298 SQLRiteError::NotImplemented(
299 "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
300 )
301 })?;
302
303 let table_name_str = table_name.to_string();
304 let column_name = match &columns[0].column.expr {
305 Expr::Identifier(ident) => ident.value.clone(),
306 Expr::CompoundIdentifier(parts) => parts
307 .last()
308 .map(|p| p.value.clone())
309 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
310 other => {
311 return Err(SQLRiteError::NotImplemented(format!(
312 "CREATE INDEX only supports simple column references, got {other:?}"
313 )));
314 }
315 };
316
317 let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
319 let table = db.get_table(table_name_str.clone()).map_err(|_| {
320 SQLRiteError::General(format!(
321 "CREATE INDEX references unknown table '{table_name_str}'"
322 ))
323 })?;
324 if !table.contains_column(column_name.clone()) {
325 return Err(SQLRiteError::General(format!(
326 "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
327 )));
328 }
329 let col = table
330 .columns
331 .iter()
332 .find(|c| c.column_name == column_name)
333 .expect("we just verified the column exists");
334 if table.index_by_name(&index_name).is_some() {
335 if *if_not_exists {
336 return Ok(index_name);
337 }
338 return Err(SQLRiteError::General(format!(
339 "index '{index_name}' already exists"
340 )));
341 }
342 let datatype = clone_datatype(&col.datatype);
343
344 let mut pairs = Vec::new();
348 for rowid in table.rowids() {
349 if let Some(v) = table.get_value(&column_name, rowid) {
350 pairs.push((rowid, v));
351 }
352 }
353 (datatype, pairs)
354 };
355
356 let mut idx = SecondaryIndex::new(
358 index_name.clone(),
359 table_name_str.clone(),
360 column_name.clone(),
361 &datatype,
362 *unique,
363 IndexOrigin::Explicit,
364 )?;
365
366 for (rowid, v) in &existing_rowids_and_values {
370 if *unique && idx.would_violate_unique(v) {
371 return Err(SQLRiteError::General(format!(
372 "cannot create UNIQUE index '{index_name}': column '{column_name}' \
373 already contains the duplicate value {}",
374 v.to_display_string()
375 )));
376 }
377 idx.insert(v, *rowid)?;
378 }
379
380 let table_mut = db.get_table_mut(table_name_str)?;
382 table_mut.secondary_indexes.push(idx);
383 Ok(index_name)
384}
385
386fn clone_datatype(dt: &DataType) -> DataType {
389 match dt {
390 DataType::Integer => DataType::Integer,
391 DataType::Text => DataType::Text,
392 DataType::Real => DataType::Real,
393 DataType::Bool => DataType::Bool,
394 DataType::Vector(dim) => DataType::Vector(*dim),
395 DataType::None => DataType::None,
396 DataType::Invalid => DataType::Invalid,
397 }
398}
399
400fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
401 if tables.len() != 1 {
402 return Err(SQLRiteError::NotImplemented(
403 "multi-table DELETE is not supported yet".to_string(),
404 ));
405 }
406 extract_table_name(&tables[0])
407}
408
409fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
410 if !twj.joins.is_empty() {
411 return Err(SQLRiteError::NotImplemented(
412 "JOIN is not supported yet".to_string(),
413 ));
414 }
415 match &twj.relation {
416 TableFactor::Table { name, .. } => Ok(name.to_string()),
417 _ => Err(SQLRiteError::NotImplemented(
418 "only plain table references are supported".to_string(),
419 )),
420 }
421}
422
423enum RowidSource {
425 IndexProbe(Vec<i64>),
429 FullScan,
432}
433
434fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
439 let Some(expr) = selection else {
440 return Ok(RowidSource::FullScan);
441 };
442 let Some((col, literal)) = try_extract_equality(expr) else {
443 return Ok(RowidSource::FullScan);
444 };
445 let Some(idx) = table.index_for_column(&col) else {
446 return Ok(RowidSource::FullScan);
447 };
448
449 let literal_value = match convert_literal(&literal) {
453 Ok(v) => v,
454 Err(_) => return Ok(RowidSource::FullScan),
455 };
456
457 let mut rowids = idx.lookup(&literal_value);
461 rowids.sort_unstable();
462 Ok(RowidSource::IndexProbe(rowids))
463}
464
465fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
469 let peeled = match expr {
471 Expr::Nested(inner) => inner.as_ref(),
472 other => other,
473 };
474 let Expr::BinaryOp { left, op, right } = peeled else {
475 return None;
476 };
477 if !matches!(op, BinaryOperator::Eq) {
478 return None;
479 }
480 let col_from = |e: &Expr| -> Option<String> {
481 match e {
482 Expr::Identifier(ident) => Some(ident.value.clone()),
483 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
484 _ => None,
485 }
486 };
487 let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
488 if let Expr::Value(v) = e {
489 Some(v.value.clone())
490 } else {
491 None
492 }
493 };
494 if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
495 return Some((c, l));
496 }
497 if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
498 return Some((c, l));
499 }
500 None
501}
502
503fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
504 let mut keys: Vec<(i64, Result<Value>)> = rowids
512 .iter()
513 .map(|r| (*r, eval_expr(&order.expr, table, *r)))
514 .collect();
515
516 for (_, k) in &keys {
520 if let Err(e) = k {
521 return Err(SQLRiteError::General(format!(
522 "ORDER BY expression failed: {e}"
523 )));
524 }
525 }
526
527 keys.sort_by(|(_, ka), (_, kb)| {
528 let va = ka.as_ref().unwrap();
531 let vb = kb.as_ref().unwrap();
532 let ord = compare_values(Some(va), Some(vb));
533 if order.ascending { ord } else { ord.reverse() }
534 });
535
536 for (i, (rowid, _)) in keys.into_iter().enumerate() {
538 rowids[i] = rowid;
539 }
540 Ok(())
541}
542
543fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
544 match (a, b) {
545 (None, None) => Ordering::Equal,
546 (None, _) => Ordering::Less,
547 (_, None) => Ordering::Greater,
548 (Some(a), Some(b)) => match (a, b) {
549 (Value::Null, Value::Null) => Ordering::Equal,
550 (Value::Null, _) => Ordering::Less,
551 (_, Value::Null) => Ordering::Greater,
552 (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
553 (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
554 (Value::Integer(x), Value::Real(y)) => {
555 (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
556 }
557 (Value::Real(x), Value::Integer(y)) => {
558 x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
559 }
560 (Value::Text(x), Value::Text(y)) => x.cmp(y),
561 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
562 (x, y) => x.to_display_string().cmp(&y.to_display_string()),
564 },
565 }
566}
567
568pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
570 let v = eval_expr(expr, table, rowid)?;
571 match v {
572 Value::Bool(b) => Ok(b),
573 Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
575 other => Err(SQLRiteError::Internal(format!(
576 "WHERE clause must evaluate to boolean, got {}",
577 other.to_display_string()
578 ))),
579 }
580}
581
582fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
583 match expr {
584 Expr::Nested(inner) => eval_expr(inner, table, rowid),
585
586 Expr::Identifier(ident) => {
587 if ident.quote_style == Some('[') {
597 let raw = format!("[{}]", ident.value);
598 let v = parse_vector_literal(&raw)?;
599 return Ok(Value::Vector(v));
600 }
601 Ok(table.get_value(&ident.value, rowid).unwrap_or(Value::Null))
602 }
603
604 Expr::CompoundIdentifier(parts) => {
605 let col = parts
607 .last()
608 .map(|i| i.value.as_str())
609 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
610 Ok(table.get_value(col, rowid).unwrap_or(Value::Null))
611 }
612
613 Expr::Value(v) => convert_literal(&v.value),
614
615 Expr::UnaryOp { op, expr } => {
616 let inner = eval_expr(expr, table, rowid)?;
617 match op {
618 UnaryOperator::Not => match inner {
619 Value::Bool(b) => Ok(Value::Bool(!b)),
620 Value::Null => Ok(Value::Null),
621 other => Err(SQLRiteError::Internal(format!(
622 "NOT applied to non-boolean value: {}",
623 other.to_display_string()
624 ))),
625 },
626 UnaryOperator::Minus => match inner {
627 Value::Integer(i) => Ok(Value::Integer(-i)),
628 Value::Real(f) => Ok(Value::Real(-f)),
629 Value::Null => Ok(Value::Null),
630 other => Err(SQLRiteError::Internal(format!(
631 "unary minus on non-numeric value: {}",
632 other.to_display_string()
633 ))),
634 },
635 UnaryOperator::Plus => Ok(inner),
636 other => Err(SQLRiteError::NotImplemented(format!(
637 "unary operator {other:?} is not supported"
638 ))),
639 }
640 }
641
642 Expr::BinaryOp { left, op, right } => match op {
643 BinaryOperator::And => {
644 let l = eval_expr(left, table, rowid)?;
645 let r = eval_expr(right, table, rowid)?;
646 Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
647 }
648 BinaryOperator::Or => {
649 let l = eval_expr(left, table, rowid)?;
650 let r = eval_expr(right, table, rowid)?;
651 Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
652 }
653 cmp @ (BinaryOperator::Eq
654 | BinaryOperator::NotEq
655 | BinaryOperator::Lt
656 | BinaryOperator::LtEq
657 | BinaryOperator::Gt
658 | BinaryOperator::GtEq) => {
659 let l = eval_expr(left, table, rowid)?;
660 let r = eval_expr(right, table, rowid)?;
661 if matches!(l, Value::Null) || matches!(r, Value::Null) {
663 return Ok(Value::Bool(false));
664 }
665 let ord = compare_values(Some(&l), Some(&r));
666 let result = match cmp {
667 BinaryOperator::Eq => ord == Ordering::Equal,
668 BinaryOperator::NotEq => ord != Ordering::Equal,
669 BinaryOperator::Lt => ord == Ordering::Less,
670 BinaryOperator::LtEq => ord != Ordering::Greater,
671 BinaryOperator::Gt => ord == Ordering::Greater,
672 BinaryOperator::GtEq => ord != Ordering::Less,
673 _ => unreachable!(),
674 };
675 Ok(Value::Bool(result))
676 }
677 arith @ (BinaryOperator::Plus
678 | BinaryOperator::Minus
679 | BinaryOperator::Multiply
680 | BinaryOperator::Divide
681 | BinaryOperator::Modulo) => {
682 let l = eval_expr(left, table, rowid)?;
683 let r = eval_expr(right, table, rowid)?;
684 eval_arith(arith, &l, &r)
685 }
686 BinaryOperator::StringConcat => {
687 let l = eval_expr(left, table, rowid)?;
688 let r = eval_expr(right, table, rowid)?;
689 if matches!(l, Value::Null) || matches!(r, Value::Null) {
690 return Ok(Value::Null);
691 }
692 Ok(Value::Text(format!(
693 "{}{}",
694 l.to_display_string(),
695 r.to_display_string()
696 )))
697 }
698 other => Err(SQLRiteError::NotImplemented(format!(
699 "binary operator {other:?} is not supported yet"
700 ))),
701 },
702
703 Expr::Function(func) => eval_function(func, table, rowid),
714
715 other => Err(SQLRiteError::NotImplemented(format!(
716 "unsupported expression in WHERE/projection: {other:?}"
717 ))),
718 }
719}
720
721fn eval_function(func: &sqlparser::ast::Function, table: &Table, rowid: i64) -> Result<Value> {
726 let name = match func.name.0.as_slice() {
729 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
730 _ => {
731 return Err(SQLRiteError::NotImplemented(format!(
732 "qualified function names not supported: {:?}",
733 func.name
734 )));
735 }
736 };
737
738 match name.as_str() {
739 "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
740 let (a, b) = extract_two_vector_args(&name, &func.args, table, rowid)?;
741 let dist = match name.as_str() {
742 "vec_distance_l2" => vec_distance_l2(&a, &b),
743 "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
744 "vec_distance_dot" => vec_distance_dot(&a, &b),
745 _ => unreachable!(),
746 };
747 Ok(Value::Real(dist as f64))
753 }
754 other => Err(SQLRiteError::NotImplemented(format!(
755 "unknown function: {other}(...)"
756 ))),
757 }
758}
759
760fn extract_two_vector_args(
764 fn_name: &str,
765 args: &FunctionArguments,
766 table: &Table,
767 rowid: i64,
768) -> Result<(Vec<f32>, Vec<f32>)> {
769 let arg_list = match args {
770 FunctionArguments::List(l) => &l.args,
771 _ => {
772 return Err(SQLRiteError::General(format!(
773 "{fn_name}() expects exactly two vector arguments"
774 )));
775 }
776 };
777 if arg_list.len() != 2 {
778 return Err(SQLRiteError::General(format!(
779 "{fn_name}() expects exactly 2 arguments, got {}",
780 arg_list.len()
781 )));
782 }
783 let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
784 for (i, arg) in arg_list.iter().enumerate() {
785 let expr = match arg {
786 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
787 other => {
788 return Err(SQLRiteError::NotImplemented(format!(
789 "{fn_name}() argument {i} has unsupported shape: {other:?}"
790 )));
791 }
792 };
793 let val = eval_expr(expr, table, rowid)?;
794 match val {
795 Value::Vector(v) => out.push(v),
796 other => {
797 return Err(SQLRiteError::General(format!(
798 "{fn_name}() argument {i} is not a vector: got {}",
799 other.to_display_string()
800 )));
801 }
802 }
803 }
804 let b = out.pop().unwrap();
805 let a = out.pop().unwrap();
806 if a.len() != b.len() {
807 return Err(SQLRiteError::General(format!(
808 "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
809 a.len(),
810 b.len()
811 )));
812 }
813 Ok((a, b))
814}
815
816pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
819 debug_assert_eq!(a.len(), b.len());
820 let mut sum = 0.0f32;
821 for i in 0..a.len() {
822 let d = a[i] - b[i];
823 sum += d * d;
824 }
825 sum.sqrt()
826}
827
828pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
838 debug_assert_eq!(a.len(), b.len());
839 let mut dot = 0.0f32;
840 let mut norm_a_sq = 0.0f32;
841 let mut norm_b_sq = 0.0f32;
842 for i in 0..a.len() {
843 dot += a[i] * b[i];
844 norm_a_sq += a[i] * a[i];
845 norm_b_sq += b[i] * b[i];
846 }
847 let denom = (norm_a_sq * norm_b_sq).sqrt();
848 if denom == 0.0 {
849 return Err(SQLRiteError::General(
850 "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
851 ));
852 }
853 Ok(1.0 - dot / denom)
854}
855
856pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
860 debug_assert_eq!(a.len(), b.len());
861 let mut dot = 0.0f32;
862 for i in 0..a.len() {
863 dot += a[i] * b[i];
864 }
865 -dot
866}
867
868fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
871 if matches!(l, Value::Null) || matches!(r, Value::Null) {
872 return Ok(Value::Null);
873 }
874 match (l, r) {
875 (Value::Integer(a), Value::Integer(b)) => match op {
876 BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
877 BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
878 BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
879 BinaryOperator::Divide => {
880 if *b == 0 {
881 Err(SQLRiteError::General("division by zero".to_string()))
882 } else {
883 Ok(Value::Integer(a / b))
884 }
885 }
886 BinaryOperator::Modulo => {
887 if *b == 0 {
888 Err(SQLRiteError::General("modulo by zero".to_string()))
889 } else {
890 Ok(Value::Integer(a % b))
891 }
892 }
893 _ => unreachable!(),
894 },
895 (a, b) => {
897 let af = as_number(a)?;
898 let bf = as_number(b)?;
899 match op {
900 BinaryOperator::Plus => Ok(Value::Real(af + bf)),
901 BinaryOperator::Minus => Ok(Value::Real(af - bf)),
902 BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
903 BinaryOperator::Divide => {
904 if bf == 0.0 {
905 Err(SQLRiteError::General("division by zero".to_string()))
906 } else {
907 Ok(Value::Real(af / bf))
908 }
909 }
910 BinaryOperator::Modulo => {
911 if bf == 0.0 {
912 Err(SQLRiteError::General("modulo by zero".to_string()))
913 } else {
914 Ok(Value::Real(af % bf))
915 }
916 }
917 _ => unreachable!(),
918 }
919 }
920 }
921}
922
923fn as_number(v: &Value) -> Result<f64> {
924 match v {
925 Value::Integer(i) => Ok(*i as f64),
926 Value::Real(f) => Ok(*f),
927 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
928 other => Err(SQLRiteError::General(format!(
929 "arithmetic on non-numeric value '{}'",
930 other.to_display_string()
931 ))),
932 }
933}
934
935fn as_bool(v: &Value) -> Result<bool> {
936 match v {
937 Value::Bool(b) => Ok(*b),
938 Value::Null => Ok(false),
939 Value::Integer(i) => Ok(*i != 0),
940 other => Err(SQLRiteError::Internal(format!(
941 "expected boolean, got {}",
942 other.to_display_string()
943 ))),
944 }
945}
946
947fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
948 use sqlparser::ast::Value as AstValue;
949 match v {
950 AstValue::Number(n, _) => {
951 if let Ok(i) = n.parse::<i64>() {
952 Ok(Value::Integer(i))
953 } else if let Ok(f) = n.parse::<f64>() {
954 Ok(Value::Real(f))
955 } else {
956 Err(SQLRiteError::Internal(format!(
957 "could not parse numeric literal '{n}'"
958 )))
959 }
960 }
961 AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
962 AstValue::Boolean(b) => Ok(Value::Bool(*b)),
963 AstValue::Null => Ok(Value::Null),
964 other => Err(SQLRiteError::NotImplemented(format!(
965 "unsupported literal value: {other:?}"
966 ))),
967 }
968}
969
970#[cfg(test)]
971mod tests {
972 use super::*;
973
974 fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
981 (a - b).abs() < eps
982 }
983
984 #[test]
985 fn vec_distance_l2_identical_is_zero() {
986 let v = vec![0.1, 0.2, 0.3];
987 assert_eq!(vec_distance_l2(&v, &v), 0.0);
988 }
989
990 #[test]
991 fn vec_distance_l2_unit_basis_is_sqrt2() {
992 let a = vec![1.0, 0.0];
994 let b = vec![0.0, 1.0];
995 assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
996 }
997
998 #[test]
999 fn vec_distance_l2_known_value() {
1000 let a = vec![0.0, 0.0, 0.0];
1002 let b = vec![3.0, 4.0, 0.0];
1003 assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
1004 }
1005
1006 #[test]
1007 fn vec_distance_cosine_identical_is_zero() {
1008 let v = vec![0.1, 0.2, 0.3];
1009 let d = vec_distance_cosine(&v, &v).unwrap();
1010 assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
1011 }
1012
1013 #[test]
1014 fn vec_distance_cosine_orthogonal_is_one() {
1015 let a = vec![1.0, 0.0];
1018 let b = vec![0.0, 1.0];
1019 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
1020 }
1021
1022 #[test]
1023 fn vec_distance_cosine_opposite_is_two() {
1024 let a = vec![1.0, 0.0, 0.0];
1026 let b = vec![-1.0, 0.0, 0.0];
1027 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
1028 }
1029
1030 #[test]
1031 fn vec_distance_cosine_zero_magnitude_errors() {
1032 let a = vec![0.0, 0.0];
1034 let b = vec![1.0, 0.0];
1035 let err = vec_distance_cosine(&a, &b).unwrap_err();
1036 assert!(format!("{err}").contains("zero-magnitude"));
1037 }
1038
1039 #[test]
1040 fn vec_distance_dot_negates() {
1041 let a = vec![1.0, 2.0, 3.0];
1043 let b = vec![4.0, 5.0, 6.0];
1044 assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
1045 }
1046
1047 #[test]
1048 fn vec_distance_dot_orthogonal_is_zero() {
1049 let a = vec![1.0, 0.0];
1051 let b = vec![0.0, 1.0];
1052 assert_eq!(vec_distance_dot(&a, &b), 0.0);
1053 }
1054
1055 #[test]
1056 fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
1057 let a = vec![0.6f32, 0.8]; let b = vec![0.8f32, 0.6]; let dot = vec_distance_dot(&a, &b);
1063 let cos = vec_distance_cosine(&a, &b).unwrap();
1064 assert!(approx_eq(dot, cos - 1.0, 1e-5));
1065 }
1066}