Skip to main content

sochdb_query/sql/
mod.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! SQL Module - Complete SQL support for SochDB
19//!
20//! Provides SQL parsing, planning, and execution.
21//!
22//! # Example
23//!
24//! ```rust,ignore
25//! use sochdb_query::sql::{Parser, Statement};
26//!
27//! let stmt = Parser::parse("SELECT * FROM users WHERE id = 1")?;
28//! ```
29
30pub mod aggregate;
31pub mod ast;
32pub mod bridge;
33pub mod compatibility;
34pub mod error;
35pub mod lexer;
36pub mod parser;
37pub mod token;
38
39pub use ast::*;
40pub use bridge::{ExecutionResult as BridgeExecutionResult, SqlBridge, SqlConnection};
41pub use compatibility::{
42    CompatibilityMatrix, FeatureSupport, SqlDialect, SqlFeature, get_feature_support,
43};
44pub use error::{SqlError, SqlResult};
45pub use lexer::{LexError, Lexer};
46pub use parser::{ParseError, Parser};
47pub use token::{Span, Token, TokenKind};
48
49use sochdb_core::SochValue;
50use std::collections::HashMap;
51
52/// Result of SQL execution
53#[derive(Debug, Clone)]
54pub enum ExecutionResult {
55    /// Query returned rows
56    Rows {
57        columns: Vec<String>,
58        rows: Vec<HashMap<String, SochValue>>,
59    },
60    /// Statement affected N rows
61    RowsAffected(usize),
62    /// Statement completed successfully
63    Ok,
64}
65
66impl ExecutionResult {
67    /// Get rows if this is a query result
68    pub fn rows(&self) -> Option<&Vec<HashMap<String, SochValue>>> {
69        match self {
70            ExecutionResult::Rows { rows, .. } => Some(rows),
71            _ => None,
72        }
73    }
74
75    /// Get column names if this is a query result
76    pub fn columns(&self) -> Option<&Vec<String>> {
77        match self {
78            ExecutionResult::Rows { columns, .. } => Some(columns),
79            _ => None,
80        }
81    }
82
83    /// Get affected row count
84    pub fn rows_affected(&self) -> usize {
85        match self {
86            ExecutionResult::RowsAffected(n) => *n,
87            ExecutionResult::Rows { rows, .. } => rows.len(),
88            ExecutionResult::Ok => 0,
89        }
90    }
91}
92
93/// Simple SQL executor for standalone use
94///
95/// For full database integration, use the SqlConnection in sochdb-storage
96pub struct SqlExecutor {
97    /// Tables stored in memory (for testing/standalone use)
98    tables: HashMap<String, TableData>,
99}
100
101/// In-memory table data
102#[derive(Debug, Clone)]
103pub struct TableData {
104    pub columns: Vec<String>,
105    pub column_types: Vec<DataType>,
106    pub rows: Vec<Vec<SochValue>>,
107}
108
109impl Default for SqlExecutor {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115impl SqlExecutor {
116    /// Create a new SQL executor
117    pub fn new() -> Self {
118        Self {
119            tables: HashMap::new(),
120        }
121    }
122
123    /// Execute a SQL statement
124    pub fn execute(&mut self, sql: &str) -> SqlResult<ExecutionResult> {
125        self.execute_with_params(sql, &[])
126    }
127
128    /// Execute a SQL statement with parameters
129    pub fn execute_with_params(
130        &mut self,
131        sql: &str,
132        params: &[SochValue],
133    ) -> SqlResult<ExecutionResult> {
134        let stmt = Parser::parse(sql).map_err(SqlError::from_parse_errors)?;
135        self.execute_statement(&stmt, params)
136    }
137
138    /// Execute a parsed statement
139    pub fn execute_statement(
140        &mut self,
141        stmt: &Statement,
142        params: &[SochValue],
143    ) -> SqlResult<ExecutionResult> {
144        match stmt {
145            Statement::Select(select) => self.execute_select(select, params),
146            Statement::Insert(insert) => self.execute_insert(insert, params),
147            Statement::Update(update) => self.execute_update(update, params),
148            Statement::Delete(delete) => self.execute_delete(delete, params),
149            Statement::CreateTable(create) => self.execute_create_table(create),
150            Statement::DropTable(drop) => self.execute_drop_table(drop),
151            Statement::Begin(_) => Ok(ExecutionResult::Ok),
152            Statement::Commit => Ok(ExecutionResult::Ok),
153            Statement::Rollback(_) => Ok(ExecutionResult::Ok),
154            _ => Err(SqlError::NotImplemented(
155                "Statement type not yet supported".into(),
156            )),
157        }
158    }
159
160    fn execute_select(
161        &self,
162        select: &SelectStmt,
163        params: &[SochValue],
164    ) -> SqlResult<ExecutionResult> {
165        // Get the table
166        let from = select
167            .from
168            .as_ref()
169            .ok_or_else(|| SqlError::InvalidArgument("SELECT requires FROM clause".into()))?;
170
171        if from.tables.len() != 1 {
172            return Err(SqlError::NotImplemented(
173                "Multi-table queries not yet supported".into(),
174            ));
175        }
176
177        let table_name = match &from.tables[0] {
178            TableRef::Table { name, .. } => name.name().to_string(),
179            _ => {
180                return Err(SqlError::NotImplemented(
181                    "Complex table references not yet supported".into(),
182                ));
183            }
184        };
185
186        let table = self
187            .tables
188            .get(&table_name)
189            .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
190
191        // Collect matching source rows
192        let mut source_rows = Vec::new();
193
194        for row in &table.rows {
195            // Build row as HashMap
196            let row_map: HashMap<String, SochValue> = table
197                .columns
198                .iter()
199                .zip(row.iter())
200                .map(|(col, val)| (col.clone(), val.clone()))
201                .collect();
202
203            // Apply WHERE filter
204            if let Some(where_clause) = &select.where_clause
205                && !self.evaluate_where(where_clause, &row_map, params)?
206            {
207                continue;
208            }
209
210            source_rows.push(row_map);
211        }
212
213        // Apply ORDER BY
214        if !select.order_by.is_empty() {
215            source_rows.sort_by(|a, b| {
216                for order_item in &select.order_by {
217                    if let Expr::Column(col_ref) = &order_item.expr {
218                        let a_val = a.get(&col_ref.column);
219                        let b_val = b.get(&col_ref.column);
220
221                        let cmp = self.compare_values(a_val, b_val);
222                        if cmp != std::cmp::Ordering::Equal {
223                            return if order_item.asc { cmp } else { cmp.reverse() };
224                        }
225                    }
226                }
227                std::cmp::Ordering::Equal
228            });
229        }
230
231        // Apply OFFSET
232        if let Some(Expr::Literal(Literal::Integer(n))) = &select.offset {
233            let n = *n as usize;
234            if n < source_rows.len() {
235                source_rows = source_rows.into_iter().skip(n).collect();
236            } else {
237                source_rows.clear();
238            }
239        }
240
241        // Apply LIMIT
242        if let Some(Expr::Literal(Literal::Integer(n))) = &select.limit {
243            source_rows.truncate(*n as usize);
244        }
245
246        // Determine output columns and evaluate SELECT expressions
247        let mut output_columns: Vec<String> = Vec::new();
248        let mut result_rows: Vec<HashMap<String, SochValue>> = Vec::new();
249
250        // Check for SELECT *
251        let is_wildcard = matches!(&select.columns[..], [SelectItem::Wildcard]);
252
253        if is_wildcard {
254            output_columns = table.columns.clone();
255            result_rows = source_rows;
256        } else {
257            // Determine column names first
258            for item in &select.columns {
259                match item {
260                    SelectItem::Wildcard => output_columns.push("*".to_string()),
261                    SelectItem::QualifiedWildcard(t) => output_columns.push(format!("{}.*", t)),
262                    SelectItem::Expr { expr, alias } => {
263                        let col_name = alias.clone().unwrap_or_else(|| match expr {
264                            Expr::Column(col) => col.column.clone(),
265                            Expr::Function(func) => format!("{}()", func.name.name()),
266                            _ => "?column?".to_string(),
267                        });
268                        output_columns.push(col_name);
269                    }
270                }
271            }
272
273            // Now evaluate expressions for each row
274            for source_row in &source_rows {
275                let mut result_row = HashMap::new();
276
277                for (idx, item) in select.columns.iter().enumerate() {
278                    let col_name = &output_columns[idx];
279
280                    match item {
281                        SelectItem::Wildcard => {
282                            // Add all columns from source row
283                            for (k, v) in source_row {
284                                result_row.insert(k.clone(), v.clone());
285                            }
286                        }
287                        SelectItem::QualifiedWildcard(_) => {
288                            // For now, treat same as wildcard
289                            for (k, v) in source_row {
290                                result_row.insert(k.clone(), v.clone());
291                            }
292                        }
293                        SelectItem::Expr { expr, .. } => {
294                            let value = self.evaluate_expr(expr, source_row, params)?;
295                            result_row.insert(col_name.clone(), value);
296                        }
297                    }
298                }
299
300                result_rows.push(result_row);
301            }
302        }
303
304        Ok(ExecutionResult::Rows {
305            columns: output_columns,
306            rows: result_rows,
307        })
308    }
309
310    fn execute_insert(
311        &mut self,
312        insert: &InsertStmt,
313        params: &[SochValue],
314    ) -> SqlResult<ExecutionResult> {
315        let table_name = insert.table.name().to_string();
316
317        // First check if table exists and get column info
318        let table_columns = {
319            let table = self
320                .tables
321                .get(&table_name)
322                .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
323            table.columns.clone()
324        };
325
326        let mut rows_affected = 0;
327        let mut new_rows = Vec::new();
328
329        match &insert.source {
330            InsertSource::Values(rows) => {
331                for value_exprs in rows {
332                    let mut row_values = Vec::new();
333
334                    if let Some(columns) = &insert.columns {
335                        if columns.len() != value_exprs.len() {
336                            return Err(SqlError::InvalidArgument(format!(
337                                "Column count ({}) doesn't match value count ({})",
338                                columns.len(),
339                                value_exprs.len()
340                            )));
341                        }
342
343                        // Map columns to table column order
344                        for table_col in &table_columns {
345                            if let Some(pos) = columns.iter().position(|c| c == table_col) {
346                                let value =
347                                    self.evaluate_expr(&value_exprs[pos], &HashMap::new(), params)?;
348                                row_values.push(value);
349                            } else {
350                                row_values.push(SochValue::Null);
351                            }
352                        }
353                    } else {
354                        // Insert in column order
355                        for expr in value_exprs {
356                            let value = self.evaluate_expr(expr, &HashMap::new(), params)?;
357                            row_values.push(value);
358                        }
359                    }
360
361                    new_rows.push(row_values);
362                    rows_affected += 1;
363                }
364            }
365            InsertSource::Query(_) => {
366                return Err(SqlError::NotImplemented(
367                    "INSERT ... SELECT not yet supported".into(),
368                ));
369            }
370            InsertSource::Default => {
371                return Err(SqlError::NotImplemented(
372                    "INSERT DEFAULT VALUES not yet supported".into(),
373                ));
374            }
375        }
376
377        // Now add the rows to the table
378        let table = self.tables.get_mut(&table_name).unwrap();
379        for row in new_rows {
380            table.rows.push(row);
381        }
382
383        Ok(ExecutionResult::RowsAffected(rows_affected))
384    }
385
386    fn execute_update(
387        &mut self,
388        update: &UpdateStmt,
389        params: &[SochValue],
390    ) -> SqlResult<ExecutionResult> {
391        let table_name = update.table.name().to_string();
392
393        // First, collect table info and rows that need updating
394        let (_table_columns, updates_to_apply) = {
395            let table = self
396                .tables
397                .get(&table_name)
398                .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
399
400            let mut updates = Vec::new();
401
402            for row_idx in 0..table.rows.len() {
403                // Build row as HashMap for WHERE evaluation
404                let row_map: HashMap<String, SochValue> = table
405                    .columns
406                    .iter()
407                    .zip(table.rows[row_idx].iter())
408                    .map(|(col, val)| (col.clone(), val.clone()))
409                    .collect();
410
411                // Check WHERE condition
412                let matches = if let Some(where_clause) = &update.where_clause {
413                    self.evaluate_where(where_clause, &row_map, params)?
414                } else {
415                    true
416                };
417
418                if matches {
419                    // Collect the updates for this row
420                    let mut row_updates = Vec::new();
421                    for assignment in &update.assignments {
422                        if let Some(col_idx) =
423                            table.columns.iter().position(|c| c == &assignment.column)
424                        {
425                            let value = self.evaluate_expr(&assignment.value, &row_map, params)?;
426                            row_updates.push((col_idx, value));
427                        }
428                    }
429                    updates.push((row_idx, row_updates));
430                }
431            }
432
433            (table.columns.clone(), updates)
434        };
435
436        let rows_affected = updates_to_apply.len();
437
438        // Now apply the updates
439        let table = self.tables.get_mut(&table_name).unwrap();
440        for (row_idx, row_updates) in updates_to_apply {
441            for (col_idx, value) in row_updates {
442                table.rows[row_idx][col_idx] = value;
443            }
444        }
445
446        Ok(ExecutionResult::RowsAffected(rows_affected))
447    }
448
449    fn execute_delete(
450        &mut self,
451        delete: &DeleteStmt,
452        params: &[SochValue],
453    ) -> SqlResult<ExecutionResult> {
454        let table_name = delete.table.name().to_string();
455
456        // First, determine which rows to delete
457        let indices_to_remove = {
458            let table = self
459                .tables
460                .get(&table_name)
461                .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
462
463            let mut indices = Vec::new();
464
465            for (row_idx, row) in table.rows.iter().enumerate() {
466                // Build row as HashMap for WHERE evaluation
467                let row_map: HashMap<String, SochValue> = table
468                    .columns
469                    .iter()
470                    .zip(row.iter())
471                    .map(|(col, val)| (col.clone(), val.clone()))
472                    .collect();
473
474                // Check WHERE condition
475                let matches = if let Some(where_clause) = &delete.where_clause {
476                    self.evaluate_where(where_clause, &row_map, params)?
477                } else {
478                    true
479                };
480
481                if matches {
482                    indices.push(row_idx);
483                }
484            }
485
486            indices
487        };
488
489        let rows_affected = indices_to_remove.len();
490
491        // Now remove the rows
492        let table = self.tables.get_mut(&table_name).unwrap();
493        // Remove in reverse order to preserve indices
494        for idx in indices_to_remove.into_iter().rev() {
495            table.rows.remove(idx);
496        }
497
498        Ok(ExecutionResult::RowsAffected(rows_affected))
499    }
500
501    fn execute_create_table(&mut self, create: &CreateTableStmt) -> SqlResult<ExecutionResult> {
502        let table_name = create.name.name().to_string();
503
504        if self.tables.contains_key(&table_name) {
505            if create.if_not_exists {
506                return Ok(ExecutionResult::Ok);
507            }
508            return Err(SqlError::ConstraintViolation(format!(
509                "Table '{}' already exists",
510                table_name
511            )));
512        }
513
514        let columns: Vec<String> = create.columns.iter().map(|c| c.name.clone()).collect();
515        let column_types: Vec<DataType> =
516            create.columns.iter().map(|c| c.data_type.clone()).collect();
517
518        self.tables.insert(
519            table_name,
520            TableData {
521                columns,
522                column_types,
523                rows: Vec::new(),
524            },
525        );
526
527        Ok(ExecutionResult::Ok)
528    }
529
530    fn execute_drop_table(&mut self, drop: &DropTableStmt) -> SqlResult<ExecutionResult> {
531        for name in &drop.names {
532            let table_name = name.name().to_string();
533            if self.tables.remove(&table_name).is_none() && !drop.if_exists {
534                return Err(SqlError::TableNotFound(table_name));
535            }
536        }
537
538        Ok(ExecutionResult::Ok)
539    }
540
541    // ========== Expression Evaluation ==========
542
543    fn evaluate_where(
544        &self,
545        expr: &Expr,
546        row: &HashMap<String, SochValue>,
547        params: &[SochValue],
548    ) -> SqlResult<bool> {
549        let value = self.evaluate_expr(expr, row, params)?;
550        match value {
551            SochValue::Bool(b) => Ok(b),
552            SochValue::Null => Ok(false),
553            _ => Err(SqlError::TypeError(
554                "WHERE clause must evaluate to boolean".into(),
555            )),
556        }
557    }
558
559    fn evaluate_expr(
560        &self,
561        expr: &Expr,
562        row: &HashMap<String, SochValue>,
563        params: &[SochValue],
564    ) -> SqlResult<SochValue> {
565        match expr {
566            Expr::Literal(lit) => Ok(self.literal_to_value(lit)),
567
568            Expr::Column(col_ref) => row
569                .get(&col_ref.column)
570                .cloned()
571                .ok_or_else(|| SqlError::ColumnNotFound(col_ref.column.clone())),
572
573            Expr::Placeholder(n) => params
574                .get((*n as usize).saturating_sub(1))
575                .cloned()
576                .ok_or_else(|| SqlError::InvalidArgument(format!("Parameter ${} not provided", n))),
577
578            Expr::BinaryOp { left, op, right } => {
579                let left_val = self.evaluate_expr(left, row, params)?;
580                let right_val = self.evaluate_expr(right, row, params)?;
581                self.evaluate_binary_op(&left_val, op, &right_val)
582            }
583
584            Expr::UnaryOp { op, expr } => {
585                let val = self.evaluate_expr(expr, row, params)?;
586                self.evaluate_unary_op(op, &val)
587            }
588
589            Expr::IsNull { expr, negated } => {
590                let val = self.evaluate_expr(expr, row, params)?;
591                let is_null = matches!(val, SochValue::Null);
592                Ok(SochValue::Bool(if *negated { !is_null } else { is_null }))
593            }
594
595            Expr::InList {
596                expr,
597                list,
598                negated,
599            } => {
600                let val = self.evaluate_expr(expr, row, params)?;
601                let mut found = false;
602                for item in list {
603                    let item_val = self.evaluate_expr(item, row, params)?;
604                    if self.values_equal(&val, &item_val) {
605                        found = true;
606                        break;
607                    }
608                }
609                Ok(SochValue::Bool(if *negated { !found } else { found }))
610            }
611
612            Expr::Between {
613                expr,
614                low,
615                high,
616                negated,
617            } => {
618                let val = self.evaluate_expr(expr, row, params)?;
619                let low_val = self.evaluate_expr(low, row, params)?;
620                let high_val = self.evaluate_expr(high, row, params)?;
621
622                let cmp_low = self.compare_values(Some(&val), Some(&low_val));
623                let cmp_high = self.compare_values(Some(&val), Some(&high_val));
624
625                let in_range =
626                    cmp_low != std::cmp::Ordering::Less && cmp_high != std::cmp::Ordering::Greater;
627
628                Ok(SochValue::Bool(if *negated { !in_range } else { in_range }))
629            }
630
631            Expr::Like {
632                expr,
633                pattern,
634                negated,
635                ..
636            } => {
637                let val = self.evaluate_expr(expr, row, params)?;
638                let pattern_val = self.evaluate_expr(pattern, row, params)?;
639
640                match (&val, &pattern_val) {
641                    (SochValue::Text(s), SochValue::Text(p)) => {
642                        // Route through the canonical LIKE matcher so SQL `LIKE`
643                        // behaves identically across every query path and treats
644                        // regex metacharacters (`.`, `*`, `[`, ...) as literals.
645                        let matches = crate::like::like_match(s, p);
646                        Ok(SochValue::Bool(if *negated { !matches } else { matches }))
647                    }
648                    _ => Ok(SochValue::Bool(false)),
649                }
650            }
651
652            Expr::Function(func) => self.evaluate_function(func, row, params),
653
654            Expr::Case {
655                operand,
656                conditions,
657                else_result,
658            } => {
659                if let Some(op) = operand {
660                    // Simple CASE
661                    let op_val = self.evaluate_expr(op, row, params)?;
662                    for (when_expr, then_expr) in conditions {
663                        let when_val = self.evaluate_expr(when_expr, row, params)?;
664                        if self.values_equal(&op_val, &when_val) {
665                            return self.evaluate_expr(then_expr, row, params);
666                        }
667                    }
668                } else {
669                    // Searched CASE
670                    for (when_expr, then_expr) in conditions {
671                        let when_val = self.evaluate_expr(when_expr, row, params)?;
672                        if matches!(when_val, SochValue::Bool(true)) {
673                            return self.evaluate_expr(then_expr, row, params);
674                        }
675                    }
676                }
677
678                if let Some(else_expr) = else_result {
679                    self.evaluate_expr(else_expr, row, params)
680                } else {
681                    Ok(SochValue::Null)
682                }
683            }
684
685            _ => Err(SqlError::NotImplemented(format!(
686                "Expression type {:?} not yet supported",
687                expr
688            ))),
689        }
690    }
691
692    fn literal_to_value(&self, lit: &Literal) -> SochValue {
693        match lit {
694            Literal::Null => SochValue::Null,
695            Literal::Boolean(b) => SochValue::Bool(*b),
696            Literal::Integer(n) => SochValue::Int(*n),
697            Literal::Float(f) => SochValue::Float(*f),
698            Literal::String(s) => SochValue::Text(s.clone()),
699            Literal::Blob(b) => SochValue::Binary(b.clone()),
700        }
701    }
702
703    fn evaluate_binary_op(
704        &self,
705        left: &SochValue,
706        op: &BinaryOperator,
707        right: &SochValue,
708    ) -> SqlResult<SochValue> {
709        match op {
710            BinaryOperator::Eq => Ok(SochValue::Bool(self.values_equal(left, right))),
711            BinaryOperator::Ne => Ok(SochValue::Bool(!self.values_equal(left, right))),
712            BinaryOperator::Lt => Ok(SochValue::Bool(
713                self.compare_values(Some(left), Some(right)) == std::cmp::Ordering::Less,
714            )),
715            BinaryOperator::Le => Ok(SochValue::Bool(
716                self.compare_values(Some(left), Some(right)) != std::cmp::Ordering::Greater,
717            )),
718            BinaryOperator::Gt => Ok(SochValue::Bool(
719                self.compare_values(Some(left), Some(right)) == std::cmp::Ordering::Greater,
720            )),
721            BinaryOperator::Ge => Ok(SochValue::Bool(
722                self.compare_values(Some(left), Some(right)) != std::cmp::Ordering::Less,
723            )),
724
725            BinaryOperator::And => match (left, right) {
726                (SochValue::Bool(l), SochValue::Bool(r)) => Ok(SochValue::Bool(*l && *r)),
727                (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
728                _ => Err(SqlError::TypeError("AND requires boolean operands".into())),
729            },
730
731            BinaryOperator::Or => match (left, right) {
732                (SochValue::Bool(l), SochValue::Bool(r)) => Ok(SochValue::Bool(*l || *r)),
733                (SochValue::Bool(true), _) | (_, SochValue::Bool(true)) => {
734                    Ok(SochValue::Bool(true))
735                }
736                (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
737                _ => Err(SqlError::TypeError("OR requires boolean operands".into())),
738            },
739
740            BinaryOperator::Plus => self.arithmetic_op(left, right, |a, b| a + b, |a, b| a + b),
741            BinaryOperator::Minus => self.arithmetic_op(left, right, |a, b| a - b, |a, b| a - b),
742            BinaryOperator::Multiply => self.arithmetic_op(left, right, |a, b| a * b, |a, b| a * b),
743            BinaryOperator::Divide => self.arithmetic_op(
744                left,
745                right,
746                |a, b| if b != 0 { a / b } else { 0 },
747                |a, b| a / b,
748            ),
749            BinaryOperator::Modulo => self.arithmetic_op(
750                left,
751                right,
752                |a, b| if b != 0 { a % b } else { 0 },
753                |a, b| a % b,
754            ),
755
756            BinaryOperator::Concat => match (left, right) {
757                (SochValue::Text(l), SochValue::Text(r)) => {
758                    Ok(SochValue::Text(format!("{}{}", l, r)))
759                }
760                (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
761                _ => Err(SqlError::TypeError("|| requires string operands".into())),
762            },
763
764            _ => Err(SqlError::NotImplemented(format!(
765                "Operator {:?} not implemented",
766                op
767            ))),
768        }
769    }
770
771    fn evaluate_unary_op(&self, op: &UnaryOperator, val: &SochValue) -> SqlResult<SochValue> {
772        match op {
773            UnaryOperator::Not => match val {
774                SochValue::Bool(b) => Ok(SochValue::Bool(!b)),
775                SochValue::Null => Ok(SochValue::Null),
776                _ => Err(SqlError::TypeError("NOT requires boolean operand".into())),
777            },
778            UnaryOperator::Minus => match val {
779                SochValue::Int(n) => Ok(SochValue::Int(-n)),
780                SochValue::Float(f) => Ok(SochValue::Float(-f)),
781                SochValue::Null => Ok(SochValue::Null),
782                _ => Err(SqlError::TypeError(
783                    "Unary minus requires numeric operand".into(),
784                )),
785            },
786            UnaryOperator::Plus => Ok(val.clone()),
787            UnaryOperator::BitNot => match val {
788                SochValue::Int(n) => Ok(SochValue::Int(!n)),
789                _ => Err(SqlError::TypeError("~ requires integer operand".into())),
790            },
791        }
792    }
793
794    fn evaluate_function(
795        &self,
796        func: &FunctionCall,
797        row: &HashMap<String, SochValue>,
798        params: &[SochValue],
799    ) -> SqlResult<SochValue> {
800        let func_name = func.name.name().to_uppercase();
801
802        match func_name.as_str() {
803            "COALESCE" => {
804                for arg in &func.args {
805                    let val = self.evaluate_expr(arg, row, params)?;
806                    if !matches!(val, SochValue::Null) {
807                        return Ok(val);
808                    }
809                }
810                Ok(SochValue::Null)
811            }
812
813            "NULLIF" => {
814                if func.args.len() != 2 {
815                    return Err(SqlError::InvalidArgument(
816                        "NULLIF requires 2 arguments".into(),
817                    ));
818                }
819                let val1 = self.evaluate_expr(&func.args[0], row, params)?;
820                let val2 = self.evaluate_expr(&func.args[1], row, params)?;
821                if self.values_equal(&val1, &val2) {
822                    Ok(SochValue::Null)
823                } else {
824                    Ok(val1)
825                }
826            }
827
828            "ABS" => {
829                if func.args.len() != 1 {
830                    return Err(SqlError::InvalidArgument("ABS requires 1 argument".into()));
831                }
832                let val = self.evaluate_expr(&func.args[0], row, params)?;
833                match val {
834                    SochValue::Int(n) => Ok(SochValue::Int(n.abs())),
835                    SochValue::Float(f) => Ok(SochValue::Float(f.abs())),
836                    SochValue::Null => Ok(SochValue::Null),
837                    _ => Err(SqlError::TypeError("ABS requires numeric argument".into())),
838                }
839            }
840
841            "LENGTH" | "LEN" => {
842                if func.args.len() != 1 {
843                    return Err(SqlError::InvalidArgument(
844                        "LENGTH requires 1 argument".into(),
845                    ));
846                }
847                let val = self.evaluate_expr(&func.args[0], row, params)?;
848                match val {
849                    SochValue::Text(s) => Ok(SochValue::Int(s.len() as i64)),
850                    SochValue::Binary(b) => Ok(SochValue::Int(b.len() as i64)),
851                    SochValue::Null => Ok(SochValue::Null),
852                    _ => Err(SqlError::TypeError(
853                        "LENGTH requires string argument".into(),
854                    )),
855                }
856            }
857
858            "UPPER" => {
859                if func.args.len() != 1 {
860                    return Err(SqlError::InvalidArgument(
861                        "UPPER requires 1 argument".into(),
862                    ));
863                }
864                let val = self.evaluate_expr(&func.args[0], row, params)?;
865                match val {
866                    SochValue::Text(s) => Ok(SochValue::Text(s.to_uppercase())),
867                    SochValue::Null => Ok(SochValue::Null),
868                    _ => Err(SqlError::TypeError("UPPER requires string argument".into())),
869                }
870            }
871
872            "LOWER" => {
873                if func.args.len() != 1 {
874                    return Err(SqlError::InvalidArgument(
875                        "LOWER requires 1 argument".into(),
876                    ));
877                }
878                let val = self.evaluate_expr(&func.args[0], row, params)?;
879                match val {
880                    SochValue::Text(s) => Ok(SochValue::Text(s.to_lowercase())),
881                    SochValue::Null => Ok(SochValue::Null),
882                    _ => Err(SqlError::TypeError("LOWER requires string argument".into())),
883                }
884            }
885
886            "TRIM" => {
887                if func.args.len() != 1 {
888                    return Err(SqlError::InvalidArgument("TRIM requires 1 argument".into()));
889                }
890                let val = self.evaluate_expr(&func.args[0], row, params)?;
891                match val {
892                    SochValue::Text(s) => Ok(SochValue::Text(s.trim().to_string())),
893                    SochValue::Null => Ok(SochValue::Null),
894                    _ => Err(SqlError::TypeError("TRIM requires string argument".into())),
895                }
896            }
897
898            "SUBSTR" | "SUBSTRING" => {
899                if func.args.len() < 2 || func.args.len() > 3 {
900                    return Err(SqlError::InvalidArgument(
901                        "SUBSTR requires 2 or 3 arguments".into(),
902                    ));
903                }
904                let val = self.evaluate_expr(&func.args[0], row, params)?;
905                let start = self.evaluate_expr(&func.args[1], row, params)?;
906                let len = if func.args.len() == 3 {
907                    Some(self.evaluate_expr(&func.args[2], row, params)?)
908                } else {
909                    None
910                };
911
912                match (val, start) {
913                    (SochValue::Text(s), SochValue::Int(start)) => {
914                        let start = (start.max(1) - 1) as usize;
915                        if start >= s.len() {
916                            return Ok(SochValue::Text(String::new()));
917                        }
918                        let result = if let Some(SochValue::Int(len)) = len {
919                            s.chars().skip(start).take(len as usize).collect()
920                        } else {
921                            s.chars().skip(start).collect()
922                        };
923                        Ok(SochValue::Text(result))
924                    }
925                    (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
926                    _ => Err(SqlError::TypeError(
927                        "SUBSTR requires string and integer arguments".into(),
928                    )),
929                }
930            }
931
932            _ => Err(SqlError::NotImplemented(format!(
933                "Function {} not implemented",
934                func_name
935            ))),
936        }
937    }
938
939    // ========== Helper Methods ==========
940
941    fn values_equal(&self, left: &SochValue, right: &SochValue) -> bool {
942        match (left, right) {
943            (SochValue::Null, _) | (_, SochValue::Null) => false,
944            (SochValue::Int(l), SochValue::Int(r)) => l == r,
945            (SochValue::Float(l), SochValue::Float(r)) => (l - r).abs() < f64::EPSILON,
946            (SochValue::Int(l), SochValue::Float(r)) => (*l as f64 - r).abs() < f64::EPSILON,
947            (SochValue::Float(l), SochValue::Int(r)) => (l - *r as f64).abs() < f64::EPSILON,
948            (SochValue::Text(l), SochValue::Text(r)) => l == r,
949            (SochValue::Bool(l), SochValue::Bool(r)) => l == r,
950            (SochValue::Binary(l), SochValue::Binary(r)) => l == r,
951            (SochValue::UInt(l), SochValue::UInt(r)) => l == r,
952            (SochValue::Int(l), SochValue::UInt(r)) => *l >= 0 && (*l as u64) == *r,
953            (SochValue::UInt(l), SochValue::Int(r)) => *r >= 0 && *l == (*r as u64),
954            _ => false,
955        }
956    }
957
958    fn compare_values(
959        &self,
960        left: Option<&SochValue>,
961        right: Option<&SochValue>,
962    ) -> std::cmp::Ordering {
963        match (left, right) {
964            (None, None) => std::cmp::Ordering::Equal,
965            (None, _) => std::cmp::Ordering::Less,
966            (_, None) => std::cmp::Ordering::Greater,
967            (Some(SochValue::Null), _) | (_, Some(SochValue::Null)) => std::cmp::Ordering::Equal,
968            (Some(SochValue::Int(l)), Some(SochValue::Int(r))) => l.cmp(r),
969            (Some(SochValue::Float(l)), Some(SochValue::Float(r))) => {
970                l.partial_cmp(r).unwrap_or(std::cmp::Ordering::Equal)
971            }
972            (Some(SochValue::Int(l)), Some(SochValue::Float(r))) => (*l as f64)
973                .partial_cmp(r)
974                .unwrap_or(std::cmp::Ordering::Equal),
975            (Some(SochValue::Float(l)), Some(SochValue::Int(r))) => l
976                .partial_cmp(&(*r as f64))
977                .unwrap_or(std::cmp::Ordering::Equal),
978            (Some(SochValue::Text(l)), Some(SochValue::Text(r))) => l.cmp(r),
979            (Some(SochValue::UInt(l)), Some(SochValue::UInt(r))) => l.cmp(r),
980            _ => std::cmp::Ordering::Equal,
981        }
982    }
983
984    fn arithmetic_op<FI, FF>(
985        &self,
986        left: &SochValue,
987        right: &SochValue,
988        int_op: FI,
989        float_op: FF,
990    ) -> SqlResult<SochValue>
991    where
992        FI: Fn(i64, i64) -> i64,
993        FF: Fn(f64, f64) -> f64,
994    {
995        match (left, right) {
996            (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
997            (SochValue::Int(l), SochValue::Int(r)) => Ok(SochValue::Int(int_op(*l, *r))),
998            (SochValue::Float(l), SochValue::Float(r)) => Ok(SochValue::Float(float_op(*l, *r))),
999            (SochValue::Int(l), SochValue::Float(r)) => {
1000                Ok(SochValue::Float(float_op(*l as f64, *r)))
1001            }
1002            (SochValue::Float(l), SochValue::Int(r)) => {
1003                Ok(SochValue::Float(float_op(*l, *r as f64)))
1004            }
1005            (SochValue::UInt(l), SochValue::UInt(r)) => {
1006                Ok(SochValue::Int(int_op(*l as i64, *r as i64)))
1007            }
1008            (SochValue::Int(l), SochValue::UInt(r)) => Ok(SochValue::Int(int_op(*l, *r as i64))),
1009            (SochValue::UInt(l), SochValue::Int(r)) => Ok(SochValue::Int(int_op(*l as i64, *r))),
1010            _ => Err(SqlError::TypeError(
1011                "Arithmetic requires numeric operands".into(),
1012            )),
1013        }
1014    }
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019    use super::*;
1020
1021    #[test]
1022    fn test_create_table_and_insert() {
1023        let mut executor = SqlExecutor::new();
1024
1025        // Create table
1026        let result = executor
1027            .execute("CREATE TABLE users (id INTEGER, name VARCHAR(100))")
1028            .unwrap();
1029        assert!(matches!(result, ExecutionResult::Ok));
1030
1031        // Insert rows
1032        let result = executor
1033            .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1034            .unwrap();
1035        assert_eq!(result.rows_affected(), 1);
1036
1037        let result = executor
1038            .execute("INSERT INTO users (id, name) VALUES (2, 'Bob')")
1039            .unwrap();
1040        assert_eq!(result.rows_affected(), 1);
1041
1042        // Select
1043        let result = executor.execute("SELECT * FROM users").unwrap();
1044        assert_eq!(result.rows_affected(), 2);
1045    }
1046
1047    #[test]
1048    fn test_select_with_where() {
1049        let mut executor = SqlExecutor::new();
1050
1051        executor
1052            .execute("CREATE TABLE products (id INTEGER, name TEXT, price FLOAT)")
1053            .unwrap();
1054        executor
1055            .execute("INSERT INTO products (id, name, price) VALUES (1, 'Apple', 1.50)")
1056            .unwrap();
1057        executor
1058            .execute("INSERT INTO products (id, name, price) VALUES (2, 'Banana', 0.75)")
1059            .unwrap();
1060        executor
1061            .execute("INSERT INTO products (id, name, price) VALUES (3, 'Orange', 2.00)")
1062            .unwrap();
1063
1064        let result = executor
1065            .execute("SELECT * FROM products WHERE price > 1.0")
1066            .unwrap();
1067        assert_eq!(result.rows_affected(), 2);
1068    }
1069
1070    #[test]
1071    fn test_update() {
1072        let mut executor = SqlExecutor::new();
1073
1074        executor
1075            .execute("CREATE TABLE users (id INTEGER, name TEXT)")
1076            .unwrap();
1077        executor
1078            .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1079            .unwrap();
1080
1081        let result = executor
1082            .execute("UPDATE users SET name = 'Alicia' WHERE id = 1")
1083            .unwrap();
1084        assert_eq!(result.rows_affected(), 1);
1085
1086        let result = executor
1087            .execute("SELECT * FROM users WHERE name = 'Alicia'")
1088            .unwrap();
1089        assert_eq!(result.rows_affected(), 1);
1090    }
1091
1092    #[test]
1093    fn test_delete() {
1094        let mut executor = SqlExecutor::new();
1095
1096        executor
1097            .execute("CREATE TABLE users (id INTEGER, name TEXT)")
1098            .unwrap();
1099        executor
1100            .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1101            .unwrap();
1102        executor
1103            .execute("INSERT INTO users (id, name) VALUES (2, 'Bob')")
1104            .unwrap();
1105
1106        let result = executor.execute("DELETE FROM users WHERE id = 1").unwrap();
1107        assert_eq!(result.rows_affected(), 1);
1108
1109        let result = executor.execute("SELECT * FROM users").unwrap();
1110        assert_eq!(result.rows_affected(), 1);
1111    }
1112
1113    #[test]
1114    fn test_functions() {
1115        let mut executor = SqlExecutor::new();
1116
1117        executor.execute("CREATE TABLE t (s TEXT)").unwrap();
1118        executor
1119            .execute("INSERT INTO t (s) VALUES ('hello')")
1120            .unwrap();
1121
1122        let result = executor.execute("SELECT UPPER(s) FROM t").unwrap();
1123        if let ExecutionResult::Rows { rows, .. } = result {
1124            let row = &rows[0];
1125            // The column name might be UPPER(s) or similar
1126            assert!(
1127                row.values()
1128                    .any(|v| matches!(v, SochValue::Text(s) if s == "HELLO"))
1129            );
1130        } else {
1131            panic!("Expected rows");
1132        }
1133    }
1134
1135    #[test]
1136    fn test_order_by() {
1137        let mut executor = SqlExecutor::new();
1138
1139        executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1140        executor.execute("INSERT INTO nums (n) VALUES (3)").unwrap();
1141        executor.execute("INSERT INTO nums (n) VALUES (1)").unwrap();
1142        executor.execute("INSERT INTO nums (n) VALUES (2)").unwrap();
1143
1144        let result = executor
1145            .execute("SELECT * FROM nums ORDER BY n ASC")
1146            .unwrap();
1147        if let ExecutionResult::Rows { rows, .. } = result {
1148            let values: Vec<i64> = rows
1149                .iter()
1150                .filter_map(|r| r.get("n"))
1151                .filter_map(|v| {
1152                    if let SochValue::Int(n) = v {
1153                        Some(*n)
1154                    } else {
1155                        None
1156                    }
1157                })
1158                .collect();
1159            assert_eq!(values, vec![1, 2, 3]);
1160        } else {
1161            panic!("Expected rows");
1162        }
1163    }
1164
1165    #[test]
1166    fn test_limit_offset() {
1167        let mut executor = SqlExecutor::new();
1168
1169        executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1170        for i in 1..=10 {
1171            executor
1172                .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1173                .unwrap();
1174        }
1175
1176        let result = executor
1177            .execute("SELECT * FROM nums LIMIT 3 OFFSET 2")
1178            .unwrap();
1179        assert_eq!(result.rows_affected(), 3);
1180    }
1181
1182    #[test]
1183    fn test_between() {
1184        let mut executor = SqlExecutor::new();
1185
1186        executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1187        for i in 1..=10 {
1188            executor
1189                .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1190                .unwrap();
1191        }
1192
1193        let result = executor
1194            .execute("SELECT * FROM nums WHERE n BETWEEN 3 AND 7")
1195            .unwrap();
1196        assert_eq!(result.rows_affected(), 5);
1197    }
1198
1199    #[test]
1200    fn test_in_list() {
1201        let mut executor = SqlExecutor::new();
1202
1203        executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1204        for i in 1..=5 {
1205            executor
1206                .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1207                .unwrap();
1208        }
1209
1210        let result = executor
1211            .execute("SELECT * FROM nums WHERE n IN (1, 3, 5)")
1212            .unwrap();
1213        assert_eq!(result.rows_affected(), 3);
1214    }
1215}