1pub mod ast;
28pub mod bridge;
29pub mod compatibility;
30pub mod error;
31pub mod lexer;
32pub mod parser;
33pub mod token;
34
35pub use ast::*;
36pub use bridge::{ExecutionResult as BridgeExecutionResult, SqlBridge, SqlConnection};
37pub use compatibility::{CompatibilityMatrix, FeatureSupport, SqlDialect, SqlFeature, get_feature_support};
38pub use error::{SqlError, SqlResult};
39pub use lexer::{LexError, Lexer};
40pub use parser::{ParseError, Parser};
41pub use token::{Span, Token, TokenKind};
42
43use std::collections::HashMap;
44use sochdb_core::SochValue;
45
46#[derive(Debug, Clone)]
48pub enum ExecutionResult {
49 Rows {
51 columns: Vec<String>,
52 rows: Vec<HashMap<String, SochValue>>,
53 },
54 RowsAffected(usize),
56 Ok,
58}
59
60impl ExecutionResult {
61 pub fn rows(&self) -> Option<&Vec<HashMap<String, SochValue>>> {
63 match self {
64 ExecutionResult::Rows { rows, .. } => Some(rows),
65 _ => None,
66 }
67 }
68
69 pub fn columns(&self) -> Option<&Vec<String>> {
71 match self {
72 ExecutionResult::Rows { columns, .. } => Some(columns),
73 _ => None,
74 }
75 }
76
77 pub fn rows_affected(&self) -> usize {
79 match self {
80 ExecutionResult::RowsAffected(n) => *n,
81 ExecutionResult::Rows { rows, .. } => rows.len(),
82 ExecutionResult::Ok => 0,
83 }
84 }
85}
86
87pub struct SqlExecutor {
91 tables: HashMap<String, TableData>,
93}
94
95#[derive(Debug, Clone)]
97pub struct TableData {
98 pub columns: Vec<String>,
99 pub column_types: Vec<DataType>,
100 pub rows: Vec<Vec<SochValue>>,
101}
102
103impl Default for SqlExecutor {
104 fn default() -> Self {
105 Self::new()
106 }
107}
108
109impl SqlExecutor {
110 pub fn new() -> Self {
112 Self {
113 tables: HashMap::new(),
114 }
115 }
116
117 pub fn execute(&mut self, sql: &str) -> SqlResult<ExecutionResult> {
119 self.execute_with_params(sql, &[])
120 }
121
122 pub fn execute_with_params(
124 &mut self,
125 sql: &str,
126 params: &[SochValue],
127 ) -> SqlResult<ExecutionResult> {
128 let stmt = Parser::parse(sql).map_err(SqlError::from_parse_errors)?;
129 self.execute_statement(&stmt, params)
130 }
131
132 pub fn execute_statement(
134 &mut self,
135 stmt: &Statement,
136 params: &[SochValue],
137 ) -> SqlResult<ExecutionResult> {
138 match stmt {
139 Statement::Select(select) => self.execute_select(select, params),
140 Statement::Insert(insert) => self.execute_insert(insert, params),
141 Statement::Update(update) => self.execute_update(update, params),
142 Statement::Delete(delete) => self.execute_delete(delete, params),
143 Statement::CreateTable(create) => self.execute_create_table(create),
144 Statement::DropTable(drop) => self.execute_drop_table(drop),
145 Statement::Begin(_) => Ok(ExecutionResult::Ok),
146 Statement::Commit => Ok(ExecutionResult::Ok),
147 Statement::Rollback(_) => Ok(ExecutionResult::Ok),
148 _ => Err(SqlError::NotImplemented(
149 "Statement type not yet supported".into(),
150 )),
151 }
152 }
153
154 fn execute_select(
155 &self,
156 select: &SelectStmt,
157 params: &[SochValue],
158 ) -> SqlResult<ExecutionResult> {
159 let from = select
161 .from
162 .as_ref()
163 .ok_or_else(|| SqlError::InvalidArgument("SELECT requires FROM clause".into()))?;
164
165 if from.tables.len() != 1 {
166 return Err(SqlError::NotImplemented(
167 "Multi-table queries not yet supported".into(),
168 ));
169 }
170
171 let table_name = match &from.tables[0] {
172 TableRef::Table { name, .. } => name.name().to_string(),
173 _ => {
174 return Err(SqlError::NotImplemented(
175 "Complex table references not yet supported".into(),
176 ));
177 }
178 };
179
180 let table = self
181 .tables
182 .get(&table_name)
183 .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
184
185 let mut source_rows = Vec::new();
187
188 for row in &table.rows {
189 let row_map: HashMap<String, SochValue> = table
191 .columns
192 .iter()
193 .zip(row.iter())
194 .map(|(col, val)| (col.clone(), val.clone()))
195 .collect();
196
197 if let Some(where_clause) = &select.where_clause
199 && !self.evaluate_where(where_clause, &row_map, params)?
200 {
201 continue;
202 }
203
204 source_rows.push(row_map);
205 }
206
207 if !select.order_by.is_empty() {
209 source_rows.sort_by(|a, b| {
210 for order_item in &select.order_by {
211 if let Expr::Column(col_ref) = &order_item.expr {
212 let a_val = a.get(&col_ref.column);
213 let b_val = b.get(&col_ref.column);
214
215 let cmp = self.compare_values(a_val, b_val);
216 if cmp != std::cmp::Ordering::Equal {
217 return if order_item.asc { cmp } else { cmp.reverse() };
218 }
219 }
220 }
221 std::cmp::Ordering::Equal
222 });
223 }
224
225 if let Some(Expr::Literal(Literal::Integer(n))) = &select.offset {
227 let n = *n as usize;
228 if n < source_rows.len() {
229 source_rows = source_rows.into_iter().skip(n).collect();
230 } else {
231 source_rows.clear();
232 }
233 }
234
235 if let Some(Expr::Literal(Literal::Integer(n))) = &select.limit {
237 source_rows.truncate(*n as usize);
238 }
239
240 let mut output_columns: Vec<String> = Vec::new();
242 let mut result_rows: Vec<HashMap<String, SochValue>> = Vec::new();
243
244 let is_wildcard = matches!(&select.columns[..], [SelectItem::Wildcard]);
246
247 if is_wildcard {
248 output_columns = table.columns.clone();
249 result_rows = source_rows;
250 } else {
251 for item in &select.columns {
253 match item {
254 SelectItem::Wildcard => output_columns.push("*".to_string()),
255 SelectItem::QualifiedWildcard(t) => output_columns.push(format!("{}.*", t)),
256 SelectItem::Expr { expr, alias } => {
257 let col_name = alias.clone().unwrap_or_else(|| match expr {
258 Expr::Column(col) => col.column.clone(),
259 Expr::Function(func) => format!("{}()", func.name.name()),
260 _ => "?column?".to_string(),
261 });
262 output_columns.push(col_name);
263 }
264 }
265 }
266
267 for source_row in &source_rows {
269 let mut result_row = HashMap::new();
270
271 for (idx, item) in select.columns.iter().enumerate() {
272 let col_name = &output_columns[idx];
273
274 match item {
275 SelectItem::Wildcard => {
276 for (k, v) in source_row {
278 result_row.insert(k.clone(), v.clone());
279 }
280 }
281 SelectItem::QualifiedWildcard(_) => {
282 for (k, v) in source_row {
284 result_row.insert(k.clone(), v.clone());
285 }
286 }
287 SelectItem::Expr { expr, .. } => {
288 let value = self.evaluate_expr(expr, source_row, params)?;
289 result_row.insert(col_name.clone(), value);
290 }
291 }
292 }
293
294 result_rows.push(result_row);
295 }
296 }
297
298 Ok(ExecutionResult::Rows {
299 columns: output_columns,
300 rows: result_rows,
301 })
302 }
303
304 fn execute_insert(
305 &mut self,
306 insert: &InsertStmt,
307 params: &[SochValue],
308 ) -> SqlResult<ExecutionResult> {
309 let table_name = insert.table.name().to_string();
310
311 let table_columns = {
313 let table = self
314 .tables
315 .get(&table_name)
316 .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
317 table.columns.clone()
318 };
319
320 let mut rows_affected = 0;
321 let mut new_rows = Vec::new();
322
323 match &insert.source {
324 InsertSource::Values(rows) => {
325 for value_exprs in rows {
326 let mut row_values = Vec::new();
327
328 if let Some(columns) = &insert.columns {
329 if columns.len() != value_exprs.len() {
330 return Err(SqlError::InvalidArgument(format!(
331 "Column count ({}) doesn't match value count ({})",
332 columns.len(),
333 value_exprs.len()
334 )));
335 }
336
337 for table_col in &table_columns {
339 if let Some(pos) = columns.iter().position(|c| c == table_col) {
340 let value =
341 self.evaluate_expr(&value_exprs[pos], &HashMap::new(), params)?;
342 row_values.push(value);
343 } else {
344 row_values.push(SochValue::Null);
345 }
346 }
347 } else {
348 for expr in value_exprs {
350 let value = self.evaluate_expr(expr, &HashMap::new(), params)?;
351 row_values.push(value);
352 }
353 }
354
355 new_rows.push(row_values);
356 rows_affected += 1;
357 }
358 }
359 InsertSource::Query(_) => {
360 return Err(SqlError::NotImplemented(
361 "INSERT ... SELECT not yet supported".into(),
362 ));
363 }
364 InsertSource::Default => {
365 return Err(SqlError::NotImplemented(
366 "INSERT DEFAULT VALUES not yet supported".into(),
367 ));
368 }
369 }
370
371 let table = self.tables.get_mut(&table_name).unwrap();
373 for row in new_rows {
374 table.rows.push(row);
375 }
376
377 Ok(ExecutionResult::RowsAffected(rows_affected))
378 }
379
380 fn execute_update(
381 &mut self,
382 update: &UpdateStmt,
383 params: &[SochValue],
384 ) -> SqlResult<ExecutionResult> {
385 let table_name = update.table.name().to_string();
386
387 let (_table_columns, updates_to_apply) = {
389 let table = self
390 .tables
391 .get(&table_name)
392 .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
393
394 let mut updates = Vec::new();
395
396 for row_idx in 0..table.rows.len() {
397 let row_map: HashMap<String, SochValue> = table
399 .columns
400 .iter()
401 .zip(table.rows[row_idx].iter())
402 .map(|(col, val)| (col.clone(), val.clone()))
403 .collect();
404
405 let matches = if let Some(where_clause) = &update.where_clause {
407 self.evaluate_where(where_clause, &row_map, params)?
408 } else {
409 true
410 };
411
412 if matches {
413 let mut row_updates = Vec::new();
415 for assignment in &update.assignments {
416 if let Some(col_idx) =
417 table.columns.iter().position(|c| c == &assignment.column)
418 {
419 let value = self.evaluate_expr(&assignment.value, &row_map, params)?;
420 row_updates.push((col_idx, value));
421 }
422 }
423 updates.push((row_idx, row_updates));
424 }
425 }
426
427 (table.columns.clone(), updates)
428 };
429
430 let rows_affected = updates_to_apply.len();
431
432 let table = self.tables.get_mut(&table_name).unwrap();
434 for (row_idx, row_updates) in updates_to_apply {
435 for (col_idx, value) in row_updates {
436 table.rows[row_idx][col_idx] = value;
437 }
438 }
439
440 Ok(ExecutionResult::RowsAffected(rows_affected))
441 }
442
443 fn execute_delete(
444 &mut self,
445 delete: &DeleteStmt,
446 params: &[SochValue],
447 ) -> SqlResult<ExecutionResult> {
448 let table_name = delete.table.name().to_string();
449
450 let indices_to_remove = {
452 let table = self
453 .tables
454 .get(&table_name)
455 .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
456
457 let mut indices = Vec::new();
458
459 for (row_idx, row) in table.rows.iter().enumerate() {
460 let row_map: HashMap<String, SochValue> = table
462 .columns
463 .iter()
464 .zip(row.iter())
465 .map(|(col, val)| (col.clone(), val.clone()))
466 .collect();
467
468 let matches = if let Some(where_clause) = &delete.where_clause {
470 self.evaluate_where(where_clause, &row_map, params)?
471 } else {
472 true
473 };
474
475 if matches {
476 indices.push(row_idx);
477 }
478 }
479
480 indices
481 };
482
483 let rows_affected = indices_to_remove.len();
484
485 let table = self.tables.get_mut(&table_name).unwrap();
487 for idx in indices_to_remove.into_iter().rev() {
489 table.rows.remove(idx);
490 }
491
492 Ok(ExecutionResult::RowsAffected(rows_affected))
493 }
494
495 fn execute_create_table(&mut self, create: &CreateTableStmt) -> SqlResult<ExecutionResult> {
496 let table_name = create.name.name().to_string();
497
498 if self.tables.contains_key(&table_name) {
499 if create.if_not_exists {
500 return Ok(ExecutionResult::Ok);
501 }
502 return Err(SqlError::ConstraintViolation(format!(
503 "Table '{}' already exists",
504 table_name
505 )));
506 }
507
508 let columns: Vec<String> = create.columns.iter().map(|c| c.name.clone()).collect();
509 let column_types: Vec<DataType> =
510 create.columns.iter().map(|c| c.data_type.clone()).collect();
511
512 self.tables.insert(
513 table_name,
514 TableData {
515 columns,
516 column_types,
517 rows: Vec::new(),
518 },
519 );
520
521 Ok(ExecutionResult::Ok)
522 }
523
524 fn execute_drop_table(&mut self, drop: &DropTableStmt) -> SqlResult<ExecutionResult> {
525 for name in &drop.names {
526 let table_name = name.name().to_string();
527 if self.tables.remove(&table_name).is_none() && !drop.if_exists {
528 return Err(SqlError::TableNotFound(table_name));
529 }
530 }
531
532 Ok(ExecutionResult::Ok)
533 }
534
535 fn evaluate_where(
538 &self,
539 expr: &Expr,
540 row: &HashMap<String, SochValue>,
541 params: &[SochValue],
542 ) -> SqlResult<bool> {
543 let value = self.evaluate_expr(expr, row, params)?;
544 match value {
545 SochValue::Bool(b) => Ok(b),
546 SochValue::Null => Ok(false),
547 _ => Err(SqlError::TypeError(
548 "WHERE clause must evaluate to boolean".into(),
549 )),
550 }
551 }
552
553 fn evaluate_expr(
554 &self,
555 expr: &Expr,
556 row: &HashMap<String, SochValue>,
557 params: &[SochValue],
558 ) -> SqlResult<SochValue> {
559 match expr {
560 Expr::Literal(lit) => Ok(self.literal_to_value(lit)),
561
562 Expr::Column(col_ref) => row
563 .get(&col_ref.column)
564 .cloned()
565 .ok_or_else(|| SqlError::ColumnNotFound(col_ref.column.clone())),
566
567 Expr::Placeholder(n) => params
568 .get((*n as usize).saturating_sub(1))
569 .cloned()
570 .ok_or_else(|| SqlError::InvalidArgument(format!("Parameter ${} not provided", n))),
571
572 Expr::BinaryOp { left, op, right } => {
573 let left_val = self.evaluate_expr(left, row, params)?;
574 let right_val = self.evaluate_expr(right, row, params)?;
575 self.evaluate_binary_op(&left_val, op, &right_val)
576 }
577
578 Expr::UnaryOp { op, expr } => {
579 let val = self.evaluate_expr(expr, row, params)?;
580 self.evaluate_unary_op(op, &val)
581 }
582
583 Expr::IsNull { expr, negated } => {
584 let val = self.evaluate_expr(expr, row, params)?;
585 let is_null = matches!(val, SochValue::Null);
586 Ok(SochValue::Bool(if *negated { !is_null } else { is_null }))
587 }
588
589 Expr::InList {
590 expr,
591 list,
592 negated,
593 } => {
594 let val = self.evaluate_expr(expr, row, params)?;
595 let mut found = false;
596 for item in list {
597 let item_val = self.evaluate_expr(item, row, params)?;
598 if self.values_equal(&val, &item_val) {
599 found = true;
600 break;
601 }
602 }
603 Ok(SochValue::Bool(if *negated { !found } else { found }))
604 }
605
606 Expr::Between {
607 expr,
608 low,
609 high,
610 negated,
611 } => {
612 let val = self.evaluate_expr(expr, row, params)?;
613 let low_val = self.evaluate_expr(low, row, params)?;
614 let high_val = self.evaluate_expr(high, row, params)?;
615
616 let cmp_low = self.compare_values(Some(&val), Some(&low_val));
617 let cmp_high = self.compare_values(Some(&val), Some(&high_val));
618
619 let in_range =
620 cmp_low != std::cmp::Ordering::Less && cmp_high != std::cmp::Ordering::Greater;
621
622 Ok(SochValue::Bool(if *negated { !in_range } else { in_range }))
623 }
624
625 Expr::Like {
626 expr,
627 pattern,
628 negated,
629 ..
630 } => {
631 let val = self.evaluate_expr(expr, row, params)?;
632 let pattern_val = self.evaluate_expr(pattern, row, params)?;
633
634 match (&val, &pattern_val) {
635 (SochValue::Text(s), SochValue::Text(p)) => {
636 let regex_pattern = p.replace('%', ".*").replace('_', ".");
637 let matches = regex::Regex::new(&format!("^{}$", regex_pattern))
638 .map(|re| re.is_match(s))
639 .unwrap_or(false);
640 Ok(SochValue::Bool(if *negated { !matches } else { matches }))
641 }
642 _ => Ok(SochValue::Bool(false)),
643 }
644 }
645
646 Expr::Function(func) => self.evaluate_function(func, row, params),
647
648 Expr::Case {
649 operand,
650 conditions,
651 else_result,
652 } => {
653 if let Some(op) = operand {
654 let op_val = self.evaluate_expr(op, row, params)?;
656 for (when_expr, then_expr) in conditions {
657 let when_val = self.evaluate_expr(when_expr, row, params)?;
658 if self.values_equal(&op_val, &when_val) {
659 return self.evaluate_expr(then_expr, row, params);
660 }
661 }
662 } else {
663 for (when_expr, then_expr) in conditions {
665 let when_val = self.evaluate_expr(when_expr, row, params)?;
666 if matches!(when_val, SochValue::Bool(true)) {
667 return self.evaluate_expr(then_expr, row, params);
668 }
669 }
670 }
671
672 if let Some(else_expr) = else_result {
673 self.evaluate_expr(else_expr, row, params)
674 } else {
675 Ok(SochValue::Null)
676 }
677 }
678
679 _ => Err(SqlError::NotImplemented(format!(
680 "Expression type {:?} not yet supported",
681 expr
682 ))),
683 }
684 }
685
686 fn literal_to_value(&self, lit: &Literal) -> SochValue {
687 match lit {
688 Literal::Null => SochValue::Null,
689 Literal::Boolean(b) => SochValue::Bool(*b),
690 Literal::Integer(n) => SochValue::Int(*n),
691 Literal::Float(f) => SochValue::Float(*f),
692 Literal::String(s) => SochValue::Text(s.clone()),
693 Literal::Blob(b) => SochValue::Binary(b.clone()),
694 }
695 }
696
697 fn evaluate_binary_op(
698 &self,
699 left: &SochValue,
700 op: &BinaryOperator,
701 right: &SochValue,
702 ) -> SqlResult<SochValue> {
703 match op {
704 BinaryOperator::Eq => Ok(SochValue::Bool(self.values_equal(left, right))),
705 BinaryOperator::Ne => Ok(SochValue::Bool(!self.values_equal(left, right))),
706 BinaryOperator::Lt => Ok(SochValue::Bool(
707 self.compare_values(Some(left), Some(right)) == std::cmp::Ordering::Less,
708 )),
709 BinaryOperator::Le => Ok(SochValue::Bool(
710 self.compare_values(Some(left), Some(right)) != std::cmp::Ordering::Greater,
711 )),
712 BinaryOperator::Gt => Ok(SochValue::Bool(
713 self.compare_values(Some(left), Some(right)) == std::cmp::Ordering::Greater,
714 )),
715 BinaryOperator::Ge => Ok(SochValue::Bool(
716 self.compare_values(Some(left), Some(right)) != std::cmp::Ordering::Less,
717 )),
718
719 BinaryOperator::And => match (left, right) {
720 (SochValue::Bool(l), SochValue::Bool(r)) => Ok(SochValue::Bool(*l && *r)),
721 (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
722 _ => Err(SqlError::TypeError("AND requires boolean operands".into())),
723 },
724
725 BinaryOperator::Or => match (left, right) {
726 (SochValue::Bool(l), SochValue::Bool(r)) => Ok(SochValue::Bool(*l || *r)),
727 (SochValue::Bool(true), _) | (_, SochValue::Bool(true)) => {
728 Ok(SochValue::Bool(true))
729 }
730 (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
731 _ => Err(SqlError::TypeError("OR requires boolean operands".into())),
732 },
733
734 BinaryOperator::Plus => self.arithmetic_op(left, right, |a, b| a + b, |a, b| a + b),
735 BinaryOperator::Minus => self.arithmetic_op(left, right, |a, b| a - b, |a, b| a - b),
736 BinaryOperator::Multiply => self.arithmetic_op(left, right, |a, b| a * b, |a, b| a * b),
737 BinaryOperator::Divide => self.arithmetic_op(
738 left,
739 right,
740 |a, b| if b != 0 { a / b } else { 0 },
741 |a, b| a / b,
742 ),
743 BinaryOperator::Modulo => self.arithmetic_op(
744 left,
745 right,
746 |a, b| if b != 0 { a % b } else { 0 },
747 |a, b| a % b,
748 ),
749
750 BinaryOperator::Concat => match (left, right) {
751 (SochValue::Text(l), SochValue::Text(r)) => {
752 Ok(SochValue::Text(format!("{}{}", l, r)))
753 }
754 (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
755 _ => Err(SqlError::TypeError("|| requires string operands".into())),
756 },
757
758 _ => Err(SqlError::NotImplemented(format!(
759 "Operator {:?} not implemented",
760 op
761 ))),
762 }
763 }
764
765 fn evaluate_unary_op(&self, op: &UnaryOperator, val: &SochValue) -> SqlResult<SochValue> {
766 match op {
767 UnaryOperator::Not => match val {
768 SochValue::Bool(b) => Ok(SochValue::Bool(!b)),
769 SochValue::Null => Ok(SochValue::Null),
770 _ => Err(SqlError::TypeError("NOT requires boolean operand".into())),
771 },
772 UnaryOperator::Minus => match val {
773 SochValue::Int(n) => Ok(SochValue::Int(-n)),
774 SochValue::Float(f) => Ok(SochValue::Float(-f)),
775 SochValue::Null => Ok(SochValue::Null),
776 _ => Err(SqlError::TypeError(
777 "Unary minus requires numeric operand".into(),
778 )),
779 },
780 UnaryOperator::Plus => Ok(val.clone()),
781 UnaryOperator::BitNot => match val {
782 SochValue::Int(n) => Ok(SochValue::Int(!n)),
783 _ => Err(SqlError::TypeError("~ requires integer operand".into())),
784 },
785 }
786 }
787
788 fn evaluate_function(
789 &self,
790 func: &FunctionCall,
791 row: &HashMap<String, SochValue>,
792 params: &[SochValue],
793 ) -> SqlResult<SochValue> {
794 let func_name = func.name.name().to_uppercase();
795
796 match func_name.as_str() {
797 "COALESCE" => {
798 for arg in &func.args {
799 let val = self.evaluate_expr(arg, row, params)?;
800 if !matches!(val, SochValue::Null) {
801 return Ok(val);
802 }
803 }
804 Ok(SochValue::Null)
805 }
806
807 "NULLIF" => {
808 if func.args.len() != 2 {
809 return Err(SqlError::InvalidArgument(
810 "NULLIF requires 2 arguments".into(),
811 ));
812 }
813 let val1 = self.evaluate_expr(&func.args[0], row, params)?;
814 let val2 = self.evaluate_expr(&func.args[1], row, params)?;
815 if self.values_equal(&val1, &val2) {
816 Ok(SochValue::Null)
817 } else {
818 Ok(val1)
819 }
820 }
821
822 "ABS" => {
823 if func.args.len() != 1 {
824 return Err(SqlError::InvalidArgument("ABS requires 1 argument".into()));
825 }
826 let val = self.evaluate_expr(&func.args[0], row, params)?;
827 match val {
828 SochValue::Int(n) => Ok(SochValue::Int(n.abs())),
829 SochValue::Float(f) => Ok(SochValue::Float(f.abs())),
830 SochValue::Null => Ok(SochValue::Null),
831 _ => Err(SqlError::TypeError("ABS requires numeric argument".into())),
832 }
833 }
834
835 "LENGTH" | "LEN" => {
836 if func.args.len() != 1 {
837 return Err(SqlError::InvalidArgument(
838 "LENGTH requires 1 argument".into(),
839 ));
840 }
841 let val = self.evaluate_expr(&func.args[0], row, params)?;
842 match val {
843 SochValue::Text(s) => Ok(SochValue::Int(s.len() as i64)),
844 SochValue::Binary(b) => Ok(SochValue::Int(b.len() as i64)),
845 SochValue::Null => Ok(SochValue::Null),
846 _ => Err(SqlError::TypeError(
847 "LENGTH requires string argument".into(),
848 )),
849 }
850 }
851
852 "UPPER" => {
853 if func.args.len() != 1 {
854 return Err(SqlError::InvalidArgument(
855 "UPPER requires 1 argument".into(),
856 ));
857 }
858 let val = self.evaluate_expr(&func.args[0], row, params)?;
859 match val {
860 SochValue::Text(s) => Ok(SochValue::Text(s.to_uppercase())),
861 SochValue::Null => Ok(SochValue::Null),
862 _ => Err(SqlError::TypeError("UPPER requires string argument".into())),
863 }
864 }
865
866 "LOWER" => {
867 if func.args.len() != 1 {
868 return Err(SqlError::InvalidArgument(
869 "LOWER requires 1 argument".into(),
870 ));
871 }
872 let val = self.evaluate_expr(&func.args[0], row, params)?;
873 match val {
874 SochValue::Text(s) => Ok(SochValue::Text(s.to_lowercase())),
875 SochValue::Null => Ok(SochValue::Null),
876 _ => Err(SqlError::TypeError("LOWER requires string argument".into())),
877 }
878 }
879
880 "TRIM" => {
881 if func.args.len() != 1 {
882 return Err(SqlError::InvalidArgument("TRIM requires 1 argument".into()));
883 }
884 let val = self.evaluate_expr(&func.args[0], row, params)?;
885 match val {
886 SochValue::Text(s) => Ok(SochValue::Text(s.trim().to_string())),
887 SochValue::Null => Ok(SochValue::Null),
888 _ => Err(SqlError::TypeError("TRIM requires string argument".into())),
889 }
890 }
891
892 "SUBSTR" | "SUBSTRING" => {
893 if func.args.len() < 2 || func.args.len() > 3 {
894 return Err(SqlError::InvalidArgument(
895 "SUBSTR requires 2 or 3 arguments".into(),
896 ));
897 }
898 let val = self.evaluate_expr(&func.args[0], row, params)?;
899 let start = self.evaluate_expr(&func.args[1], row, params)?;
900 let len = if func.args.len() == 3 {
901 Some(self.evaluate_expr(&func.args[2], row, params)?)
902 } else {
903 None
904 };
905
906 match (val, start) {
907 (SochValue::Text(s), SochValue::Int(start)) => {
908 let start = (start.max(1) - 1) as usize;
909 if start >= s.len() {
910 return Ok(SochValue::Text(String::new()));
911 }
912 let result = if let Some(SochValue::Int(len)) = len {
913 s.chars().skip(start).take(len as usize).collect()
914 } else {
915 s.chars().skip(start).collect()
916 };
917 Ok(SochValue::Text(result))
918 }
919 (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
920 _ => Err(SqlError::TypeError(
921 "SUBSTR requires string and integer arguments".into(),
922 )),
923 }
924 }
925
926 _ => Err(SqlError::NotImplemented(format!(
927 "Function {} not implemented",
928 func_name
929 ))),
930 }
931 }
932
933 fn values_equal(&self, left: &SochValue, right: &SochValue) -> bool {
936 match (left, right) {
937 (SochValue::Null, _) | (_, SochValue::Null) => false,
938 (SochValue::Int(l), SochValue::Int(r)) => l == r,
939 (SochValue::Float(l), SochValue::Float(r)) => (l - r).abs() < f64::EPSILON,
940 (SochValue::Int(l), SochValue::Float(r)) => (*l as f64 - r).abs() < f64::EPSILON,
941 (SochValue::Float(l), SochValue::Int(r)) => (l - *r as f64).abs() < f64::EPSILON,
942 (SochValue::Text(l), SochValue::Text(r)) => l == r,
943 (SochValue::Bool(l), SochValue::Bool(r)) => l == r,
944 (SochValue::Binary(l), SochValue::Binary(r)) => l == r,
945 (SochValue::UInt(l), SochValue::UInt(r)) => l == r,
946 (SochValue::Int(l), SochValue::UInt(r)) => *l >= 0 && (*l as u64) == *r,
947 (SochValue::UInt(l), SochValue::Int(r)) => *r >= 0 && *l == (*r as u64),
948 _ => false,
949 }
950 }
951
952 fn compare_values(
953 &self,
954 left: Option<&SochValue>,
955 right: Option<&SochValue>,
956 ) -> std::cmp::Ordering {
957 match (left, right) {
958 (None, None) => std::cmp::Ordering::Equal,
959 (None, _) => std::cmp::Ordering::Less,
960 (_, None) => std::cmp::Ordering::Greater,
961 (Some(SochValue::Null), _) | (_, Some(SochValue::Null)) => std::cmp::Ordering::Equal,
962 (Some(SochValue::Int(l)), Some(SochValue::Int(r))) => l.cmp(r),
963 (Some(SochValue::Float(l)), Some(SochValue::Float(r))) => {
964 l.partial_cmp(r).unwrap_or(std::cmp::Ordering::Equal)
965 }
966 (Some(SochValue::Int(l)), Some(SochValue::Float(r))) => (*l as f64)
967 .partial_cmp(r)
968 .unwrap_or(std::cmp::Ordering::Equal),
969 (Some(SochValue::Float(l)), Some(SochValue::Int(r))) => l
970 .partial_cmp(&(*r as f64))
971 .unwrap_or(std::cmp::Ordering::Equal),
972 (Some(SochValue::Text(l)), Some(SochValue::Text(r))) => l.cmp(r),
973 (Some(SochValue::UInt(l)), Some(SochValue::UInt(r))) => l.cmp(r),
974 _ => std::cmp::Ordering::Equal,
975 }
976 }
977
978 fn arithmetic_op<FI, FF>(
979 &self,
980 left: &SochValue,
981 right: &SochValue,
982 int_op: FI,
983 float_op: FF,
984 ) -> SqlResult<SochValue>
985 where
986 FI: Fn(i64, i64) -> i64,
987 FF: Fn(f64, f64) -> f64,
988 {
989 match (left, right) {
990 (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
991 (SochValue::Int(l), SochValue::Int(r)) => Ok(SochValue::Int(int_op(*l, *r))),
992 (SochValue::Float(l), SochValue::Float(r)) => Ok(SochValue::Float(float_op(*l, *r))),
993 (SochValue::Int(l), SochValue::Float(r)) => {
994 Ok(SochValue::Float(float_op(*l as f64, *r)))
995 }
996 (SochValue::Float(l), SochValue::Int(r)) => {
997 Ok(SochValue::Float(float_op(*l, *r as f64)))
998 }
999 (SochValue::UInt(l), SochValue::UInt(r)) => {
1000 Ok(SochValue::Int(int_op(*l as i64, *r as i64)))
1001 }
1002 (SochValue::Int(l), SochValue::UInt(r)) => Ok(SochValue::Int(int_op(*l, *r as i64))),
1003 (SochValue::UInt(l), SochValue::Int(r)) => Ok(SochValue::Int(int_op(*l as i64, *r))),
1004 _ => Err(SqlError::TypeError(
1005 "Arithmetic requires numeric operands".into(),
1006 )),
1007 }
1008 }
1009}
1010
1011#[cfg(test)]
1012mod tests {
1013 use super::*;
1014
1015 #[test]
1016 fn test_create_table_and_insert() {
1017 let mut executor = SqlExecutor::new();
1018
1019 let result = executor
1021 .execute("CREATE TABLE users (id INTEGER, name VARCHAR(100))")
1022 .unwrap();
1023 assert!(matches!(result, ExecutionResult::Ok));
1024
1025 let result = executor
1027 .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1028 .unwrap();
1029 assert_eq!(result.rows_affected(), 1);
1030
1031 let result = executor
1032 .execute("INSERT INTO users (id, name) VALUES (2, 'Bob')")
1033 .unwrap();
1034 assert_eq!(result.rows_affected(), 1);
1035
1036 let result = executor.execute("SELECT * FROM users").unwrap();
1038 assert_eq!(result.rows_affected(), 2);
1039 }
1040
1041 #[test]
1042 fn test_select_with_where() {
1043 let mut executor = SqlExecutor::new();
1044
1045 executor
1046 .execute("CREATE TABLE products (id INTEGER, name TEXT, price FLOAT)")
1047 .unwrap();
1048 executor
1049 .execute("INSERT INTO products (id, name, price) VALUES (1, 'Apple', 1.50)")
1050 .unwrap();
1051 executor
1052 .execute("INSERT INTO products (id, name, price) VALUES (2, 'Banana', 0.75)")
1053 .unwrap();
1054 executor
1055 .execute("INSERT INTO products (id, name, price) VALUES (3, 'Orange', 2.00)")
1056 .unwrap();
1057
1058 let result = executor
1059 .execute("SELECT * FROM products WHERE price > 1.0")
1060 .unwrap();
1061 assert_eq!(result.rows_affected(), 2);
1062 }
1063
1064 #[test]
1065 fn test_update() {
1066 let mut executor = SqlExecutor::new();
1067
1068 executor
1069 .execute("CREATE TABLE users (id INTEGER, name TEXT)")
1070 .unwrap();
1071 executor
1072 .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1073 .unwrap();
1074
1075 let result = executor
1076 .execute("UPDATE users SET name = 'Alicia' WHERE id = 1")
1077 .unwrap();
1078 assert_eq!(result.rows_affected(), 1);
1079
1080 let result = executor
1081 .execute("SELECT * FROM users WHERE name = 'Alicia'")
1082 .unwrap();
1083 assert_eq!(result.rows_affected(), 1);
1084 }
1085
1086 #[test]
1087 fn test_delete() {
1088 let mut executor = SqlExecutor::new();
1089
1090 executor
1091 .execute("CREATE TABLE users (id INTEGER, name TEXT)")
1092 .unwrap();
1093 executor
1094 .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1095 .unwrap();
1096 executor
1097 .execute("INSERT INTO users (id, name) VALUES (2, 'Bob')")
1098 .unwrap();
1099
1100 let result = executor.execute("DELETE FROM users WHERE id = 1").unwrap();
1101 assert_eq!(result.rows_affected(), 1);
1102
1103 let result = executor.execute("SELECT * FROM users").unwrap();
1104 assert_eq!(result.rows_affected(), 1);
1105 }
1106
1107 #[test]
1108 fn test_functions() {
1109 let mut executor = SqlExecutor::new();
1110
1111 executor.execute("CREATE TABLE t (s TEXT)").unwrap();
1112 executor
1113 .execute("INSERT INTO t (s) VALUES ('hello')")
1114 .unwrap();
1115
1116 let result = executor.execute("SELECT UPPER(s) FROM t").unwrap();
1117 if let ExecutionResult::Rows { rows, .. } = result {
1118 let row = &rows[0];
1119 assert!(
1121 row.values()
1122 .any(|v| matches!(v, SochValue::Text(s) if s == "HELLO"))
1123 );
1124 } else {
1125 panic!("Expected rows");
1126 }
1127 }
1128
1129 #[test]
1130 fn test_order_by() {
1131 let mut executor = SqlExecutor::new();
1132
1133 executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1134 executor.execute("INSERT INTO nums (n) VALUES (3)").unwrap();
1135 executor.execute("INSERT INTO nums (n) VALUES (1)").unwrap();
1136 executor.execute("INSERT INTO nums (n) VALUES (2)").unwrap();
1137
1138 let result = executor
1139 .execute("SELECT * FROM nums ORDER BY n ASC")
1140 .unwrap();
1141 if let ExecutionResult::Rows { rows, .. } = result {
1142 let values: Vec<i64> = rows
1143 .iter()
1144 .filter_map(|r| r.get("n"))
1145 .filter_map(|v| {
1146 if let SochValue::Int(n) = v {
1147 Some(*n)
1148 } else {
1149 None
1150 }
1151 })
1152 .collect();
1153 assert_eq!(values, vec![1, 2, 3]);
1154 } else {
1155 panic!("Expected rows");
1156 }
1157 }
1158
1159 #[test]
1160 fn test_limit_offset() {
1161 let mut executor = SqlExecutor::new();
1162
1163 executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1164 for i in 1..=10 {
1165 executor
1166 .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1167 .unwrap();
1168 }
1169
1170 let result = executor
1171 .execute("SELECT * FROM nums LIMIT 3 OFFSET 2")
1172 .unwrap();
1173 assert_eq!(result.rows_affected(), 3);
1174 }
1175
1176 #[test]
1177 fn test_between() {
1178 let mut executor = SqlExecutor::new();
1179
1180 executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1181 for i in 1..=10 {
1182 executor
1183 .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1184 .unwrap();
1185 }
1186
1187 let result = executor
1188 .execute("SELECT * FROM nums WHERE n BETWEEN 3 AND 7")
1189 .unwrap();
1190 assert_eq!(result.rows_affected(), 5);
1191 }
1192
1193 #[test]
1194 fn test_in_list() {
1195 let mut executor = SqlExecutor::new();
1196
1197 executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1198 for i in 1..=5 {
1199 executor
1200 .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1201 .unwrap();
1202 }
1203
1204 let result = executor
1205 .execute("SELECT * FROM nums WHERE n IN (1, 3, 5)")
1206 .unwrap();
1207 assert_eq!(result.rows_affected(), 3);
1208 }
1209}