1use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8 AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr, FromTable, Statement, TableFactor,
9 TableWithJoins, UnaryOperator, Update,
10};
11
12use crate::error::{Result, SQLRiteError};
13use crate::sql::db::database::Database;
14use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
15use crate::sql::db::table::{DataType, Table, Value};
16use crate::sql::parser::select::{OrderByClause, Projection, SelectQuery};
17
18pub struct SelectResult {
27 pub columns: Vec<String>,
28 pub rows: Vec<Vec<Value>>,
29}
30
31pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
35 let table = db
36 .get_table(query.table_name.clone())
37 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
38
39 let projected_cols: Vec<String> = match &query.projection {
41 Projection::All => table.column_names(),
42 Projection::Columns(cols) => {
43 for c in cols {
44 if !table.contains_column(c.to_string()) {
45 return Err(SQLRiteError::Internal(format!(
46 "Column '{c}' does not exist on table '{}'",
47 query.table_name
48 )));
49 }
50 }
51 cols.clone()
52 }
53 };
54
55 let matching = match select_rowids(table, query.selection.as_ref())? {
59 RowidSource::IndexProbe(rowids) => rowids,
60 RowidSource::FullScan => {
61 let mut out = Vec::new();
62 for rowid in table.rowids() {
63 if let Some(expr) = &query.selection {
64 if !eval_predicate(expr, table, rowid)? {
65 continue;
66 }
67 }
68 out.push(rowid);
69 }
70 out
71 }
72 };
73 let mut matching = matching;
74
75 if let Some(order) = &query.order_by {
77 sort_rowids(&mut matching, table, order)?;
78 }
79
80 if let Some(n) = query.limit {
81 matching.truncate(n);
82 }
83
84 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
88 for rowid in &matching {
89 let row: Vec<Value> = projected_cols
90 .iter()
91 .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
92 .collect();
93 rows.push(row);
94 }
95
96 Ok(SelectResult {
97 columns: projected_cols,
98 rows,
99 })
100}
101
102pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
107 let result = execute_select_rows(query, db)?;
108 let row_count = result.rows.len();
109
110 let mut print_table = PrintTable::new();
111 let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
112 print_table.add_row(PrintRow::new(header_cells));
113
114 for row in &result.rows {
115 let cells: Vec<PrintCell> = row
116 .iter()
117 .map(|v| PrintCell::new(&v.to_display_string()))
118 .collect();
119 print_table.add_row(PrintRow::new(cells));
120 }
121
122 Ok((print_table.to_string(), row_count))
123}
124
125pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
127 let Statement::Delete(Delete {
128 from, selection, ..
129 }) = stmt
130 else {
131 return Err(SQLRiteError::Internal(
132 "execute_delete called on a non-DELETE statement".to_string(),
133 ));
134 };
135
136 let tables = match from {
137 FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
138 };
139 let table_name = extract_single_table_name(tables)?;
140
141 let matching: Vec<i64> = {
143 let table = db
144 .get_table(table_name.clone())
145 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
146 match select_rowids(table, selection.as_ref())? {
147 RowidSource::IndexProbe(rowids) => rowids,
148 RowidSource::FullScan => {
149 let mut out = Vec::new();
150 for rowid in table.rowids() {
151 if let Some(expr) = selection {
152 if !eval_predicate(expr, table, rowid)? {
153 continue;
154 }
155 }
156 out.push(rowid);
157 }
158 out
159 }
160 }
161 };
162
163 let table = db.get_table_mut(table_name)?;
164 for rowid in &matching {
165 table.delete_row(*rowid);
166 }
167 Ok(matching.len())
168}
169
170pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
172 let Statement::Update(Update {
173 table,
174 assignments,
175 from,
176 selection,
177 ..
178 }) = stmt
179 else {
180 return Err(SQLRiteError::Internal(
181 "execute_update called on a non-UPDATE statement".to_string(),
182 ));
183 };
184
185 if from.is_some() {
186 return Err(SQLRiteError::NotImplemented(
187 "UPDATE ... FROM is not supported yet".to_string(),
188 ));
189 }
190
191 let table_name = extract_table_name(table)?;
192
193 let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
195 {
196 let tbl = db
197 .get_table(table_name.clone())
198 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
199 for a in assignments {
200 let col = match &a.target {
201 AssignmentTarget::ColumnName(name) => name
202 .0
203 .last()
204 .map(|p| p.to_string())
205 .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
206 AssignmentTarget::Tuple(_) => {
207 return Err(SQLRiteError::NotImplemented(
208 "tuple assignment targets are not supported".to_string(),
209 ));
210 }
211 };
212 if !tbl.contains_column(col.clone()) {
213 return Err(SQLRiteError::Internal(format!(
214 "UPDATE references unknown column '{col}'"
215 )));
216 }
217 parsed_assignments.push((col, a.value.clone()));
218 }
219 }
220
221 let work: Vec<(i64, Vec<(String, Value)>)> = {
225 let tbl = db.get_table(table_name.clone())?;
226 let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
227 RowidSource::IndexProbe(rowids) => rowids,
228 RowidSource::FullScan => {
229 let mut out = Vec::new();
230 for rowid in tbl.rowids() {
231 if let Some(expr) = selection {
232 if !eval_predicate(expr, tbl, rowid)? {
233 continue;
234 }
235 }
236 out.push(rowid);
237 }
238 out
239 }
240 };
241 let mut rows_to_update = Vec::new();
242 for rowid in matched_rowids {
243 let mut values = Vec::with_capacity(parsed_assignments.len());
244 for (col, expr) in &parsed_assignments {
245 let v = eval_expr(expr, tbl, rowid)?;
248 values.push((col.clone(), v));
249 }
250 rows_to_update.push((rowid, values));
251 }
252 rows_to_update
253 };
254
255 let tbl = db.get_table_mut(table_name)?;
256 for (rowid, values) in &work {
257 for (col, v) in values {
258 tbl.set_value(col, *rowid, v.clone())?;
259 }
260 }
261 Ok(work.len())
262}
263
264pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
268 let Statement::CreateIndex(CreateIndex {
269 name,
270 table_name,
271 columns,
272 unique,
273 if_not_exists,
274 predicate,
275 ..
276 }) = stmt
277 else {
278 return Err(SQLRiteError::Internal(
279 "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
280 ));
281 };
282
283 if predicate.is_some() {
284 return Err(SQLRiteError::NotImplemented(
285 "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
286 ));
287 }
288
289 if columns.len() != 1 {
290 return Err(SQLRiteError::NotImplemented(format!(
291 "multi-column indexes are not supported yet ({} columns given)",
292 columns.len()
293 )));
294 }
295
296 let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
297 SQLRiteError::NotImplemented(
298 "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
299 )
300 })?;
301
302 let table_name_str = table_name.to_string();
303 let column_name = match &columns[0].column.expr {
304 Expr::Identifier(ident) => ident.value.clone(),
305 Expr::CompoundIdentifier(parts) => parts
306 .last()
307 .map(|p| p.value.clone())
308 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
309 other => {
310 return Err(SQLRiteError::NotImplemented(format!(
311 "CREATE INDEX only supports simple column references, got {other:?}"
312 )));
313 }
314 };
315
316 let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
318 let table = db.get_table(table_name_str.clone()).map_err(|_| {
319 SQLRiteError::General(format!(
320 "CREATE INDEX references unknown table '{table_name_str}'"
321 ))
322 })?;
323 if !table.contains_column(column_name.clone()) {
324 return Err(SQLRiteError::General(format!(
325 "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
326 )));
327 }
328 let col = table
329 .columns
330 .iter()
331 .find(|c| c.column_name == column_name)
332 .expect("we just verified the column exists");
333 if table.index_by_name(&index_name).is_some() {
334 if *if_not_exists {
335 return Ok(index_name);
336 }
337 return Err(SQLRiteError::General(format!(
338 "index '{index_name}' already exists"
339 )));
340 }
341 let datatype = clone_datatype(&col.datatype);
342
343 let mut pairs = Vec::new();
347 for rowid in table.rowids() {
348 if let Some(v) = table.get_value(&column_name, rowid) {
349 pairs.push((rowid, v));
350 }
351 }
352 (datatype, pairs)
353 };
354
355 let mut idx = SecondaryIndex::new(
357 index_name.clone(),
358 table_name_str.clone(),
359 column_name.clone(),
360 &datatype,
361 *unique,
362 IndexOrigin::Explicit,
363 )?;
364
365 for (rowid, v) in &existing_rowids_and_values {
369 if *unique && idx.would_violate_unique(v) {
370 return Err(SQLRiteError::General(format!(
371 "cannot create UNIQUE index '{index_name}': column '{column_name}' \
372 already contains the duplicate value {}",
373 v.to_display_string()
374 )));
375 }
376 idx.insert(v, *rowid)?;
377 }
378
379 let table_mut = db.get_table_mut(table_name_str)?;
381 table_mut.secondary_indexes.push(idx);
382 Ok(index_name)
383}
384
385fn clone_datatype(dt: &DataType) -> DataType {
388 match dt {
389 DataType::Integer => DataType::Integer,
390 DataType::Text => DataType::Text,
391 DataType::Real => DataType::Real,
392 DataType::Bool => DataType::Bool,
393 DataType::None => DataType::None,
394 DataType::Invalid => DataType::Invalid,
395 }
396}
397
398fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
399 if tables.len() != 1 {
400 return Err(SQLRiteError::NotImplemented(
401 "multi-table DELETE is not supported yet".to_string(),
402 ));
403 }
404 extract_table_name(&tables[0])
405}
406
407fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
408 if !twj.joins.is_empty() {
409 return Err(SQLRiteError::NotImplemented(
410 "JOIN is not supported yet".to_string(),
411 ));
412 }
413 match &twj.relation {
414 TableFactor::Table { name, .. } => Ok(name.to_string()),
415 _ => Err(SQLRiteError::NotImplemented(
416 "only plain table references are supported".to_string(),
417 )),
418 }
419}
420
421enum RowidSource {
423 IndexProbe(Vec<i64>),
427 FullScan,
430}
431
432fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
437 let Some(expr) = selection else {
438 return Ok(RowidSource::FullScan);
439 };
440 let Some((col, literal)) = try_extract_equality(expr) else {
441 return Ok(RowidSource::FullScan);
442 };
443 let Some(idx) = table.index_for_column(&col) else {
444 return Ok(RowidSource::FullScan);
445 };
446
447 let literal_value = match convert_literal(&literal) {
451 Ok(v) => v,
452 Err(_) => return Ok(RowidSource::FullScan),
453 };
454
455 let mut rowids = idx.lookup(&literal_value);
459 rowids.sort_unstable();
460 Ok(RowidSource::IndexProbe(rowids))
461}
462
463fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
467 let peeled = match expr {
469 Expr::Nested(inner) => inner.as_ref(),
470 other => other,
471 };
472 let Expr::BinaryOp { left, op, right } = peeled else {
473 return None;
474 };
475 if !matches!(op, BinaryOperator::Eq) {
476 return None;
477 }
478 let col_from = |e: &Expr| -> Option<String> {
479 match e {
480 Expr::Identifier(ident) => Some(ident.value.clone()),
481 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
482 _ => None,
483 }
484 };
485 let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
486 if let Expr::Value(v) = e {
487 Some(v.value.clone())
488 } else {
489 None
490 }
491 };
492 if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
493 return Some((c, l));
494 }
495 if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
496 return Some((c, l));
497 }
498 None
499}
500
501fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
502 if !table.contains_column(order.column.clone()) {
503 return Err(SQLRiteError::Internal(format!(
504 "ORDER BY references unknown column '{}'",
505 order.column
506 )));
507 }
508 rowids.sort_by(|a, b| {
509 let va = table.get_value(&order.column, *a);
510 let vb = table.get_value(&order.column, *b);
511 let ord = compare_values(va.as_ref(), vb.as_ref());
512 if order.ascending { ord } else { ord.reverse() }
513 });
514 Ok(())
515}
516
517fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
518 match (a, b) {
519 (None, None) => Ordering::Equal,
520 (None, _) => Ordering::Less,
521 (_, None) => Ordering::Greater,
522 (Some(a), Some(b)) => match (a, b) {
523 (Value::Null, Value::Null) => Ordering::Equal,
524 (Value::Null, _) => Ordering::Less,
525 (_, Value::Null) => Ordering::Greater,
526 (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
527 (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
528 (Value::Integer(x), Value::Real(y)) => {
529 (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
530 }
531 (Value::Real(x), Value::Integer(y)) => {
532 x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
533 }
534 (Value::Text(x), Value::Text(y)) => x.cmp(y),
535 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
536 (x, y) => x.to_display_string().cmp(&y.to_display_string()),
538 },
539 }
540}
541
542pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
544 let v = eval_expr(expr, table, rowid)?;
545 match v {
546 Value::Bool(b) => Ok(b),
547 Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
549 other => Err(SQLRiteError::Internal(format!(
550 "WHERE clause must evaluate to boolean, got {}",
551 other.to_display_string()
552 ))),
553 }
554}
555
556fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
557 match expr {
558 Expr::Nested(inner) => eval_expr(inner, table, rowid),
559
560 Expr::Identifier(ident) => Ok(table.get_value(&ident.value, rowid).unwrap_or(Value::Null)),
561
562 Expr::CompoundIdentifier(parts) => {
563 let col = parts
565 .last()
566 .map(|i| i.value.as_str())
567 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
568 Ok(table.get_value(col, rowid).unwrap_or(Value::Null))
569 }
570
571 Expr::Value(v) => convert_literal(&v.value),
572
573 Expr::UnaryOp { op, expr } => {
574 let inner = eval_expr(expr, table, rowid)?;
575 match op {
576 UnaryOperator::Not => match inner {
577 Value::Bool(b) => Ok(Value::Bool(!b)),
578 Value::Null => Ok(Value::Null),
579 other => Err(SQLRiteError::Internal(format!(
580 "NOT applied to non-boolean value: {}",
581 other.to_display_string()
582 ))),
583 },
584 UnaryOperator::Minus => match inner {
585 Value::Integer(i) => Ok(Value::Integer(-i)),
586 Value::Real(f) => Ok(Value::Real(-f)),
587 Value::Null => Ok(Value::Null),
588 other => Err(SQLRiteError::Internal(format!(
589 "unary minus on non-numeric value: {}",
590 other.to_display_string()
591 ))),
592 },
593 UnaryOperator::Plus => Ok(inner),
594 other => Err(SQLRiteError::NotImplemented(format!(
595 "unary operator {other:?} is not supported"
596 ))),
597 }
598 }
599
600 Expr::BinaryOp { left, op, right } => match op {
601 BinaryOperator::And => {
602 let l = eval_expr(left, table, rowid)?;
603 let r = eval_expr(right, table, rowid)?;
604 Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
605 }
606 BinaryOperator::Or => {
607 let l = eval_expr(left, table, rowid)?;
608 let r = eval_expr(right, table, rowid)?;
609 Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
610 }
611 cmp @ (BinaryOperator::Eq
612 | BinaryOperator::NotEq
613 | BinaryOperator::Lt
614 | BinaryOperator::LtEq
615 | BinaryOperator::Gt
616 | BinaryOperator::GtEq) => {
617 let l = eval_expr(left, table, rowid)?;
618 let r = eval_expr(right, table, rowid)?;
619 if matches!(l, Value::Null) || matches!(r, Value::Null) {
621 return Ok(Value::Bool(false));
622 }
623 let ord = compare_values(Some(&l), Some(&r));
624 let result = match cmp {
625 BinaryOperator::Eq => ord == Ordering::Equal,
626 BinaryOperator::NotEq => ord != Ordering::Equal,
627 BinaryOperator::Lt => ord == Ordering::Less,
628 BinaryOperator::LtEq => ord != Ordering::Greater,
629 BinaryOperator::Gt => ord == Ordering::Greater,
630 BinaryOperator::GtEq => ord != Ordering::Less,
631 _ => unreachable!(),
632 };
633 Ok(Value::Bool(result))
634 }
635 arith @ (BinaryOperator::Plus
636 | BinaryOperator::Minus
637 | BinaryOperator::Multiply
638 | BinaryOperator::Divide
639 | BinaryOperator::Modulo) => {
640 let l = eval_expr(left, table, rowid)?;
641 let r = eval_expr(right, table, rowid)?;
642 eval_arith(arith, &l, &r)
643 }
644 BinaryOperator::StringConcat => {
645 let l = eval_expr(left, table, rowid)?;
646 let r = eval_expr(right, table, rowid)?;
647 if matches!(l, Value::Null) || matches!(r, Value::Null) {
648 return Ok(Value::Null);
649 }
650 Ok(Value::Text(format!(
651 "{}{}",
652 l.to_display_string(),
653 r.to_display_string()
654 )))
655 }
656 other => Err(SQLRiteError::NotImplemented(format!(
657 "binary operator {other:?} is not supported yet"
658 ))),
659 },
660
661 other => Err(SQLRiteError::NotImplemented(format!(
662 "unsupported expression in WHERE/projection: {other:?}"
663 ))),
664 }
665}
666
667fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
670 if matches!(l, Value::Null) || matches!(r, Value::Null) {
671 return Ok(Value::Null);
672 }
673 match (l, r) {
674 (Value::Integer(a), Value::Integer(b)) => match op {
675 BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
676 BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
677 BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
678 BinaryOperator::Divide => {
679 if *b == 0 {
680 Err(SQLRiteError::General("division by zero".to_string()))
681 } else {
682 Ok(Value::Integer(a / b))
683 }
684 }
685 BinaryOperator::Modulo => {
686 if *b == 0 {
687 Err(SQLRiteError::General("modulo by zero".to_string()))
688 } else {
689 Ok(Value::Integer(a % b))
690 }
691 }
692 _ => unreachable!(),
693 },
694 (a, b) => {
696 let af = as_number(a)?;
697 let bf = as_number(b)?;
698 match op {
699 BinaryOperator::Plus => Ok(Value::Real(af + bf)),
700 BinaryOperator::Minus => Ok(Value::Real(af - bf)),
701 BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
702 BinaryOperator::Divide => {
703 if bf == 0.0 {
704 Err(SQLRiteError::General("division by zero".to_string()))
705 } else {
706 Ok(Value::Real(af / bf))
707 }
708 }
709 BinaryOperator::Modulo => {
710 if bf == 0.0 {
711 Err(SQLRiteError::General("modulo by zero".to_string()))
712 } else {
713 Ok(Value::Real(af % bf))
714 }
715 }
716 _ => unreachable!(),
717 }
718 }
719 }
720}
721
722fn as_number(v: &Value) -> Result<f64> {
723 match v {
724 Value::Integer(i) => Ok(*i as f64),
725 Value::Real(f) => Ok(*f),
726 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
727 other => Err(SQLRiteError::General(format!(
728 "arithmetic on non-numeric value '{}'",
729 other.to_display_string()
730 ))),
731 }
732}
733
734fn as_bool(v: &Value) -> Result<bool> {
735 match v {
736 Value::Bool(b) => Ok(*b),
737 Value::Null => Ok(false),
738 Value::Integer(i) => Ok(*i != 0),
739 other => Err(SQLRiteError::Internal(format!(
740 "expected boolean, got {}",
741 other.to_display_string()
742 ))),
743 }
744}
745
746fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
747 use sqlparser::ast::Value as AstValue;
748 match v {
749 AstValue::Number(n, _) => {
750 if let Ok(i) = n.parse::<i64>() {
751 Ok(Value::Integer(i))
752 } else if let Ok(f) = n.parse::<f64>() {
753 Ok(Value::Real(f))
754 } else {
755 Err(SQLRiteError::Internal(format!(
756 "could not parse numeric literal '{n}'"
757 )))
758 }
759 }
760 AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
761 AstValue::Boolean(b) => Ok(Value::Bool(*b)),
762 AstValue::Null => Ok(Value::Null),
763 other => Err(SQLRiteError::NotImplemented(format!(
764 "unsupported literal value: {other:?}"
765 ))),
766 }
767}