1pub mod ast;
31pub mod bridge;
32pub mod compatibility;
33pub mod error;
34pub mod lexer;
35pub mod parser;
36pub mod token;
37
38pub use ast::*;
39pub use bridge::{ExecutionResult as BridgeExecutionResult, SqlBridge, SqlConnection};
40pub use compatibility::{CompatibilityMatrix, FeatureSupport, SqlDialect, SqlFeature, get_feature_support};
41pub use error::{SqlError, SqlResult};
42pub use lexer::{LexError, Lexer};
43pub use parser::{ParseError, Parser};
44pub use token::{Span, Token, TokenKind};
45
46use std::collections::HashMap;
47use sochdb_core::SochValue;
48
49#[derive(Debug, Clone)]
51pub enum ExecutionResult {
52 Rows {
54 columns: Vec<String>,
55 rows: Vec<HashMap<String, SochValue>>,
56 },
57 RowsAffected(usize),
59 Ok,
61}
62
63impl ExecutionResult {
64 pub fn rows(&self) -> Option<&Vec<HashMap<String, SochValue>>> {
66 match self {
67 ExecutionResult::Rows { rows, .. } => Some(rows),
68 _ => None,
69 }
70 }
71
72 pub fn columns(&self) -> Option<&Vec<String>> {
74 match self {
75 ExecutionResult::Rows { columns, .. } => Some(columns),
76 _ => None,
77 }
78 }
79
80 pub fn rows_affected(&self) -> usize {
82 match self {
83 ExecutionResult::RowsAffected(n) => *n,
84 ExecutionResult::Rows { rows, .. } => rows.len(),
85 ExecutionResult::Ok => 0,
86 }
87 }
88}
89
90pub struct SqlExecutor {
94 tables: HashMap<String, TableData>,
96}
97
98#[derive(Debug, Clone)]
100pub struct TableData {
101 pub columns: Vec<String>,
102 pub column_types: Vec<DataType>,
103 pub rows: Vec<Vec<SochValue>>,
104}
105
106impl Default for SqlExecutor {
107 fn default() -> Self {
108 Self::new()
109 }
110}
111
112impl SqlExecutor {
113 pub fn new() -> Self {
115 Self {
116 tables: HashMap::new(),
117 }
118 }
119
120 pub fn execute(&mut self, sql: &str) -> SqlResult<ExecutionResult> {
122 self.execute_with_params(sql, &[])
123 }
124
125 pub fn execute_with_params(
127 &mut self,
128 sql: &str,
129 params: &[SochValue],
130 ) -> SqlResult<ExecutionResult> {
131 let stmt = Parser::parse(sql).map_err(SqlError::from_parse_errors)?;
132 self.execute_statement(&stmt, params)
133 }
134
135 pub fn execute_statement(
137 &mut self,
138 stmt: &Statement,
139 params: &[SochValue],
140 ) -> SqlResult<ExecutionResult> {
141 match stmt {
142 Statement::Select(select) => self.execute_select(select, params),
143 Statement::Insert(insert) => self.execute_insert(insert, params),
144 Statement::Update(update) => self.execute_update(update, params),
145 Statement::Delete(delete) => self.execute_delete(delete, params),
146 Statement::CreateTable(create) => self.execute_create_table(create),
147 Statement::DropTable(drop) => self.execute_drop_table(drop),
148 Statement::Begin(_) => Ok(ExecutionResult::Ok),
149 Statement::Commit => Ok(ExecutionResult::Ok),
150 Statement::Rollback(_) => Ok(ExecutionResult::Ok),
151 _ => Err(SqlError::NotImplemented(
152 "Statement type not yet supported".into(),
153 )),
154 }
155 }
156
157 fn execute_select(
158 &self,
159 select: &SelectStmt,
160 params: &[SochValue],
161 ) -> SqlResult<ExecutionResult> {
162 let from = select
164 .from
165 .as_ref()
166 .ok_or_else(|| SqlError::InvalidArgument("SELECT requires FROM clause".into()))?;
167
168 if from.tables.len() != 1 {
169 return Err(SqlError::NotImplemented(
170 "Multi-table queries not yet supported".into(),
171 ));
172 }
173
174 let table_name = match &from.tables[0] {
175 TableRef::Table { name, .. } => name.name().to_string(),
176 _ => {
177 return Err(SqlError::NotImplemented(
178 "Complex table references not yet supported".into(),
179 ));
180 }
181 };
182
183 let table = self
184 .tables
185 .get(&table_name)
186 .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
187
188 let mut source_rows = Vec::new();
190
191 for row in &table.rows {
192 let row_map: HashMap<String, SochValue> = table
194 .columns
195 .iter()
196 .zip(row.iter())
197 .map(|(col, val)| (col.clone(), val.clone()))
198 .collect();
199
200 if let Some(where_clause) = &select.where_clause
202 && !self.evaluate_where(where_clause, &row_map, params)?
203 {
204 continue;
205 }
206
207 source_rows.push(row_map);
208 }
209
210 if !select.order_by.is_empty() {
212 source_rows.sort_by(|a, b| {
213 for order_item in &select.order_by {
214 if let Expr::Column(col_ref) = &order_item.expr {
215 let a_val = a.get(&col_ref.column);
216 let b_val = b.get(&col_ref.column);
217
218 let cmp = self.compare_values(a_val, b_val);
219 if cmp != std::cmp::Ordering::Equal {
220 return if order_item.asc { cmp } else { cmp.reverse() };
221 }
222 }
223 }
224 std::cmp::Ordering::Equal
225 });
226 }
227
228 if let Some(Expr::Literal(Literal::Integer(n))) = &select.offset {
230 let n = *n as usize;
231 if n < source_rows.len() {
232 source_rows = source_rows.into_iter().skip(n).collect();
233 } else {
234 source_rows.clear();
235 }
236 }
237
238 if let Some(Expr::Literal(Literal::Integer(n))) = &select.limit {
240 source_rows.truncate(*n as usize);
241 }
242
243 let mut output_columns: Vec<String> = Vec::new();
245 let mut result_rows: Vec<HashMap<String, SochValue>> = Vec::new();
246
247 let is_wildcard = matches!(&select.columns[..], [SelectItem::Wildcard]);
249
250 if is_wildcard {
251 output_columns = table.columns.clone();
252 result_rows = source_rows;
253 } else {
254 for item in &select.columns {
256 match item {
257 SelectItem::Wildcard => output_columns.push("*".to_string()),
258 SelectItem::QualifiedWildcard(t) => output_columns.push(format!("{}.*", t)),
259 SelectItem::Expr { expr, alias } => {
260 let col_name = alias.clone().unwrap_or_else(|| match expr {
261 Expr::Column(col) => col.column.clone(),
262 Expr::Function(func) => format!("{}()", func.name.name()),
263 _ => "?column?".to_string(),
264 });
265 output_columns.push(col_name);
266 }
267 }
268 }
269
270 for source_row in &source_rows {
272 let mut result_row = HashMap::new();
273
274 for (idx, item) in select.columns.iter().enumerate() {
275 let col_name = &output_columns[idx];
276
277 match item {
278 SelectItem::Wildcard => {
279 for (k, v) in source_row {
281 result_row.insert(k.clone(), v.clone());
282 }
283 }
284 SelectItem::QualifiedWildcard(_) => {
285 for (k, v) in source_row {
287 result_row.insert(k.clone(), v.clone());
288 }
289 }
290 SelectItem::Expr { expr, .. } => {
291 let value = self.evaluate_expr(expr, source_row, params)?;
292 result_row.insert(col_name.clone(), value);
293 }
294 }
295 }
296
297 result_rows.push(result_row);
298 }
299 }
300
301 Ok(ExecutionResult::Rows {
302 columns: output_columns,
303 rows: result_rows,
304 })
305 }
306
307 fn execute_insert(
308 &mut self,
309 insert: &InsertStmt,
310 params: &[SochValue],
311 ) -> SqlResult<ExecutionResult> {
312 let table_name = insert.table.name().to_string();
313
314 let table_columns = {
316 let table = self
317 .tables
318 .get(&table_name)
319 .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
320 table.columns.clone()
321 };
322
323 let mut rows_affected = 0;
324 let mut new_rows = Vec::new();
325
326 match &insert.source {
327 InsertSource::Values(rows) => {
328 for value_exprs in rows {
329 let mut row_values = Vec::new();
330
331 if let Some(columns) = &insert.columns {
332 if columns.len() != value_exprs.len() {
333 return Err(SqlError::InvalidArgument(format!(
334 "Column count ({}) doesn't match value count ({})",
335 columns.len(),
336 value_exprs.len()
337 )));
338 }
339
340 for table_col in &table_columns {
342 if let Some(pos) = columns.iter().position(|c| c == table_col) {
343 let value =
344 self.evaluate_expr(&value_exprs[pos], &HashMap::new(), params)?;
345 row_values.push(value);
346 } else {
347 row_values.push(SochValue::Null);
348 }
349 }
350 } else {
351 for expr in value_exprs {
353 let value = self.evaluate_expr(expr, &HashMap::new(), params)?;
354 row_values.push(value);
355 }
356 }
357
358 new_rows.push(row_values);
359 rows_affected += 1;
360 }
361 }
362 InsertSource::Query(_) => {
363 return Err(SqlError::NotImplemented(
364 "INSERT ... SELECT not yet supported".into(),
365 ));
366 }
367 InsertSource::Default => {
368 return Err(SqlError::NotImplemented(
369 "INSERT DEFAULT VALUES not yet supported".into(),
370 ));
371 }
372 }
373
374 let table = self.tables.get_mut(&table_name).unwrap();
376 for row in new_rows {
377 table.rows.push(row);
378 }
379
380 Ok(ExecutionResult::RowsAffected(rows_affected))
381 }
382
383 fn execute_update(
384 &mut self,
385 update: &UpdateStmt,
386 params: &[SochValue],
387 ) -> SqlResult<ExecutionResult> {
388 let table_name = update.table.name().to_string();
389
390 let (_table_columns, updates_to_apply) = {
392 let table = self
393 .tables
394 .get(&table_name)
395 .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
396
397 let mut updates = Vec::new();
398
399 for row_idx in 0..table.rows.len() {
400 let row_map: HashMap<String, SochValue> = table
402 .columns
403 .iter()
404 .zip(table.rows[row_idx].iter())
405 .map(|(col, val)| (col.clone(), val.clone()))
406 .collect();
407
408 let matches = if let Some(where_clause) = &update.where_clause {
410 self.evaluate_where(where_clause, &row_map, params)?
411 } else {
412 true
413 };
414
415 if matches {
416 let mut row_updates = Vec::new();
418 for assignment in &update.assignments {
419 if let Some(col_idx) =
420 table.columns.iter().position(|c| c == &assignment.column)
421 {
422 let value = self.evaluate_expr(&assignment.value, &row_map, params)?;
423 row_updates.push((col_idx, value));
424 }
425 }
426 updates.push((row_idx, row_updates));
427 }
428 }
429
430 (table.columns.clone(), updates)
431 };
432
433 let rows_affected = updates_to_apply.len();
434
435 let table = self.tables.get_mut(&table_name).unwrap();
437 for (row_idx, row_updates) in updates_to_apply {
438 for (col_idx, value) in row_updates {
439 table.rows[row_idx][col_idx] = value;
440 }
441 }
442
443 Ok(ExecutionResult::RowsAffected(rows_affected))
444 }
445
446 fn execute_delete(
447 &mut self,
448 delete: &DeleteStmt,
449 params: &[SochValue],
450 ) -> SqlResult<ExecutionResult> {
451 let table_name = delete.table.name().to_string();
452
453 let indices_to_remove = {
455 let table = self
456 .tables
457 .get(&table_name)
458 .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
459
460 let mut indices = Vec::new();
461
462 for (row_idx, row) in table.rows.iter().enumerate() {
463 let row_map: HashMap<String, SochValue> = table
465 .columns
466 .iter()
467 .zip(row.iter())
468 .map(|(col, val)| (col.clone(), val.clone()))
469 .collect();
470
471 let matches = if let Some(where_clause) = &delete.where_clause {
473 self.evaluate_where(where_clause, &row_map, params)?
474 } else {
475 true
476 };
477
478 if matches {
479 indices.push(row_idx);
480 }
481 }
482
483 indices
484 };
485
486 let rows_affected = indices_to_remove.len();
487
488 let table = self.tables.get_mut(&table_name).unwrap();
490 for idx in indices_to_remove.into_iter().rev() {
492 table.rows.remove(idx);
493 }
494
495 Ok(ExecutionResult::RowsAffected(rows_affected))
496 }
497
498 fn execute_create_table(&mut self, create: &CreateTableStmt) -> SqlResult<ExecutionResult> {
499 let table_name = create.name.name().to_string();
500
501 if self.tables.contains_key(&table_name) {
502 if create.if_not_exists {
503 return Ok(ExecutionResult::Ok);
504 }
505 return Err(SqlError::ConstraintViolation(format!(
506 "Table '{}' already exists",
507 table_name
508 )));
509 }
510
511 let columns: Vec<String> = create.columns.iter().map(|c| c.name.clone()).collect();
512 let column_types: Vec<DataType> =
513 create.columns.iter().map(|c| c.data_type.clone()).collect();
514
515 self.tables.insert(
516 table_name,
517 TableData {
518 columns,
519 column_types,
520 rows: Vec::new(),
521 },
522 );
523
524 Ok(ExecutionResult::Ok)
525 }
526
527 fn execute_drop_table(&mut self, drop: &DropTableStmt) -> SqlResult<ExecutionResult> {
528 for name in &drop.names {
529 let table_name = name.name().to_string();
530 if self.tables.remove(&table_name).is_none() && !drop.if_exists {
531 return Err(SqlError::TableNotFound(table_name));
532 }
533 }
534
535 Ok(ExecutionResult::Ok)
536 }
537
538 fn evaluate_where(
541 &self,
542 expr: &Expr,
543 row: &HashMap<String, SochValue>,
544 params: &[SochValue],
545 ) -> SqlResult<bool> {
546 let value = self.evaluate_expr(expr, row, params)?;
547 match value {
548 SochValue::Bool(b) => Ok(b),
549 SochValue::Null => Ok(false),
550 _ => Err(SqlError::TypeError(
551 "WHERE clause must evaluate to boolean".into(),
552 )),
553 }
554 }
555
556 fn evaluate_expr(
557 &self,
558 expr: &Expr,
559 row: &HashMap<String, SochValue>,
560 params: &[SochValue],
561 ) -> SqlResult<SochValue> {
562 match expr {
563 Expr::Literal(lit) => Ok(self.literal_to_value(lit)),
564
565 Expr::Column(col_ref) => row
566 .get(&col_ref.column)
567 .cloned()
568 .ok_or_else(|| SqlError::ColumnNotFound(col_ref.column.clone())),
569
570 Expr::Placeholder(n) => params
571 .get((*n as usize).saturating_sub(1))
572 .cloned()
573 .ok_or_else(|| SqlError::InvalidArgument(format!("Parameter ${} not provided", n))),
574
575 Expr::BinaryOp { left, op, right } => {
576 let left_val = self.evaluate_expr(left, row, params)?;
577 let right_val = self.evaluate_expr(right, row, params)?;
578 self.evaluate_binary_op(&left_val, op, &right_val)
579 }
580
581 Expr::UnaryOp { op, expr } => {
582 let val = self.evaluate_expr(expr, row, params)?;
583 self.evaluate_unary_op(op, &val)
584 }
585
586 Expr::IsNull { expr, negated } => {
587 let val = self.evaluate_expr(expr, row, params)?;
588 let is_null = matches!(val, SochValue::Null);
589 Ok(SochValue::Bool(if *negated { !is_null } else { is_null }))
590 }
591
592 Expr::InList {
593 expr,
594 list,
595 negated,
596 } => {
597 let val = self.evaluate_expr(expr, row, params)?;
598 let mut found = false;
599 for item in list {
600 let item_val = self.evaluate_expr(item, row, params)?;
601 if self.values_equal(&val, &item_val) {
602 found = true;
603 break;
604 }
605 }
606 Ok(SochValue::Bool(if *negated { !found } else { found }))
607 }
608
609 Expr::Between {
610 expr,
611 low,
612 high,
613 negated,
614 } => {
615 let val = self.evaluate_expr(expr, row, params)?;
616 let low_val = self.evaluate_expr(low, row, params)?;
617 let high_val = self.evaluate_expr(high, row, params)?;
618
619 let cmp_low = self.compare_values(Some(&val), Some(&low_val));
620 let cmp_high = self.compare_values(Some(&val), Some(&high_val));
621
622 let in_range =
623 cmp_low != std::cmp::Ordering::Less && cmp_high != std::cmp::Ordering::Greater;
624
625 Ok(SochValue::Bool(if *negated { !in_range } else { in_range }))
626 }
627
628 Expr::Like {
629 expr,
630 pattern,
631 negated,
632 ..
633 } => {
634 let val = self.evaluate_expr(expr, row, params)?;
635 let pattern_val = self.evaluate_expr(pattern, row, params)?;
636
637 match (&val, &pattern_val) {
638 (SochValue::Text(s), SochValue::Text(p)) => {
639 let regex_pattern = p.replace('%', ".*").replace('_', ".");
640 let matches = regex::Regex::new(&format!("^{}$", regex_pattern))
641 .map(|re| re.is_match(s))
642 .unwrap_or(false);
643 Ok(SochValue::Bool(if *negated { !matches } else { matches }))
644 }
645 _ => Ok(SochValue::Bool(false)),
646 }
647 }
648
649 Expr::Function(func) => self.evaluate_function(func, row, params),
650
651 Expr::Case {
652 operand,
653 conditions,
654 else_result,
655 } => {
656 if let Some(op) = operand {
657 let op_val = self.evaluate_expr(op, row, params)?;
659 for (when_expr, then_expr) in conditions {
660 let when_val = self.evaluate_expr(when_expr, row, params)?;
661 if self.values_equal(&op_val, &when_val) {
662 return self.evaluate_expr(then_expr, row, params);
663 }
664 }
665 } else {
666 for (when_expr, then_expr) in conditions {
668 let when_val = self.evaluate_expr(when_expr, row, params)?;
669 if matches!(when_val, SochValue::Bool(true)) {
670 return self.evaluate_expr(then_expr, row, params);
671 }
672 }
673 }
674
675 if let Some(else_expr) = else_result {
676 self.evaluate_expr(else_expr, row, params)
677 } else {
678 Ok(SochValue::Null)
679 }
680 }
681
682 _ => Err(SqlError::NotImplemented(format!(
683 "Expression type {:?} not yet supported",
684 expr
685 ))),
686 }
687 }
688
689 fn literal_to_value(&self, lit: &Literal) -> SochValue {
690 match lit {
691 Literal::Null => SochValue::Null,
692 Literal::Boolean(b) => SochValue::Bool(*b),
693 Literal::Integer(n) => SochValue::Int(*n),
694 Literal::Float(f) => SochValue::Float(*f),
695 Literal::String(s) => SochValue::Text(s.clone()),
696 Literal::Blob(b) => SochValue::Binary(b.clone()),
697 }
698 }
699
700 fn evaluate_binary_op(
701 &self,
702 left: &SochValue,
703 op: &BinaryOperator,
704 right: &SochValue,
705 ) -> SqlResult<SochValue> {
706 match op {
707 BinaryOperator::Eq => Ok(SochValue::Bool(self.values_equal(left, right))),
708 BinaryOperator::Ne => Ok(SochValue::Bool(!self.values_equal(left, right))),
709 BinaryOperator::Lt => Ok(SochValue::Bool(
710 self.compare_values(Some(left), Some(right)) == std::cmp::Ordering::Less,
711 )),
712 BinaryOperator::Le => Ok(SochValue::Bool(
713 self.compare_values(Some(left), Some(right)) != std::cmp::Ordering::Greater,
714 )),
715 BinaryOperator::Gt => Ok(SochValue::Bool(
716 self.compare_values(Some(left), Some(right)) == std::cmp::Ordering::Greater,
717 )),
718 BinaryOperator::Ge => Ok(SochValue::Bool(
719 self.compare_values(Some(left), Some(right)) != std::cmp::Ordering::Less,
720 )),
721
722 BinaryOperator::And => match (left, right) {
723 (SochValue::Bool(l), SochValue::Bool(r)) => Ok(SochValue::Bool(*l && *r)),
724 (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
725 _ => Err(SqlError::TypeError("AND requires boolean operands".into())),
726 },
727
728 BinaryOperator::Or => match (left, right) {
729 (SochValue::Bool(l), SochValue::Bool(r)) => Ok(SochValue::Bool(*l || *r)),
730 (SochValue::Bool(true), _) | (_, SochValue::Bool(true)) => {
731 Ok(SochValue::Bool(true))
732 }
733 (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
734 _ => Err(SqlError::TypeError("OR requires boolean operands".into())),
735 },
736
737 BinaryOperator::Plus => self.arithmetic_op(left, right, |a, b| a + b, |a, b| a + b),
738 BinaryOperator::Minus => self.arithmetic_op(left, right, |a, b| a - b, |a, b| a - b),
739 BinaryOperator::Multiply => self.arithmetic_op(left, right, |a, b| a * b, |a, b| a * b),
740 BinaryOperator::Divide => self.arithmetic_op(
741 left,
742 right,
743 |a, b| if b != 0 { a / b } else { 0 },
744 |a, b| a / b,
745 ),
746 BinaryOperator::Modulo => self.arithmetic_op(
747 left,
748 right,
749 |a, b| if b != 0 { a % b } else { 0 },
750 |a, b| a % b,
751 ),
752
753 BinaryOperator::Concat => match (left, right) {
754 (SochValue::Text(l), SochValue::Text(r)) => {
755 Ok(SochValue::Text(format!("{}{}", l, r)))
756 }
757 (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
758 _ => Err(SqlError::TypeError("|| requires string operands".into())),
759 },
760
761 _ => Err(SqlError::NotImplemented(format!(
762 "Operator {:?} not implemented",
763 op
764 ))),
765 }
766 }
767
768 fn evaluate_unary_op(&self, op: &UnaryOperator, val: &SochValue) -> SqlResult<SochValue> {
769 match op {
770 UnaryOperator::Not => match val {
771 SochValue::Bool(b) => Ok(SochValue::Bool(!b)),
772 SochValue::Null => Ok(SochValue::Null),
773 _ => Err(SqlError::TypeError("NOT requires boolean operand".into())),
774 },
775 UnaryOperator::Minus => match val {
776 SochValue::Int(n) => Ok(SochValue::Int(-n)),
777 SochValue::Float(f) => Ok(SochValue::Float(-f)),
778 SochValue::Null => Ok(SochValue::Null),
779 _ => Err(SqlError::TypeError(
780 "Unary minus requires numeric operand".into(),
781 )),
782 },
783 UnaryOperator::Plus => Ok(val.clone()),
784 UnaryOperator::BitNot => match val {
785 SochValue::Int(n) => Ok(SochValue::Int(!n)),
786 _ => Err(SqlError::TypeError("~ requires integer operand".into())),
787 },
788 }
789 }
790
791 fn evaluate_function(
792 &self,
793 func: &FunctionCall,
794 row: &HashMap<String, SochValue>,
795 params: &[SochValue],
796 ) -> SqlResult<SochValue> {
797 let func_name = func.name.name().to_uppercase();
798
799 match func_name.as_str() {
800 "COALESCE" => {
801 for arg in &func.args {
802 let val = self.evaluate_expr(arg, row, params)?;
803 if !matches!(val, SochValue::Null) {
804 return Ok(val);
805 }
806 }
807 Ok(SochValue::Null)
808 }
809
810 "NULLIF" => {
811 if func.args.len() != 2 {
812 return Err(SqlError::InvalidArgument(
813 "NULLIF requires 2 arguments".into(),
814 ));
815 }
816 let val1 = self.evaluate_expr(&func.args[0], row, params)?;
817 let val2 = self.evaluate_expr(&func.args[1], row, params)?;
818 if self.values_equal(&val1, &val2) {
819 Ok(SochValue::Null)
820 } else {
821 Ok(val1)
822 }
823 }
824
825 "ABS" => {
826 if func.args.len() != 1 {
827 return Err(SqlError::InvalidArgument("ABS requires 1 argument".into()));
828 }
829 let val = self.evaluate_expr(&func.args[0], row, params)?;
830 match val {
831 SochValue::Int(n) => Ok(SochValue::Int(n.abs())),
832 SochValue::Float(f) => Ok(SochValue::Float(f.abs())),
833 SochValue::Null => Ok(SochValue::Null),
834 _ => Err(SqlError::TypeError("ABS requires numeric argument".into())),
835 }
836 }
837
838 "LENGTH" | "LEN" => {
839 if func.args.len() != 1 {
840 return Err(SqlError::InvalidArgument(
841 "LENGTH requires 1 argument".into(),
842 ));
843 }
844 let val = self.evaluate_expr(&func.args[0], row, params)?;
845 match val {
846 SochValue::Text(s) => Ok(SochValue::Int(s.len() as i64)),
847 SochValue::Binary(b) => Ok(SochValue::Int(b.len() as i64)),
848 SochValue::Null => Ok(SochValue::Null),
849 _ => Err(SqlError::TypeError(
850 "LENGTH requires string argument".into(),
851 )),
852 }
853 }
854
855 "UPPER" => {
856 if func.args.len() != 1 {
857 return Err(SqlError::InvalidArgument(
858 "UPPER requires 1 argument".into(),
859 ));
860 }
861 let val = self.evaluate_expr(&func.args[0], row, params)?;
862 match val {
863 SochValue::Text(s) => Ok(SochValue::Text(s.to_uppercase())),
864 SochValue::Null => Ok(SochValue::Null),
865 _ => Err(SqlError::TypeError("UPPER requires string argument".into())),
866 }
867 }
868
869 "LOWER" => {
870 if func.args.len() != 1 {
871 return Err(SqlError::InvalidArgument(
872 "LOWER requires 1 argument".into(),
873 ));
874 }
875 let val = self.evaluate_expr(&func.args[0], row, params)?;
876 match val {
877 SochValue::Text(s) => Ok(SochValue::Text(s.to_lowercase())),
878 SochValue::Null => Ok(SochValue::Null),
879 _ => Err(SqlError::TypeError("LOWER requires string argument".into())),
880 }
881 }
882
883 "TRIM" => {
884 if func.args.len() != 1 {
885 return Err(SqlError::InvalidArgument("TRIM requires 1 argument".into()));
886 }
887 let val = self.evaluate_expr(&func.args[0], row, params)?;
888 match val {
889 SochValue::Text(s) => Ok(SochValue::Text(s.trim().to_string())),
890 SochValue::Null => Ok(SochValue::Null),
891 _ => Err(SqlError::TypeError("TRIM requires string argument".into())),
892 }
893 }
894
895 "SUBSTR" | "SUBSTRING" => {
896 if func.args.len() < 2 || func.args.len() > 3 {
897 return Err(SqlError::InvalidArgument(
898 "SUBSTR requires 2 or 3 arguments".into(),
899 ));
900 }
901 let val = self.evaluate_expr(&func.args[0], row, params)?;
902 let start = self.evaluate_expr(&func.args[1], row, params)?;
903 let len = if func.args.len() == 3 {
904 Some(self.evaluate_expr(&func.args[2], row, params)?)
905 } else {
906 None
907 };
908
909 match (val, start) {
910 (SochValue::Text(s), SochValue::Int(start)) => {
911 let start = (start.max(1) - 1) as usize;
912 if start >= s.len() {
913 return Ok(SochValue::Text(String::new()));
914 }
915 let result = if let Some(SochValue::Int(len)) = len {
916 s.chars().skip(start).take(len as usize).collect()
917 } else {
918 s.chars().skip(start).collect()
919 };
920 Ok(SochValue::Text(result))
921 }
922 (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
923 _ => Err(SqlError::TypeError(
924 "SUBSTR requires string and integer arguments".into(),
925 )),
926 }
927 }
928
929 _ => Err(SqlError::NotImplemented(format!(
930 "Function {} not implemented",
931 func_name
932 ))),
933 }
934 }
935
936 fn values_equal(&self, left: &SochValue, right: &SochValue) -> bool {
939 match (left, right) {
940 (SochValue::Null, _) | (_, SochValue::Null) => false,
941 (SochValue::Int(l), SochValue::Int(r)) => l == r,
942 (SochValue::Float(l), SochValue::Float(r)) => (l - r).abs() < f64::EPSILON,
943 (SochValue::Int(l), SochValue::Float(r)) => (*l as f64 - r).abs() < f64::EPSILON,
944 (SochValue::Float(l), SochValue::Int(r)) => (l - *r as f64).abs() < f64::EPSILON,
945 (SochValue::Text(l), SochValue::Text(r)) => l == r,
946 (SochValue::Bool(l), SochValue::Bool(r)) => l == r,
947 (SochValue::Binary(l), SochValue::Binary(r)) => l == r,
948 (SochValue::UInt(l), SochValue::UInt(r)) => l == r,
949 (SochValue::Int(l), SochValue::UInt(r)) => *l >= 0 && (*l as u64) == *r,
950 (SochValue::UInt(l), SochValue::Int(r)) => *r >= 0 && *l == (*r as u64),
951 _ => false,
952 }
953 }
954
955 fn compare_values(
956 &self,
957 left: Option<&SochValue>,
958 right: Option<&SochValue>,
959 ) -> std::cmp::Ordering {
960 match (left, right) {
961 (None, None) => std::cmp::Ordering::Equal,
962 (None, _) => std::cmp::Ordering::Less,
963 (_, None) => std::cmp::Ordering::Greater,
964 (Some(SochValue::Null), _) | (_, Some(SochValue::Null)) => std::cmp::Ordering::Equal,
965 (Some(SochValue::Int(l)), Some(SochValue::Int(r))) => l.cmp(r),
966 (Some(SochValue::Float(l)), Some(SochValue::Float(r))) => {
967 l.partial_cmp(r).unwrap_or(std::cmp::Ordering::Equal)
968 }
969 (Some(SochValue::Int(l)), Some(SochValue::Float(r))) => (*l as f64)
970 .partial_cmp(r)
971 .unwrap_or(std::cmp::Ordering::Equal),
972 (Some(SochValue::Float(l)), Some(SochValue::Int(r))) => l
973 .partial_cmp(&(*r as f64))
974 .unwrap_or(std::cmp::Ordering::Equal),
975 (Some(SochValue::Text(l)), Some(SochValue::Text(r))) => l.cmp(r),
976 (Some(SochValue::UInt(l)), Some(SochValue::UInt(r))) => l.cmp(r),
977 _ => std::cmp::Ordering::Equal,
978 }
979 }
980
981 fn arithmetic_op<FI, FF>(
982 &self,
983 left: &SochValue,
984 right: &SochValue,
985 int_op: FI,
986 float_op: FF,
987 ) -> SqlResult<SochValue>
988 where
989 FI: Fn(i64, i64) -> i64,
990 FF: Fn(f64, f64) -> f64,
991 {
992 match (left, right) {
993 (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
994 (SochValue::Int(l), SochValue::Int(r)) => Ok(SochValue::Int(int_op(*l, *r))),
995 (SochValue::Float(l), SochValue::Float(r)) => Ok(SochValue::Float(float_op(*l, *r))),
996 (SochValue::Int(l), SochValue::Float(r)) => {
997 Ok(SochValue::Float(float_op(*l as f64, *r)))
998 }
999 (SochValue::Float(l), SochValue::Int(r)) => {
1000 Ok(SochValue::Float(float_op(*l, *r as f64)))
1001 }
1002 (SochValue::UInt(l), SochValue::UInt(r)) => {
1003 Ok(SochValue::Int(int_op(*l as i64, *r as i64)))
1004 }
1005 (SochValue::Int(l), SochValue::UInt(r)) => Ok(SochValue::Int(int_op(*l, *r as i64))),
1006 (SochValue::UInt(l), SochValue::Int(r)) => Ok(SochValue::Int(int_op(*l as i64, *r))),
1007 _ => Err(SqlError::TypeError(
1008 "Arithmetic requires numeric operands".into(),
1009 )),
1010 }
1011 }
1012}
1013
1014#[cfg(test)]
1015mod tests {
1016 use super::*;
1017
1018 #[test]
1019 fn test_create_table_and_insert() {
1020 let mut executor = SqlExecutor::new();
1021
1022 let result = executor
1024 .execute("CREATE TABLE users (id INTEGER, name VARCHAR(100))")
1025 .unwrap();
1026 assert!(matches!(result, ExecutionResult::Ok));
1027
1028 let result = executor
1030 .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1031 .unwrap();
1032 assert_eq!(result.rows_affected(), 1);
1033
1034 let result = executor
1035 .execute("INSERT INTO users (id, name) VALUES (2, 'Bob')")
1036 .unwrap();
1037 assert_eq!(result.rows_affected(), 1);
1038
1039 let result = executor.execute("SELECT * FROM users").unwrap();
1041 assert_eq!(result.rows_affected(), 2);
1042 }
1043
1044 #[test]
1045 fn test_select_with_where() {
1046 let mut executor = SqlExecutor::new();
1047
1048 executor
1049 .execute("CREATE TABLE products (id INTEGER, name TEXT, price FLOAT)")
1050 .unwrap();
1051 executor
1052 .execute("INSERT INTO products (id, name, price) VALUES (1, 'Apple', 1.50)")
1053 .unwrap();
1054 executor
1055 .execute("INSERT INTO products (id, name, price) VALUES (2, 'Banana', 0.75)")
1056 .unwrap();
1057 executor
1058 .execute("INSERT INTO products (id, name, price) VALUES (3, 'Orange', 2.00)")
1059 .unwrap();
1060
1061 let result = executor
1062 .execute("SELECT * FROM products WHERE price > 1.0")
1063 .unwrap();
1064 assert_eq!(result.rows_affected(), 2);
1065 }
1066
1067 #[test]
1068 fn test_update() {
1069 let mut executor = SqlExecutor::new();
1070
1071 executor
1072 .execute("CREATE TABLE users (id INTEGER, name TEXT)")
1073 .unwrap();
1074 executor
1075 .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1076 .unwrap();
1077
1078 let result = executor
1079 .execute("UPDATE users SET name = 'Alicia' WHERE id = 1")
1080 .unwrap();
1081 assert_eq!(result.rows_affected(), 1);
1082
1083 let result = executor
1084 .execute("SELECT * FROM users WHERE name = 'Alicia'")
1085 .unwrap();
1086 assert_eq!(result.rows_affected(), 1);
1087 }
1088
1089 #[test]
1090 fn test_delete() {
1091 let mut executor = SqlExecutor::new();
1092
1093 executor
1094 .execute("CREATE TABLE users (id INTEGER, name TEXT)")
1095 .unwrap();
1096 executor
1097 .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1098 .unwrap();
1099 executor
1100 .execute("INSERT INTO users (id, name) VALUES (2, 'Bob')")
1101 .unwrap();
1102
1103 let result = executor.execute("DELETE FROM users WHERE id = 1").unwrap();
1104 assert_eq!(result.rows_affected(), 1);
1105
1106 let result = executor.execute("SELECT * FROM users").unwrap();
1107 assert_eq!(result.rows_affected(), 1);
1108 }
1109
1110 #[test]
1111 fn test_functions() {
1112 let mut executor = SqlExecutor::new();
1113
1114 executor.execute("CREATE TABLE t (s TEXT)").unwrap();
1115 executor
1116 .execute("INSERT INTO t (s) VALUES ('hello')")
1117 .unwrap();
1118
1119 let result = executor.execute("SELECT UPPER(s) FROM t").unwrap();
1120 if let ExecutionResult::Rows { rows, .. } = result {
1121 let row = &rows[0];
1122 assert!(
1124 row.values()
1125 .any(|v| matches!(v, SochValue::Text(s) if s == "HELLO"))
1126 );
1127 } else {
1128 panic!("Expected rows");
1129 }
1130 }
1131
1132 #[test]
1133 fn test_order_by() {
1134 let mut executor = SqlExecutor::new();
1135
1136 executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1137 executor.execute("INSERT INTO nums (n) VALUES (3)").unwrap();
1138 executor.execute("INSERT INTO nums (n) VALUES (1)").unwrap();
1139 executor.execute("INSERT INTO nums (n) VALUES (2)").unwrap();
1140
1141 let result = executor
1142 .execute("SELECT * FROM nums ORDER BY n ASC")
1143 .unwrap();
1144 if let ExecutionResult::Rows { rows, .. } = result {
1145 let values: Vec<i64> = rows
1146 .iter()
1147 .filter_map(|r| r.get("n"))
1148 .filter_map(|v| {
1149 if let SochValue::Int(n) = v {
1150 Some(*n)
1151 } else {
1152 None
1153 }
1154 })
1155 .collect();
1156 assert_eq!(values, vec![1, 2, 3]);
1157 } else {
1158 panic!("Expected rows");
1159 }
1160 }
1161
1162 #[test]
1163 fn test_limit_offset() {
1164 let mut executor = SqlExecutor::new();
1165
1166 executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1167 for i in 1..=10 {
1168 executor
1169 .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1170 .unwrap();
1171 }
1172
1173 let result = executor
1174 .execute("SELECT * FROM nums LIMIT 3 OFFSET 2")
1175 .unwrap();
1176 assert_eq!(result.rows_affected(), 3);
1177 }
1178
1179 #[test]
1180 fn test_between() {
1181 let mut executor = SqlExecutor::new();
1182
1183 executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1184 for i in 1..=10 {
1185 executor
1186 .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1187 .unwrap();
1188 }
1189
1190 let result = executor
1191 .execute("SELECT * FROM nums WHERE n BETWEEN 3 AND 7")
1192 .unwrap();
1193 assert_eq!(result.rows_affected(), 5);
1194 }
1195
1196 #[test]
1197 fn test_in_list() {
1198 let mut executor = SqlExecutor::new();
1199
1200 executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1201 for i in 1..=5 {
1202 executor
1203 .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1204 .unwrap();
1205 }
1206
1207 let result = executor
1208 .execute("SELECT * FROM nums WHERE n IN (1, 3, 5)")
1209 .unwrap();
1210 assert_eq!(result.rows_affected(), 3);
1211 }
1212}