Skip to main content

sochdb_query/sql/
mod.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! SQL Module - Complete SQL support for SochDB
19//!
20//! Provides SQL parsing, planning, and execution.
21//!
22//! # Example
23//!
24//! ```rust,ignore
25//! use sochdb_query::sql::{Parser, Statement};
26//!
27//! let stmt = Parser::parse("SELECT * FROM users WHERE id = 1")?;
28//! ```
29
30pub mod ast;
31pub mod bridge;
32pub mod compatibility;
33pub mod error;
34pub mod lexer;
35pub mod parser;
36pub mod token;
37
38pub use ast::*;
39pub use bridge::{ExecutionResult as BridgeExecutionResult, SqlBridge, SqlConnection};
40pub use compatibility::{CompatibilityMatrix, FeatureSupport, SqlDialect, SqlFeature, get_feature_support};
41pub use error::{SqlError, SqlResult};
42pub use lexer::{LexError, Lexer};
43pub use parser::{ParseError, Parser};
44pub use token::{Span, Token, TokenKind};
45
46use std::collections::HashMap;
47use sochdb_core::SochValue;
48
49/// Result of SQL execution
50#[derive(Debug, Clone)]
51pub enum ExecutionResult {
52    /// Query returned rows
53    Rows {
54        columns: Vec<String>,
55        rows: Vec<HashMap<String, SochValue>>,
56    },
57    /// Statement affected N rows
58    RowsAffected(usize),
59    /// Statement completed successfully
60    Ok,
61}
62
63impl ExecutionResult {
64    /// Get rows if this is a query result
65    pub fn rows(&self) -> Option<&Vec<HashMap<String, SochValue>>> {
66        match self {
67            ExecutionResult::Rows { rows, .. } => Some(rows),
68            _ => None,
69        }
70    }
71
72    /// Get column names if this is a query result
73    pub fn columns(&self) -> Option<&Vec<String>> {
74        match self {
75            ExecutionResult::Rows { columns, .. } => Some(columns),
76            _ => None,
77        }
78    }
79
80    /// Get affected row count
81    pub fn rows_affected(&self) -> usize {
82        match self {
83            ExecutionResult::RowsAffected(n) => *n,
84            ExecutionResult::Rows { rows, .. } => rows.len(),
85            ExecutionResult::Ok => 0,
86        }
87    }
88}
89
90/// Simple SQL executor for standalone use
91///
92/// For full database integration, use the SqlConnection in sochdb-storage
93pub struct SqlExecutor {
94    /// Tables stored in memory (for testing/standalone use)
95    tables: HashMap<String, TableData>,
96}
97
98/// In-memory table data
99#[derive(Debug, Clone)]
100pub struct TableData {
101    pub columns: Vec<String>,
102    pub column_types: Vec<DataType>,
103    pub rows: Vec<Vec<SochValue>>,
104}
105
106impl Default for SqlExecutor {
107    fn default() -> Self {
108        Self::new()
109    }
110}
111
112impl SqlExecutor {
113    /// Create a new SQL executor
114    pub fn new() -> Self {
115        Self {
116            tables: HashMap::new(),
117        }
118    }
119
120    /// Execute a SQL statement
121    pub fn execute(&mut self, sql: &str) -> SqlResult<ExecutionResult> {
122        self.execute_with_params(sql, &[])
123    }
124
125    /// Execute a SQL statement with parameters
126    pub fn execute_with_params(
127        &mut self,
128        sql: &str,
129        params: &[SochValue],
130    ) -> SqlResult<ExecutionResult> {
131        let stmt = Parser::parse(sql).map_err(SqlError::from_parse_errors)?;
132        self.execute_statement(&stmt, params)
133    }
134
135    /// Execute a parsed statement
136    pub fn execute_statement(
137        &mut self,
138        stmt: &Statement,
139        params: &[SochValue],
140    ) -> SqlResult<ExecutionResult> {
141        match stmt {
142            Statement::Select(select) => self.execute_select(select, params),
143            Statement::Insert(insert) => self.execute_insert(insert, params),
144            Statement::Update(update) => self.execute_update(update, params),
145            Statement::Delete(delete) => self.execute_delete(delete, params),
146            Statement::CreateTable(create) => self.execute_create_table(create),
147            Statement::DropTable(drop) => self.execute_drop_table(drop),
148            Statement::Begin(_) => Ok(ExecutionResult::Ok),
149            Statement::Commit => Ok(ExecutionResult::Ok),
150            Statement::Rollback(_) => Ok(ExecutionResult::Ok),
151            _ => Err(SqlError::NotImplemented(
152                "Statement type not yet supported".into(),
153            )),
154        }
155    }
156
157    fn execute_select(
158        &self,
159        select: &SelectStmt,
160        params: &[SochValue],
161    ) -> SqlResult<ExecutionResult> {
162        // Get the table
163        let from = select
164            .from
165            .as_ref()
166            .ok_or_else(|| SqlError::InvalidArgument("SELECT requires FROM clause".into()))?;
167
168        if from.tables.len() != 1 {
169            return Err(SqlError::NotImplemented(
170                "Multi-table queries not yet supported".into(),
171            ));
172        }
173
174        let table_name = match &from.tables[0] {
175            TableRef::Table { name, .. } => name.name().to_string(),
176            _ => {
177                return Err(SqlError::NotImplemented(
178                    "Complex table references not yet supported".into(),
179                ));
180            }
181        };
182
183        let table = self
184            .tables
185            .get(&table_name)
186            .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
187
188        // Collect matching source rows
189        let mut source_rows = Vec::new();
190
191        for row in &table.rows {
192            // Build row as HashMap
193            let row_map: HashMap<String, SochValue> = table
194                .columns
195                .iter()
196                .zip(row.iter())
197                .map(|(col, val)| (col.clone(), val.clone()))
198                .collect();
199
200            // Apply WHERE filter
201            if let Some(where_clause) = &select.where_clause
202                && !self.evaluate_where(where_clause, &row_map, params)?
203            {
204                continue;
205            }
206
207            source_rows.push(row_map);
208        }
209
210        // Apply ORDER BY
211        if !select.order_by.is_empty() {
212            source_rows.sort_by(|a, b| {
213                for order_item in &select.order_by {
214                    if let Expr::Column(col_ref) = &order_item.expr {
215                        let a_val = a.get(&col_ref.column);
216                        let b_val = b.get(&col_ref.column);
217
218                        let cmp = self.compare_values(a_val, b_val);
219                        if cmp != std::cmp::Ordering::Equal {
220                            return if order_item.asc { cmp } else { cmp.reverse() };
221                        }
222                    }
223                }
224                std::cmp::Ordering::Equal
225            });
226        }
227
228        // Apply OFFSET
229        if let Some(Expr::Literal(Literal::Integer(n))) = &select.offset {
230            let n = *n as usize;
231            if n < source_rows.len() {
232                source_rows = source_rows.into_iter().skip(n).collect();
233            } else {
234                source_rows.clear();
235            }
236        }
237
238        // Apply LIMIT
239        if let Some(Expr::Literal(Literal::Integer(n))) = &select.limit {
240            source_rows.truncate(*n as usize);
241        }
242
243        // Determine output columns and evaluate SELECT expressions
244        let mut output_columns: Vec<String> = Vec::new();
245        let mut result_rows: Vec<HashMap<String, SochValue>> = Vec::new();
246
247        // Check for SELECT *
248        let is_wildcard = matches!(&select.columns[..], [SelectItem::Wildcard]);
249
250        if is_wildcard {
251            output_columns = table.columns.clone();
252            result_rows = source_rows;
253        } else {
254            // Determine column names first
255            for item in &select.columns {
256                match item {
257                    SelectItem::Wildcard => output_columns.push("*".to_string()),
258                    SelectItem::QualifiedWildcard(t) => output_columns.push(format!("{}.*", t)),
259                    SelectItem::Expr { expr, alias } => {
260                        let col_name = alias.clone().unwrap_or_else(|| match expr {
261                            Expr::Column(col) => col.column.clone(),
262                            Expr::Function(func) => format!("{}()", func.name.name()),
263                            _ => "?column?".to_string(),
264                        });
265                        output_columns.push(col_name);
266                    }
267                }
268            }
269
270            // Now evaluate expressions for each row
271            for source_row in &source_rows {
272                let mut result_row = HashMap::new();
273
274                for (idx, item) in select.columns.iter().enumerate() {
275                    let col_name = &output_columns[idx];
276
277                    match item {
278                        SelectItem::Wildcard => {
279                            // Add all columns from source row
280                            for (k, v) in source_row {
281                                result_row.insert(k.clone(), v.clone());
282                            }
283                        }
284                        SelectItem::QualifiedWildcard(_) => {
285                            // For now, treat same as wildcard
286                            for (k, v) in source_row {
287                                result_row.insert(k.clone(), v.clone());
288                            }
289                        }
290                        SelectItem::Expr { expr, .. } => {
291                            let value = self.evaluate_expr(expr, source_row, params)?;
292                            result_row.insert(col_name.clone(), value);
293                        }
294                    }
295                }
296
297                result_rows.push(result_row);
298            }
299        }
300
301        Ok(ExecutionResult::Rows {
302            columns: output_columns,
303            rows: result_rows,
304        })
305    }
306
307    fn execute_insert(
308        &mut self,
309        insert: &InsertStmt,
310        params: &[SochValue],
311    ) -> SqlResult<ExecutionResult> {
312        let table_name = insert.table.name().to_string();
313
314        // First check if table exists and get column info
315        let table_columns = {
316            let table = self
317                .tables
318                .get(&table_name)
319                .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
320            table.columns.clone()
321        };
322
323        let mut rows_affected = 0;
324        let mut new_rows = Vec::new();
325
326        match &insert.source {
327            InsertSource::Values(rows) => {
328                for value_exprs in rows {
329                    let mut row_values = Vec::new();
330
331                    if let Some(columns) = &insert.columns {
332                        if columns.len() != value_exprs.len() {
333                            return Err(SqlError::InvalidArgument(format!(
334                                "Column count ({}) doesn't match value count ({})",
335                                columns.len(),
336                                value_exprs.len()
337                            )));
338                        }
339
340                        // Map columns to table column order
341                        for table_col in &table_columns {
342                            if let Some(pos) = columns.iter().position(|c| c == table_col) {
343                                let value =
344                                    self.evaluate_expr(&value_exprs[pos], &HashMap::new(), params)?;
345                                row_values.push(value);
346                            } else {
347                                row_values.push(SochValue::Null);
348                            }
349                        }
350                    } else {
351                        // Insert in column order
352                        for expr in value_exprs {
353                            let value = self.evaluate_expr(expr, &HashMap::new(), params)?;
354                            row_values.push(value);
355                        }
356                    }
357
358                    new_rows.push(row_values);
359                    rows_affected += 1;
360                }
361            }
362            InsertSource::Query(_) => {
363                return Err(SqlError::NotImplemented(
364                    "INSERT ... SELECT not yet supported".into(),
365                ));
366            }
367            InsertSource::Default => {
368                return Err(SqlError::NotImplemented(
369                    "INSERT DEFAULT VALUES not yet supported".into(),
370                ));
371            }
372        }
373
374        // Now add the rows to the table
375        let table = self.tables.get_mut(&table_name).unwrap();
376        for row in new_rows {
377            table.rows.push(row);
378        }
379
380        Ok(ExecutionResult::RowsAffected(rows_affected))
381    }
382
383    fn execute_update(
384        &mut self,
385        update: &UpdateStmt,
386        params: &[SochValue],
387    ) -> SqlResult<ExecutionResult> {
388        let table_name = update.table.name().to_string();
389
390        // First, collect table info and rows that need updating
391        let (_table_columns, updates_to_apply) = {
392            let table = self
393                .tables
394                .get(&table_name)
395                .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
396
397            let mut updates = Vec::new();
398
399            for row_idx in 0..table.rows.len() {
400                // Build row as HashMap for WHERE evaluation
401                let row_map: HashMap<String, SochValue> = table
402                    .columns
403                    .iter()
404                    .zip(table.rows[row_idx].iter())
405                    .map(|(col, val)| (col.clone(), val.clone()))
406                    .collect();
407
408                // Check WHERE condition
409                let matches = if let Some(where_clause) = &update.where_clause {
410                    self.evaluate_where(where_clause, &row_map, params)?
411                } else {
412                    true
413                };
414
415                if matches {
416                    // Collect the updates for this row
417                    let mut row_updates = Vec::new();
418                    for assignment in &update.assignments {
419                        if let Some(col_idx) =
420                            table.columns.iter().position(|c| c == &assignment.column)
421                        {
422                            let value = self.evaluate_expr(&assignment.value, &row_map, params)?;
423                            row_updates.push((col_idx, value));
424                        }
425                    }
426                    updates.push((row_idx, row_updates));
427                }
428            }
429
430            (table.columns.clone(), updates)
431        };
432
433        let rows_affected = updates_to_apply.len();
434
435        // Now apply the updates
436        let table = self.tables.get_mut(&table_name).unwrap();
437        for (row_idx, row_updates) in updates_to_apply {
438            for (col_idx, value) in row_updates {
439                table.rows[row_idx][col_idx] = value;
440            }
441        }
442
443        Ok(ExecutionResult::RowsAffected(rows_affected))
444    }
445
446    fn execute_delete(
447        &mut self,
448        delete: &DeleteStmt,
449        params: &[SochValue],
450    ) -> SqlResult<ExecutionResult> {
451        let table_name = delete.table.name().to_string();
452
453        // First, determine which rows to delete
454        let indices_to_remove = {
455            let table = self
456                .tables
457                .get(&table_name)
458                .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
459
460            let mut indices = Vec::new();
461
462            for (row_idx, row) in table.rows.iter().enumerate() {
463                // Build row as HashMap for WHERE evaluation
464                let row_map: HashMap<String, SochValue> = table
465                    .columns
466                    .iter()
467                    .zip(row.iter())
468                    .map(|(col, val)| (col.clone(), val.clone()))
469                    .collect();
470
471                // Check WHERE condition
472                let matches = if let Some(where_clause) = &delete.where_clause {
473                    self.evaluate_where(where_clause, &row_map, params)?
474                } else {
475                    true
476                };
477
478                if matches {
479                    indices.push(row_idx);
480                }
481            }
482
483            indices
484        };
485
486        let rows_affected = indices_to_remove.len();
487
488        // Now remove the rows
489        let table = self.tables.get_mut(&table_name).unwrap();
490        // Remove in reverse order to preserve indices
491        for idx in indices_to_remove.into_iter().rev() {
492            table.rows.remove(idx);
493        }
494
495        Ok(ExecutionResult::RowsAffected(rows_affected))
496    }
497
498    fn execute_create_table(&mut self, create: &CreateTableStmt) -> SqlResult<ExecutionResult> {
499        let table_name = create.name.name().to_string();
500
501        if self.tables.contains_key(&table_name) {
502            if create.if_not_exists {
503                return Ok(ExecutionResult::Ok);
504            }
505            return Err(SqlError::ConstraintViolation(format!(
506                "Table '{}' already exists",
507                table_name
508            )));
509        }
510
511        let columns: Vec<String> = create.columns.iter().map(|c| c.name.clone()).collect();
512        let column_types: Vec<DataType> =
513            create.columns.iter().map(|c| c.data_type.clone()).collect();
514
515        self.tables.insert(
516            table_name,
517            TableData {
518                columns,
519                column_types,
520                rows: Vec::new(),
521            },
522        );
523
524        Ok(ExecutionResult::Ok)
525    }
526
527    fn execute_drop_table(&mut self, drop: &DropTableStmt) -> SqlResult<ExecutionResult> {
528        for name in &drop.names {
529            let table_name = name.name().to_string();
530            if self.tables.remove(&table_name).is_none() && !drop.if_exists {
531                return Err(SqlError::TableNotFound(table_name));
532            }
533        }
534
535        Ok(ExecutionResult::Ok)
536    }
537
538    // ========== Expression Evaluation ==========
539
540    fn evaluate_where(
541        &self,
542        expr: &Expr,
543        row: &HashMap<String, SochValue>,
544        params: &[SochValue],
545    ) -> SqlResult<bool> {
546        let value = self.evaluate_expr(expr, row, params)?;
547        match value {
548            SochValue::Bool(b) => Ok(b),
549            SochValue::Null => Ok(false),
550            _ => Err(SqlError::TypeError(
551                "WHERE clause must evaluate to boolean".into(),
552            )),
553        }
554    }
555
556    fn evaluate_expr(
557        &self,
558        expr: &Expr,
559        row: &HashMap<String, SochValue>,
560        params: &[SochValue],
561    ) -> SqlResult<SochValue> {
562        match expr {
563            Expr::Literal(lit) => Ok(self.literal_to_value(lit)),
564
565            Expr::Column(col_ref) => row
566                .get(&col_ref.column)
567                .cloned()
568                .ok_or_else(|| SqlError::ColumnNotFound(col_ref.column.clone())),
569
570            Expr::Placeholder(n) => params
571                .get((*n as usize).saturating_sub(1))
572                .cloned()
573                .ok_or_else(|| SqlError::InvalidArgument(format!("Parameter ${} not provided", n))),
574
575            Expr::BinaryOp { left, op, right } => {
576                let left_val = self.evaluate_expr(left, row, params)?;
577                let right_val = self.evaluate_expr(right, row, params)?;
578                self.evaluate_binary_op(&left_val, op, &right_val)
579            }
580
581            Expr::UnaryOp { op, expr } => {
582                let val = self.evaluate_expr(expr, row, params)?;
583                self.evaluate_unary_op(op, &val)
584            }
585
586            Expr::IsNull { expr, negated } => {
587                let val = self.evaluate_expr(expr, row, params)?;
588                let is_null = matches!(val, SochValue::Null);
589                Ok(SochValue::Bool(if *negated { !is_null } else { is_null }))
590            }
591
592            Expr::InList {
593                expr,
594                list,
595                negated,
596            } => {
597                let val = self.evaluate_expr(expr, row, params)?;
598                let mut found = false;
599                for item in list {
600                    let item_val = self.evaluate_expr(item, row, params)?;
601                    if self.values_equal(&val, &item_val) {
602                        found = true;
603                        break;
604                    }
605                }
606                Ok(SochValue::Bool(if *negated { !found } else { found }))
607            }
608
609            Expr::Between {
610                expr,
611                low,
612                high,
613                negated,
614            } => {
615                let val = self.evaluate_expr(expr, row, params)?;
616                let low_val = self.evaluate_expr(low, row, params)?;
617                let high_val = self.evaluate_expr(high, row, params)?;
618
619                let cmp_low = self.compare_values(Some(&val), Some(&low_val));
620                let cmp_high = self.compare_values(Some(&val), Some(&high_val));
621
622                let in_range =
623                    cmp_low != std::cmp::Ordering::Less && cmp_high != std::cmp::Ordering::Greater;
624
625                Ok(SochValue::Bool(if *negated { !in_range } else { in_range }))
626            }
627
628            Expr::Like {
629                expr,
630                pattern,
631                negated,
632                ..
633            } => {
634                let val = self.evaluate_expr(expr, row, params)?;
635                let pattern_val = self.evaluate_expr(pattern, row, params)?;
636
637                match (&val, &pattern_val) {
638                    (SochValue::Text(s), SochValue::Text(p)) => {
639                        let regex_pattern = p.replace('%', ".*").replace('_', ".");
640                        let matches = regex::Regex::new(&format!("^{}$", regex_pattern))
641                            .map(|re| re.is_match(s))
642                            .unwrap_or(false);
643                        Ok(SochValue::Bool(if *negated { !matches } else { matches }))
644                    }
645                    _ => Ok(SochValue::Bool(false)),
646                }
647            }
648
649            Expr::Function(func) => self.evaluate_function(func, row, params),
650
651            Expr::Case {
652                operand,
653                conditions,
654                else_result,
655            } => {
656                if let Some(op) = operand {
657                    // Simple CASE
658                    let op_val = self.evaluate_expr(op, row, params)?;
659                    for (when_expr, then_expr) in conditions {
660                        let when_val = self.evaluate_expr(when_expr, row, params)?;
661                        if self.values_equal(&op_val, &when_val) {
662                            return self.evaluate_expr(then_expr, row, params);
663                        }
664                    }
665                } else {
666                    // Searched CASE
667                    for (when_expr, then_expr) in conditions {
668                        let when_val = self.evaluate_expr(when_expr, row, params)?;
669                        if matches!(when_val, SochValue::Bool(true)) {
670                            return self.evaluate_expr(then_expr, row, params);
671                        }
672                    }
673                }
674
675                if let Some(else_expr) = else_result {
676                    self.evaluate_expr(else_expr, row, params)
677                } else {
678                    Ok(SochValue::Null)
679                }
680            }
681
682            _ => Err(SqlError::NotImplemented(format!(
683                "Expression type {:?} not yet supported",
684                expr
685            ))),
686        }
687    }
688
689    fn literal_to_value(&self, lit: &Literal) -> SochValue {
690        match lit {
691            Literal::Null => SochValue::Null,
692            Literal::Boolean(b) => SochValue::Bool(*b),
693            Literal::Integer(n) => SochValue::Int(*n),
694            Literal::Float(f) => SochValue::Float(*f),
695            Literal::String(s) => SochValue::Text(s.clone()),
696            Literal::Blob(b) => SochValue::Binary(b.clone()),
697        }
698    }
699
700    fn evaluate_binary_op(
701        &self,
702        left: &SochValue,
703        op: &BinaryOperator,
704        right: &SochValue,
705    ) -> SqlResult<SochValue> {
706        match op {
707            BinaryOperator::Eq => Ok(SochValue::Bool(self.values_equal(left, right))),
708            BinaryOperator::Ne => Ok(SochValue::Bool(!self.values_equal(left, right))),
709            BinaryOperator::Lt => Ok(SochValue::Bool(
710                self.compare_values(Some(left), Some(right)) == std::cmp::Ordering::Less,
711            )),
712            BinaryOperator::Le => Ok(SochValue::Bool(
713                self.compare_values(Some(left), Some(right)) != std::cmp::Ordering::Greater,
714            )),
715            BinaryOperator::Gt => Ok(SochValue::Bool(
716                self.compare_values(Some(left), Some(right)) == std::cmp::Ordering::Greater,
717            )),
718            BinaryOperator::Ge => Ok(SochValue::Bool(
719                self.compare_values(Some(left), Some(right)) != std::cmp::Ordering::Less,
720            )),
721
722            BinaryOperator::And => match (left, right) {
723                (SochValue::Bool(l), SochValue::Bool(r)) => Ok(SochValue::Bool(*l && *r)),
724                (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
725                _ => Err(SqlError::TypeError("AND requires boolean operands".into())),
726            },
727
728            BinaryOperator::Or => match (left, right) {
729                (SochValue::Bool(l), SochValue::Bool(r)) => Ok(SochValue::Bool(*l || *r)),
730                (SochValue::Bool(true), _) | (_, SochValue::Bool(true)) => {
731                    Ok(SochValue::Bool(true))
732                }
733                (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
734                _ => Err(SqlError::TypeError("OR requires boolean operands".into())),
735            },
736
737            BinaryOperator::Plus => self.arithmetic_op(left, right, |a, b| a + b, |a, b| a + b),
738            BinaryOperator::Minus => self.arithmetic_op(left, right, |a, b| a - b, |a, b| a - b),
739            BinaryOperator::Multiply => self.arithmetic_op(left, right, |a, b| a * b, |a, b| a * b),
740            BinaryOperator::Divide => self.arithmetic_op(
741                left,
742                right,
743                |a, b| if b != 0 { a / b } else { 0 },
744                |a, b| a / b,
745            ),
746            BinaryOperator::Modulo => self.arithmetic_op(
747                left,
748                right,
749                |a, b| if b != 0 { a % b } else { 0 },
750                |a, b| a % b,
751            ),
752
753            BinaryOperator::Concat => match (left, right) {
754                (SochValue::Text(l), SochValue::Text(r)) => {
755                    Ok(SochValue::Text(format!("{}{}", l, r)))
756                }
757                (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
758                _ => Err(SqlError::TypeError("|| requires string operands".into())),
759            },
760
761            _ => Err(SqlError::NotImplemented(format!(
762                "Operator {:?} not implemented",
763                op
764            ))),
765        }
766    }
767
768    fn evaluate_unary_op(&self, op: &UnaryOperator, val: &SochValue) -> SqlResult<SochValue> {
769        match op {
770            UnaryOperator::Not => match val {
771                SochValue::Bool(b) => Ok(SochValue::Bool(!b)),
772                SochValue::Null => Ok(SochValue::Null),
773                _ => Err(SqlError::TypeError("NOT requires boolean operand".into())),
774            },
775            UnaryOperator::Minus => match val {
776                SochValue::Int(n) => Ok(SochValue::Int(-n)),
777                SochValue::Float(f) => Ok(SochValue::Float(-f)),
778                SochValue::Null => Ok(SochValue::Null),
779                _ => Err(SqlError::TypeError(
780                    "Unary minus requires numeric operand".into(),
781                )),
782            },
783            UnaryOperator::Plus => Ok(val.clone()),
784            UnaryOperator::BitNot => match val {
785                SochValue::Int(n) => Ok(SochValue::Int(!n)),
786                _ => Err(SqlError::TypeError("~ requires integer operand".into())),
787            },
788        }
789    }
790
791    fn evaluate_function(
792        &self,
793        func: &FunctionCall,
794        row: &HashMap<String, SochValue>,
795        params: &[SochValue],
796    ) -> SqlResult<SochValue> {
797        let func_name = func.name.name().to_uppercase();
798
799        match func_name.as_str() {
800            "COALESCE" => {
801                for arg in &func.args {
802                    let val = self.evaluate_expr(arg, row, params)?;
803                    if !matches!(val, SochValue::Null) {
804                        return Ok(val);
805                    }
806                }
807                Ok(SochValue::Null)
808            }
809
810            "NULLIF" => {
811                if func.args.len() != 2 {
812                    return Err(SqlError::InvalidArgument(
813                        "NULLIF requires 2 arguments".into(),
814                    ));
815                }
816                let val1 = self.evaluate_expr(&func.args[0], row, params)?;
817                let val2 = self.evaluate_expr(&func.args[1], row, params)?;
818                if self.values_equal(&val1, &val2) {
819                    Ok(SochValue::Null)
820                } else {
821                    Ok(val1)
822                }
823            }
824
825            "ABS" => {
826                if func.args.len() != 1 {
827                    return Err(SqlError::InvalidArgument("ABS requires 1 argument".into()));
828                }
829                let val = self.evaluate_expr(&func.args[0], row, params)?;
830                match val {
831                    SochValue::Int(n) => Ok(SochValue::Int(n.abs())),
832                    SochValue::Float(f) => Ok(SochValue::Float(f.abs())),
833                    SochValue::Null => Ok(SochValue::Null),
834                    _ => Err(SqlError::TypeError("ABS requires numeric argument".into())),
835                }
836            }
837
838            "LENGTH" | "LEN" => {
839                if func.args.len() != 1 {
840                    return Err(SqlError::InvalidArgument(
841                        "LENGTH requires 1 argument".into(),
842                    ));
843                }
844                let val = self.evaluate_expr(&func.args[0], row, params)?;
845                match val {
846                    SochValue::Text(s) => Ok(SochValue::Int(s.len() as i64)),
847                    SochValue::Binary(b) => Ok(SochValue::Int(b.len() as i64)),
848                    SochValue::Null => Ok(SochValue::Null),
849                    _ => Err(SqlError::TypeError(
850                        "LENGTH requires string argument".into(),
851                    )),
852                }
853            }
854
855            "UPPER" => {
856                if func.args.len() != 1 {
857                    return Err(SqlError::InvalidArgument(
858                        "UPPER requires 1 argument".into(),
859                    ));
860                }
861                let val = self.evaluate_expr(&func.args[0], row, params)?;
862                match val {
863                    SochValue::Text(s) => Ok(SochValue::Text(s.to_uppercase())),
864                    SochValue::Null => Ok(SochValue::Null),
865                    _ => Err(SqlError::TypeError("UPPER requires string argument".into())),
866                }
867            }
868
869            "LOWER" => {
870                if func.args.len() != 1 {
871                    return Err(SqlError::InvalidArgument(
872                        "LOWER requires 1 argument".into(),
873                    ));
874                }
875                let val = self.evaluate_expr(&func.args[0], row, params)?;
876                match val {
877                    SochValue::Text(s) => Ok(SochValue::Text(s.to_lowercase())),
878                    SochValue::Null => Ok(SochValue::Null),
879                    _ => Err(SqlError::TypeError("LOWER requires string argument".into())),
880                }
881            }
882
883            "TRIM" => {
884                if func.args.len() != 1 {
885                    return Err(SqlError::InvalidArgument("TRIM requires 1 argument".into()));
886                }
887                let val = self.evaluate_expr(&func.args[0], row, params)?;
888                match val {
889                    SochValue::Text(s) => Ok(SochValue::Text(s.trim().to_string())),
890                    SochValue::Null => Ok(SochValue::Null),
891                    _ => Err(SqlError::TypeError("TRIM requires string argument".into())),
892                }
893            }
894
895            "SUBSTR" | "SUBSTRING" => {
896                if func.args.len() < 2 || func.args.len() > 3 {
897                    return Err(SqlError::InvalidArgument(
898                        "SUBSTR requires 2 or 3 arguments".into(),
899                    ));
900                }
901                let val = self.evaluate_expr(&func.args[0], row, params)?;
902                let start = self.evaluate_expr(&func.args[1], row, params)?;
903                let len = if func.args.len() == 3 {
904                    Some(self.evaluate_expr(&func.args[2], row, params)?)
905                } else {
906                    None
907                };
908
909                match (val, start) {
910                    (SochValue::Text(s), SochValue::Int(start)) => {
911                        let start = (start.max(1) - 1) as usize;
912                        if start >= s.len() {
913                            return Ok(SochValue::Text(String::new()));
914                        }
915                        let result = if let Some(SochValue::Int(len)) = len {
916                            s.chars().skip(start).take(len as usize).collect()
917                        } else {
918                            s.chars().skip(start).collect()
919                        };
920                        Ok(SochValue::Text(result))
921                    }
922                    (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
923                    _ => Err(SqlError::TypeError(
924                        "SUBSTR requires string and integer arguments".into(),
925                    )),
926                }
927            }
928
929            _ => Err(SqlError::NotImplemented(format!(
930                "Function {} not implemented",
931                func_name
932            ))),
933        }
934    }
935
936    // ========== Helper Methods ==========
937
938    fn values_equal(&self, left: &SochValue, right: &SochValue) -> bool {
939        match (left, right) {
940            (SochValue::Null, _) | (_, SochValue::Null) => false,
941            (SochValue::Int(l), SochValue::Int(r)) => l == r,
942            (SochValue::Float(l), SochValue::Float(r)) => (l - r).abs() < f64::EPSILON,
943            (SochValue::Int(l), SochValue::Float(r)) => (*l as f64 - r).abs() < f64::EPSILON,
944            (SochValue::Float(l), SochValue::Int(r)) => (l - *r as f64).abs() < f64::EPSILON,
945            (SochValue::Text(l), SochValue::Text(r)) => l == r,
946            (SochValue::Bool(l), SochValue::Bool(r)) => l == r,
947            (SochValue::Binary(l), SochValue::Binary(r)) => l == r,
948            (SochValue::UInt(l), SochValue::UInt(r)) => l == r,
949            (SochValue::Int(l), SochValue::UInt(r)) => *l >= 0 && (*l as u64) == *r,
950            (SochValue::UInt(l), SochValue::Int(r)) => *r >= 0 && *l == (*r as u64),
951            _ => false,
952        }
953    }
954
955    fn compare_values(
956        &self,
957        left: Option<&SochValue>,
958        right: Option<&SochValue>,
959    ) -> std::cmp::Ordering {
960        match (left, right) {
961            (None, None) => std::cmp::Ordering::Equal,
962            (None, _) => std::cmp::Ordering::Less,
963            (_, None) => std::cmp::Ordering::Greater,
964            (Some(SochValue::Null), _) | (_, Some(SochValue::Null)) => std::cmp::Ordering::Equal,
965            (Some(SochValue::Int(l)), Some(SochValue::Int(r))) => l.cmp(r),
966            (Some(SochValue::Float(l)), Some(SochValue::Float(r))) => {
967                l.partial_cmp(r).unwrap_or(std::cmp::Ordering::Equal)
968            }
969            (Some(SochValue::Int(l)), Some(SochValue::Float(r))) => (*l as f64)
970                .partial_cmp(r)
971                .unwrap_or(std::cmp::Ordering::Equal),
972            (Some(SochValue::Float(l)), Some(SochValue::Int(r))) => l
973                .partial_cmp(&(*r as f64))
974                .unwrap_or(std::cmp::Ordering::Equal),
975            (Some(SochValue::Text(l)), Some(SochValue::Text(r))) => l.cmp(r),
976            (Some(SochValue::UInt(l)), Some(SochValue::UInt(r))) => l.cmp(r),
977            _ => std::cmp::Ordering::Equal,
978        }
979    }
980
981    fn arithmetic_op<FI, FF>(
982        &self,
983        left: &SochValue,
984        right: &SochValue,
985        int_op: FI,
986        float_op: FF,
987    ) -> SqlResult<SochValue>
988    where
989        FI: Fn(i64, i64) -> i64,
990        FF: Fn(f64, f64) -> f64,
991    {
992        match (left, right) {
993            (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
994            (SochValue::Int(l), SochValue::Int(r)) => Ok(SochValue::Int(int_op(*l, *r))),
995            (SochValue::Float(l), SochValue::Float(r)) => Ok(SochValue::Float(float_op(*l, *r))),
996            (SochValue::Int(l), SochValue::Float(r)) => {
997                Ok(SochValue::Float(float_op(*l as f64, *r)))
998            }
999            (SochValue::Float(l), SochValue::Int(r)) => {
1000                Ok(SochValue::Float(float_op(*l, *r as f64)))
1001            }
1002            (SochValue::UInt(l), SochValue::UInt(r)) => {
1003                Ok(SochValue::Int(int_op(*l as i64, *r as i64)))
1004            }
1005            (SochValue::Int(l), SochValue::UInt(r)) => Ok(SochValue::Int(int_op(*l, *r as i64))),
1006            (SochValue::UInt(l), SochValue::Int(r)) => Ok(SochValue::Int(int_op(*l as i64, *r))),
1007            _ => Err(SqlError::TypeError(
1008                "Arithmetic requires numeric operands".into(),
1009            )),
1010        }
1011    }
1012}
1013
1014#[cfg(test)]
1015mod tests {
1016    use super::*;
1017
1018    #[test]
1019    fn test_create_table_and_insert() {
1020        let mut executor = SqlExecutor::new();
1021
1022        // Create table
1023        let result = executor
1024            .execute("CREATE TABLE users (id INTEGER, name VARCHAR(100))")
1025            .unwrap();
1026        assert!(matches!(result, ExecutionResult::Ok));
1027
1028        // Insert rows
1029        let result = executor
1030            .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1031            .unwrap();
1032        assert_eq!(result.rows_affected(), 1);
1033
1034        let result = executor
1035            .execute("INSERT INTO users (id, name) VALUES (2, 'Bob')")
1036            .unwrap();
1037        assert_eq!(result.rows_affected(), 1);
1038
1039        // Select
1040        let result = executor.execute("SELECT * FROM users").unwrap();
1041        assert_eq!(result.rows_affected(), 2);
1042    }
1043
1044    #[test]
1045    fn test_select_with_where() {
1046        let mut executor = SqlExecutor::new();
1047
1048        executor
1049            .execute("CREATE TABLE products (id INTEGER, name TEXT, price FLOAT)")
1050            .unwrap();
1051        executor
1052            .execute("INSERT INTO products (id, name, price) VALUES (1, 'Apple', 1.50)")
1053            .unwrap();
1054        executor
1055            .execute("INSERT INTO products (id, name, price) VALUES (2, 'Banana', 0.75)")
1056            .unwrap();
1057        executor
1058            .execute("INSERT INTO products (id, name, price) VALUES (3, 'Orange', 2.00)")
1059            .unwrap();
1060
1061        let result = executor
1062            .execute("SELECT * FROM products WHERE price > 1.0")
1063            .unwrap();
1064        assert_eq!(result.rows_affected(), 2);
1065    }
1066
1067    #[test]
1068    fn test_update() {
1069        let mut executor = SqlExecutor::new();
1070
1071        executor
1072            .execute("CREATE TABLE users (id INTEGER, name TEXT)")
1073            .unwrap();
1074        executor
1075            .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1076            .unwrap();
1077
1078        let result = executor
1079            .execute("UPDATE users SET name = 'Alicia' WHERE id = 1")
1080            .unwrap();
1081        assert_eq!(result.rows_affected(), 1);
1082
1083        let result = executor
1084            .execute("SELECT * FROM users WHERE name = 'Alicia'")
1085            .unwrap();
1086        assert_eq!(result.rows_affected(), 1);
1087    }
1088
1089    #[test]
1090    fn test_delete() {
1091        let mut executor = SqlExecutor::new();
1092
1093        executor
1094            .execute("CREATE TABLE users (id INTEGER, name TEXT)")
1095            .unwrap();
1096        executor
1097            .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1098            .unwrap();
1099        executor
1100            .execute("INSERT INTO users (id, name) VALUES (2, 'Bob')")
1101            .unwrap();
1102
1103        let result = executor.execute("DELETE FROM users WHERE id = 1").unwrap();
1104        assert_eq!(result.rows_affected(), 1);
1105
1106        let result = executor.execute("SELECT * FROM users").unwrap();
1107        assert_eq!(result.rows_affected(), 1);
1108    }
1109
1110    #[test]
1111    fn test_functions() {
1112        let mut executor = SqlExecutor::new();
1113
1114        executor.execute("CREATE TABLE t (s TEXT)").unwrap();
1115        executor
1116            .execute("INSERT INTO t (s) VALUES ('hello')")
1117            .unwrap();
1118
1119        let result = executor.execute("SELECT UPPER(s) FROM t").unwrap();
1120        if let ExecutionResult::Rows { rows, .. } = result {
1121            let row = &rows[0];
1122            // The column name might be UPPER(s) or similar
1123            assert!(
1124                row.values()
1125                    .any(|v| matches!(v, SochValue::Text(s) if s == "HELLO"))
1126            );
1127        } else {
1128            panic!("Expected rows");
1129        }
1130    }
1131
1132    #[test]
1133    fn test_order_by() {
1134        let mut executor = SqlExecutor::new();
1135
1136        executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1137        executor.execute("INSERT INTO nums (n) VALUES (3)").unwrap();
1138        executor.execute("INSERT INTO nums (n) VALUES (1)").unwrap();
1139        executor.execute("INSERT INTO nums (n) VALUES (2)").unwrap();
1140
1141        let result = executor
1142            .execute("SELECT * FROM nums ORDER BY n ASC")
1143            .unwrap();
1144        if let ExecutionResult::Rows { rows, .. } = result {
1145            let values: Vec<i64> = rows
1146                .iter()
1147                .filter_map(|r| r.get("n"))
1148                .filter_map(|v| {
1149                    if let SochValue::Int(n) = v {
1150                        Some(*n)
1151                    } else {
1152                        None
1153                    }
1154                })
1155                .collect();
1156            assert_eq!(values, vec![1, 2, 3]);
1157        } else {
1158            panic!("Expected rows");
1159        }
1160    }
1161
1162    #[test]
1163    fn test_limit_offset() {
1164        let mut executor = SqlExecutor::new();
1165
1166        executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1167        for i in 1..=10 {
1168            executor
1169                .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1170                .unwrap();
1171        }
1172
1173        let result = executor
1174            .execute("SELECT * FROM nums LIMIT 3 OFFSET 2")
1175            .unwrap();
1176        assert_eq!(result.rows_affected(), 3);
1177    }
1178
1179    #[test]
1180    fn test_between() {
1181        let mut executor = SqlExecutor::new();
1182
1183        executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1184        for i in 1..=10 {
1185            executor
1186                .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1187                .unwrap();
1188        }
1189
1190        let result = executor
1191            .execute("SELECT * FROM nums WHERE n BETWEEN 3 AND 7")
1192            .unwrap();
1193        assert_eq!(result.rows_affected(), 5);
1194    }
1195
1196    #[test]
1197    fn test_in_list() {
1198        let mut executor = SqlExecutor::new();
1199
1200        executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1201        for i in 1..=5 {
1202            executor
1203                .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1204                .unwrap();
1205        }
1206
1207        let result = executor
1208            .execute("SELECT * FROM nums WHERE n IN (1, 3, 5)")
1209            .unwrap();
1210        assert_eq!(result.rows_affected(), 3);
1211    }
1212}