sochdb_query/sql/
mod.rs

1// Copyright 2025 Sushanth (https://github.com/sushanthpy)
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! SQL Module - Complete SQL support for SochDB
16//!
17//! Provides SQL parsing, planning, and execution.
18//!
19//! # Example
20//!
21//! ```rust,ignore
22//! use sochdb_query::sql::{Parser, Statement};
23//!
24//! let stmt = Parser::parse("SELECT * FROM users WHERE id = 1")?;
25//! ```
26
27pub mod ast;
28pub mod bridge;
29pub mod compatibility;
30pub mod error;
31pub mod lexer;
32pub mod parser;
33pub mod token;
34
35pub use ast::*;
36pub use bridge::{ExecutionResult as BridgeExecutionResult, SqlBridge, SqlConnection};
37pub use compatibility::{CompatibilityMatrix, FeatureSupport, SqlDialect, SqlFeature, get_feature_support};
38pub use error::{SqlError, SqlResult};
39pub use lexer::{LexError, Lexer};
40pub use parser::{ParseError, Parser};
41pub use token::{Span, Token, TokenKind};
42
43use std::collections::HashMap;
44use sochdb_core::SochValue;
45
46/// Result of SQL execution
47#[derive(Debug, Clone)]
48pub enum ExecutionResult {
49    /// Query returned rows
50    Rows {
51        columns: Vec<String>,
52        rows: Vec<HashMap<String, SochValue>>,
53    },
54    /// Statement affected N rows
55    RowsAffected(usize),
56    /// Statement completed successfully
57    Ok,
58}
59
60impl ExecutionResult {
61    /// Get rows if this is a query result
62    pub fn rows(&self) -> Option<&Vec<HashMap<String, SochValue>>> {
63        match self {
64            ExecutionResult::Rows { rows, .. } => Some(rows),
65            _ => None,
66        }
67    }
68
69    /// Get column names if this is a query result
70    pub fn columns(&self) -> Option<&Vec<String>> {
71        match self {
72            ExecutionResult::Rows { columns, .. } => Some(columns),
73            _ => None,
74        }
75    }
76
77    /// Get affected row count
78    pub fn rows_affected(&self) -> usize {
79        match self {
80            ExecutionResult::RowsAffected(n) => *n,
81            ExecutionResult::Rows { rows, .. } => rows.len(),
82            ExecutionResult::Ok => 0,
83        }
84    }
85}
86
87/// Simple SQL executor for standalone use
88///
89/// For full database integration, use the SqlConnection in sochdb-storage
90pub struct SqlExecutor {
91    /// Tables stored in memory (for testing/standalone use)
92    tables: HashMap<String, TableData>,
93}
94
95/// In-memory table data
96#[derive(Debug, Clone)]
97pub struct TableData {
98    pub columns: Vec<String>,
99    pub column_types: Vec<DataType>,
100    pub rows: Vec<Vec<SochValue>>,
101}
102
103impl Default for SqlExecutor {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109impl SqlExecutor {
110    /// Create a new SQL executor
111    pub fn new() -> Self {
112        Self {
113            tables: HashMap::new(),
114        }
115    }
116
117    /// Execute a SQL statement
118    pub fn execute(&mut self, sql: &str) -> SqlResult<ExecutionResult> {
119        self.execute_with_params(sql, &[])
120    }
121
122    /// Execute a SQL statement with parameters
123    pub fn execute_with_params(
124        &mut self,
125        sql: &str,
126        params: &[SochValue],
127    ) -> SqlResult<ExecutionResult> {
128        let stmt = Parser::parse(sql).map_err(SqlError::from_parse_errors)?;
129        self.execute_statement(&stmt, params)
130    }
131
132    /// Execute a parsed statement
133    pub fn execute_statement(
134        &mut self,
135        stmt: &Statement,
136        params: &[SochValue],
137    ) -> SqlResult<ExecutionResult> {
138        match stmt {
139            Statement::Select(select) => self.execute_select(select, params),
140            Statement::Insert(insert) => self.execute_insert(insert, params),
141            Statement::Update(update) => self.execute_update(update, params),
142            Statement::Delete(delete) => self.execute_delete(delete, params),
143            Statement::CreateTable(create) => self.execute_create_table(create),
144            Statement::DropTable(drop) => self.execute_drop_table(drop),
145            Statement::Begin(_) => Ok(ExecutionResult::Ok),
146            Statement::Commit => Ok(ExecutionResult::Ok),
147            Statement::Rollback(_) => Ok(ExecutionResult::Ok),
148            _ => Err(SqlError::NotImplemented(
149                "Statement type not yet supported".into(),
150            )),
151        }
152    }
153
154    fn execute_select(
155        &self,
156        select: &SelectStmt,
157        params: &[SochValue],
158    ) -> SqlResult<ExecutionResult> {
159        // Get the table
160        let from = select
161            .from
162            .as_ref()
163            .ok_or_else(|| SqlError::InvalidArgument("SELECT requires FROM clause".into()))?;
164
165        if from.tables.len() != 1 {
166            return Err(SqlError::NotImplemented(
167                "Multi-table queries not yet supported".into(),
168            ));
169        }
170
171        let table_name = match &from.tables[0] {
172            TableRef::Table { name, .. } => name.name().to_string(),
173            _ => {
174                return Err(SqlError::NotImplemented(
175                    "Complex table references not yet supported".into(),
176                ));
177            }
178        };
179
180        let table = self
181            .tables
182            .get(&table_name)
183            .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
184
185        // Collect matching source rows
186        let mut source_rows = Vec::new();
187
188        for row in &table.rows {
189            // Build row as HashMap
190            let row_map: HashMap<String, SochValue> = table
191                .columns
192                .iter()
193                .zip(row.iter())
194                .map(|(col, val)| (col.clone(), val.clone()))
195                .collect();
196
197            // Apply WHERE filter
198            if let Some(where_clause) = &select.where_clause
199                && !self.evaluate_where(where_clause, &row_map, params)?
200            {
201                continue;
202            }
203
204            source_rows.push(row_map);
205        }
206
207        // Apply ORDER BY
208        if !select.order_by.is_empty() {
209            source_rows.sort_by(|a, b| {
210                for order_item in &select.order_by {
211                    if let Expr::Column(col_ref) = &order_item.expr {
212                        let a_val = a.get(&col_ref.column);
213                        let b_val = b.get(&col_ref.column);
214
215                        let cmp = self.compare_values(a_val, b_val);
216                        if cmp != std::cmp::Ordering::Equal {
217                            return if order_item.asc { cmp } else { cmp.reverse() };
218                        }
219                    }
220                }
221                std::cmp::Ordering::Equal
222            });
223        }
224
225        // Apply OFFSET
226        if let Some(Expr::Literal(Literal::Integer(n))) = &select.offset {
227            let n = *n as usize;
228            if n < source_rows.len() {
229                source_rows = source_rows.into_iter().skip(n).collect();
230            } else {
231                source_rows.clear();
232            }
233        }
234
235        // Apply LIMIT
236        if let Some(Expr::Literal(Literal::Integer(n))) = &select.limit {
237            source_rows.truncate(*n as usize);
238        }
239
240        // Determine output columns and evaluate SELECT expressions
241        let mut output_columns: Vec<String> = Vec::new();
242        let mut result_rows: Vec<HashMap<String, SochValue>> = Vec::new();
243
244        // Check for SELECT *
245        let is_wildcard = matches!(&select.columns[..], [SelectItem::Wildcard]);
246
247        if is_wildcard {
248            output_columns = table.columns.clone();
249            result_rows = source_rows;
250        } else {
251            // Determine column names first
252            for item in &select.columns {
253                match item {
254                    SelectItem::Wildcard => output_columns.push("*".to_string()),
255                    SelectItem::QualifiedWildcard(t) => output_columns.push(format!("{}.*", t)),
256                    SelectItem::Expr { expr, alias } => {
257                        let col_name = alias.clone().unwrap_or_else(|| match expr {
258                            Expr::Column(col) => col.column.clone(),
259                            Expr::Function(func) => format!("{}()", func.name.name()),
260                            _ => "?column?".to_string(),
261                        });
262                        output_columns.push(col_name);
263                    }
264                }
265            }
266
267            // Now evaluate expressions for each row
268            for source_row in &source_rows {
269                let mut result_row = HashMap::new();
270
271                for (idx, item) in select.columns.iter().enumerate() {
272                    let col_name = &output_columns[idx];
273
274                    match item {
275                        SelectItem::Wildcard => {
276                            // Add all columns from source row
277                            for (k, v) in source_row {
278                                result_row.insert(k.clone(), v.clone());
279                            }
280                        }
281                        SelectItem::QualifiedWildcard(_) => {
282                            // For now, treat same as wildcard
283                            for (k, v) in source_row {
284                                result_row.insert(k.clone(), v.clone());
285                            }
286                        }
287                        SelectItem::Expr { expr, .. } => {
288                            let value = self.evaluate_expr(expr, source_row, params)?;
289                            result_row.insert(col_name.clone(), value);
290                        }
291                    }
292                }
293
294                result_rows.push(result_row);
295            }
296        }
297
298        Ok(ExecutionResult::Rows {
299            columns: output_columns,
300            rows: result_rows,
301        })
302    }
303
304    fn execute_insert(
305        &mut self,
306        insert: &InsertStmt,
307        params: &[SochValue],
308    ) -> SqlResult<ExecutionResult> {
309        let table_name = insert.table.name().to_string();
310
311        // First check if table exists and get column info
312        let table_columns = {
313            let table = self
314                .tables
315                .get(&table_name)
316                .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
317            table.columns.clone()
318        };
319
320        let mut rows_affected = 0;
321        let mut new_rows = Vec::new();
322
323        match &insert.source {
324            InsertSource::Values(rows) => {
325                for value_exprs in rows {
326                    let mut row_values = Vec::new();
327
328                    if let Some(columns) = &insert.columns {
329                        if columns.len() != value_exprs.len() {
330                            return Err(SqlError::InvalidArgument(format!(
331                                "Column count ({}) doesn't match value count ({})",
332                                columns.len(),
333                                value_exprs.len()
334                            )));
335                        }
336
337                        // Map columns to table column order
338                        for table_col in &table_columns {
339                            if let Some(pos) = columns.iter().position(|c| c == table_col) {
340                                let value =
341                                    self.evaluate_expr(&value_exprs[pos], &HashMap::new(), params)?;
342                                row_values.push(value);
343                            } else {
344                                row_values.push(SochValue::Null);
345                            }
346                        }
347                    } else {
348                        // Insert in column order
349                        for expr in value_exprs {
350                            let value = self.evaluate_expr(expr, &HashMap::new(), params)?;
351                            row_values.push(value);
352                        }
353                    }
354
355                    new_rows.push(row_values);
356                    rows_affected += 1;
357                }
358            }
359            InsertSource::Query(_) => {
360                return Err(SqlError::NotImplemented(
361                    "INSERT ... SELECT not yet supported".into(),
362                ));
363            }
364            InsertSource::Default => {
365                return Err(SqlError::NotImplemented(
366                    "INSERT DEFAULT VALUES not yet supported".into(),
367                ));
368            }
369        }
370
371        // Now add the rows to the table
372        let table = self.tables.get_mut(&table_name).unwrap();
373        for row in new_rows {
374            table.rows.push(row);
375        }
376
377        Ok(ExecutionResult::RowsAffected(rows_affected))
378    }
379
380    fn execute_update(
381        &mut self,
382        update: &UpdateStmt,
383        params: &[SochValue],
384    ) -> SqlResult<ExecutionResult> {
385        let table_name = update.table.name().to_string();
386
387        // First, collect table info and rows that need updating
388        let (_table_columns, updates_to_apply) = {
389            let table = self
390                .tables
391                .get(&table_name)
392                .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
393
394            let mut updates = Vec::new();
395
396            for row_idx in 0..table.rows.len() {
397                // Build row as HashMap for WHERE evaluation
398                let row_map: HashMap<String, SochValue> = table
399                    .columns
400                    .iter()
401                    .zip(table.rows[row_idx].iter())
402                    .map(|(col, val)| (col.clone(), val.clone()))
403                    .collect();
404
405                // Check WHERE condition
406                let matches = if let Some(where_clause) = &update.where_clause {
407                    self.evaluate_where(where_clause, &row_map, params)?
408                } else {
409                    true
410                };
411
412                if matches {
413                    // Collect the updates for this row
414                    let mut row_updates = Vec::new();
415                    for assignment in &update.assignments {
416                        if let Some(col_idx) =
417                            table.columns.iter().position(|c| c == &assignment.column)
418                        {
419                            let value = self.evaluate_expr(&assignment.value, &row_map, params)?;
420                            row_updates.push((col_idx, value));
421                        }
422                    }
423                    updates.push((row_idx, row_updates));
424                }
425            }
426
427            (table.columns.clone(), updates)
428        };
429
430        let rows_affected = updates_to_apply.len();
431
432        // Now apply the updates
433        let table = self.tables.get_mut(&table_name).unwrap();
434        for (row_idx, row_updates) in updates_to_apply {
435            for (col_idx, value) in row_updates {
436                table.rows[row_idx][col_idx] = value;
437            }
438        }
439
440        Ok(ExecutionResult::RowsAffected(rows_affected))
441    }
442
443    fn execute_delete(
444        &mut self,
445        delete: &DeleteStmt,
446        params: &[SochValue],
447    ) -> SqlResult<ExecutionResult> {
448        let table_name = delete.table.name().to_string();
449
450        // First, determine which rows to delete
451        let indices_to_remove = {
452            let table = self
453                .tables
454                .get(&table_name)
455                .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
456
457            let mut indices = Vec::new();
458
459            for (row_idx, row) in table.rows.iter().enumerate() {
460                // Build row as HashMap for WHERE evaluation
461                let row_map: HashMap<String, SochValue> = table
462                    .columns
463                    .iter()
464                    .zip(row.iter())
465                    .map(|(col, val)| (col.clone(), val.clone()))
466                    .collect();
467
468                // Check WHERE condition
469                let matches = if let Some(where_clause) = &delete.where_clause {
470                    self.evaluate_where(where_clause, &row_map, params)?
471                } else {
472                    true
473                };
474
475                if matches {
476                    indices.push(row_idx);
477                }
478            }
479
480            indices
481        };
482
483        let rows_affected = indices_to_remove.len();
484
485        // Now remove the rows
486        let table = self.tables.get_mut(&table_name).unwrap();
487        // Remove in reverse order to preserve indices
488        for idx in indices_to_remove.into_iter().rev() {
489            table.rows.remove(idx);
490        }
491
492        Ok(ExecutionResult::RowsAffected(rows_affected))
493    }
494
495    fn execute_create_table(&mut self, create: &CreateTableStmt) -> SqlResult<ExecutionResult> {
496        let table_name = create.name.name().to_string();
497
498        if self.tables.contains_key(&table_name) {
499            if create.if_not_exists {
500                return Ok(ExecutionResult::Ok);
501            }
502            return Err(SqlError::ConstraintViolation(format!(
503                "Table '{}' already exists",
504                table_name
505            )));
506        }
507
508        let columns: Vec<String> = create.columns.iter().map(|c| c.name.clone()).collect();
509        let column_types: Vec<DataType> =
510            create.columns.iter().map(|c| c.data_type.clone()).collect();
511
512        self.tables.insert(
513            table_name,
514            TableData {
515                columns,
516                column_types,
517                rows: Vec::new(),
518            },
519        );
520
521        Ok(ExecutionResult::Ok)
522    }
523
524    fn execute_drop_table(&mut self, drop: &DropTableStmt) -> SqlResult<ExecutionResult> {
525        for name in &drop.names {
526            let table_name = name.name().to_string();
527            if self.tables.remove(&table_name).is_none() && !drop.if_exists {
528                return Err(SqlError::TableNotFound(table_name));
529            }
530        }
531
532        Ok(ExecutionResult::Ok)
533    }
534
535    // ========== Expression Evaluation ==========
536
537    fn evaluate_where(
538        &self,
539        expr: &Expr,
540        row: &HashMap<String, SochValue>,
541        params: &[SochValue],
542    ) -> SqlResult<bool> {
543        let value = self.evaluate_expr(expr, row, params)?;
544        match value {
545            SochValue::Bool(b) => Ok(b),
546            SochValue::Null => Ok(false),
547            _ => Err(SqlError::TypeError(
548                "WHERE clause must evaluate to boolean".into(),
549            )),
550        }
551    }
552
553    fn evaluate_expr(
554        &self,
555        expr: &Expr,
556        row: &HashMap<String, SochValue>,
557        params: &[SochValue],
558    ) -> SqlResult<SochValue> {
559        match expr {
560            Expr::Literal(lit) => Ok(self.literal_to_value(lit)),
561
562            Expr::Column(col_ref) => row
563                .get(&col_ref.column)
564                .cloned()
565                .ok_or_else(|| SqlError::ColumnNotFound(col_ref.column.clone())),
566
567            Expr::Placeholder(n) => params
568                .get((*n as usize).saturating_sub(1))
569                .cloned()
570                .ok_or_else(|| SqlError::InvalidArgument(format!("Parameter ${} not provided", n))),
571
572            Expr::BinaryOp { left, op, right } => {
573                let left_val = self.evaluate_expr(left, row, params)?;
574                let right_val = self.evaluate_expr(right, row, params)?;
575                self.evaluate_binary_op(&left_val, op, &right_val)
576            }
577
578            Expr::UnaryOp { op, expr } => {
579                let val = self.evaluate_expr(expr, row, params)?;
580                self.evaluate_unary_op(op, &val)
581            }
582
583            Expr::IsNull { expr, negated } => {
584                let val = self.evaluate_expr(expr, row, params)?;
585                let is_null = matches!(val, SochValue::Null);
586                Ok(SochValue::Bool(if *negated { !is_null } else { is_null }))
587            }
588
589            Expr::InList {
590                expr,
591                list,
592                negated,
593            } => {
594                let val = self.evaluate_expr(expr, row, params)?;
595                let mut found = false;
596                for item in list {
597                    let item_val = self.evaluate_expr(item, row, params)?;
598                    if self.values_equal(&val, &item_val) {
599                        found = true;
600                        break;
601                    }
602                }
603                Ok(SochValue::Bool(if *negated { !found } else { found }))
604            }
605
606            Expr::Between {
607                expr,
608                low,
609                high,
610                negated,
611            } => {
612                let val = self.evaluate_expr(expr, row, params)?;
613                let low_val = self.evaluate_expr(low, row, params)?;
614                let high_val = self.evaluate_expr(high, row, params)?;
615
616                let cmp_low = self.compare_values(Some(&val), Some(&low_val));
617                let cmp_high = self.compare_values(Some(&val), Some(&high_val));
618
619                let in_range =
620                    cmp_low != std::cmp::Ordering::Less && cmp_high != std::cmp::Ordering::Greater;
621
622                Ok(SochValue::Bool(if *negated { !in_range } else { in_range }))
623            }
624
625            Expr::Like {
626                expr,
627                pattern,
628                negated,
629                ..
630            } => {
631                let val = self.evaluate_expr(expr, row, params)?;
632                let pattern_val = self.evaluate_expr(pattern, row, params)?;
633
634                match (&val, &pattern_val) {
635                    (SochValue::Text(s), SochValue::Text(p)) => {
636                        let regex_pattern = p.replace('%', ".*").replace('_', ".");
637                        let matches = regex::Regex::new(&format!("^{}$", regex_pattern))
638                            .map(|re| re.is_match(s))
639                            .unwrap_or(false);
640                        Ok(SochValue::Bool(if *negated { !matches } else { matches }))
641                    }
642                    _ => Ok(SochValue::Bool(false)),
643                }
644            }
645
646            Expr::Function(func) => self.evaluate_function(func, row, params),
647
648            Expr::Case {
649                operand,
650                conditions,
651                else_result,
652            } => {
653                if let Some(op) = operand {
654                    // Simple CASE
655                    let op_val = self.evaluate_expr(op, row, params)?;
656                    for (when_expr, then_expr) in conditions {
657                        let when_val = self.evaluate_expr(when_expr, row, params)?;
658                        if self.values_equal(&op_val, &when_val) {
659                            return self.evaluate_expr(then_expr, row, params);
660                        }
661                    }
662                } else {
663                    // Searched CASE
664                    for (when_expr, then_expr) in conditions {
665                        let when_val = self.evaluate_expr(when_expr, row, params)?;
666                        if matches!(when_val, SochValue::Bool(true)) {
667                            return self.evaluate_expr(then_expr, row, params);
668                        }
669                    }
670                }
671
672                if let Some(else_expr) = else_result {
673                    self.evaluate_expr(else_expr, row, params)
674                } else {
675                    Ok(SochValue::Null)
676                }
677            }
678
679            _ => Err(SqlError::NotImplemented(format!(
680                "Expression type {:?} not yet supported",
681                expr
682            ))),
683        }
684    }
685
686    fn literal_to_value(&self, lit: &Literal) -> SochValue {
687        match lit {
688            Literal::Null => SochValue::Null,
689            Literal::Boolean(b) => SochValue::Bool(*b),
690            Literal::Integer(n) => SochValue::Int(*n),
691            Literal::Float(f) => SochValue::Float(*f),
692            Literal::String(s) => SochValue::Text(s.clone()),
693            Literal::Blob(b) => SochValue::Binary(b.clone()),
694        }
695    }
696
697    fn evaluate_binary_op(
698        &self,
699        left: &SochValue,
700        op: &BinaryOperator,
701        right: &SochValue,
702    ) -> SqlResult<SochValue> {
703        match op {
704            BinaryOperator::Eq => Ok(SochValue::Bool(self.values_equal(left, right))),
705            BinaryOperator::Ne => Ok(SochValue::Bool(!self.values_equal(left, right))),
706            BinaryOperator::Lt => Ok(SochValue::Bool(
707                self.compare_values(Some(left), Some(right)) == std::cmp::Ordering::Less,
708            )),
709            BinaryOperator::Le => Ok(SochValue::Bool(
710                self.compare_values(Some(left), Some(right)) != std::cmp::Ordering::Greater,
711            )),
712            BinaryOperator::Gt => Ok(SochValue::Bool(
713                self.compare_values(Some(left), Some(right)) == std::cmp::Ordering::Greater,
714            )),
715            BinaryOperator::Ge => Ok(SochValue::Bool(
716                self.compare_values(Some(left), Some(right)) != std::cmp::Ordering::Less,
717            )),
718
719            BinaryOperator::And => match (left, right) {
720                (SochValue::Bool(l), SochValue::Bool(r)) => Ok(SochValue::Bool(*l && *r)),
721                (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
722                _ => Err(SqlError::TypeError("AND requires boolean operands".into())),
723            },
724
725            BinaryOperator::Or => match (left, right) {
726                (SochValue::Bool(l), SochValue::Bool(r)) => Ok(SochValue::Bool(*l || *r)),
727                (SochValue::Bool(true), _) | (_, SochValue::Bool(true)) => {
728                    Ok(SochValue::Bool(true))
729                }
730                (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
731                _ => Err(SqlError::TypeError("OR requires boolean operands".into())),
732            },
733
734            BinaryOperator::Plus => self.arithmetic_op(left, right, |a, b| a + b, |a, b| a + b),
735            BinaryOperator::Minus => self.arithmetic_op(left, right, |a, b| a - b, |a, b| a - b),
736            BinaryOperator::Multiply => self.arithmetic_op(left, right, |a, b| a * b, |a, b| a * b),
737            BinaryOperator::Divide => self.arithmetic_op(
738                left,
739                right,
740                |a, b| if b != 0 { a / b } else { 0 },
741                |a, b| a / b,
742            ),
743            BinaryOperator::Modulo => self.arithmetic_op(
744                left,
745                right,
746                |a, b| if b != 0 { a % b } else { 0 },
747                |a, b| a % b,
748            ),
749
750            BinaryOperator::Concat => match (left, right) {
751                (SochValue::Text(l), SochValue::Text(r)) => {
752                    Ok(SochValue::Text(format!("{}{}", l, r)))
753                }
754                (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
755                _ => Err(SqlError::TypeError("|| requires string operands".into())),
756            },
757
758            _ => Err(SqlError::NotImplemented(format!(
759                "Operator {:?} not implemented",
760                op
761            ))),
762        }
763    }
764
765    fn evaluate_unary_op(&self, op: &UnaryOperator, val: &SochValue) -> SqlResult<SochValue> {
766        match op {
767            UnaryOperator::Not => match val {
768                SochValue::Bool(b) => Ok(SochValue::Bool(!b)),
769                SochValue::Null => Ok(SochValue::Null),
770                _ => Err(SqlError::TypeError("NOT requires boolean operand".into())),
771            },
772            UnaryOperator::Minus => match val {
773                SochValue::Int(n) => Ok(SochValue::Int(-n)),
774                SochValue::Float(f) => Ok(SochValue::Float(-f)),
775                SochValue::Null => Ok(SochValue::Null),
776                _ => Err(SqlError::TypeError(
777                    "Unary minus requires numeric operand".into(),
778                )),
779            },
780            UnaryOperator::Plus => Ok(val.clone()),
781            UnaryOperator::BitNot => match val {
782                SochValue::Int(n) => Ok(SochValue::Int(!n)),
783                _ => Err(SqlError::TypeError("~ requires integer operand".into())),
784            },
785        }
786    }
787
788    fn evaluate_function(
789        &self,
790        func: &FunctionCall,
791        row: &HashMap<String, SochValue>,
792        params: &[SochValue],
793    ) -> SqlResult<SochValue> {
794        let func_name = func.name.name().to_uppercase();
795
796        match func_name.as_str() {
797            "COALESCE" => {
798                for arg in &func.args {
799                    let val = self.evaluate_expr(arg, row, params)?;
800                    if !matches!(val, SochValue::Null) {
801                        return Ok(val);
802                    }
803                }
804                Ok(SochValue::Null)
805            }
806
807            "NULLIF" => {
808                if func.args.len() != 2 {
809                    return Err(SqlError::InvalidArgument(
810                        "NULLIF requires 2 arguments".into(),
811                    ));
812                }
813                let val1 = self.evaluate_expr(&func.args[0], row, params)?;
814                let val2 = self.evaluate_expr(&func.args[1], row, params)?;
815                if self.values_equal(&val1, &val2) {
816                    Ok(SochValue::Null)
817                } else {
818                    Ok(val1)
819                }
820            }
821
822            "ABS" => {
823                if func.args.len() != 1 {
824                    return Err(SqlError::InvalidArgument("ABS requires 1 argument".into()));
825                }
826                let val = self.evaluate_expr(&func.args[0], row, params)?;
827                match val {
828                    SochValue::Int(n) => Ok(SochValue::Int(n.abs())),
829                    SochValue::Float(f) => Ok(SochValue::Float(f.abs())),
830                    SochValue::Null => Ok(SochValue::Null),
831                    _ => Err(SqlError::TypeError("ABS requires numeric argument".into())),
832                }
833            }
834
835            "LENGTH" | "LEN" => {
836                if func.args.len() != 1 {
837                    return Err(SqlError::InvalidArgument(
838                        "LENGTH requires 1 argument".into(),
839                    ));
840                }
841                let val = self.evaluate_expr(&func.args[0], row, params)?;
842                match val {
843                    SochValue::Text(s) => Ok(SochValue::Int(s.len() as i64)),
844                    SochValue::Binary(b) => Ok(SochValue::Int(b.len() as i64)),
845                    SochValue::Null => Ok(SochValue::Null),
846                    _ => Err(SqlError::TypeError(
847                        "LENGTH requires string argument".into(),
848                    )),
849                }
850            }
851
852            "UPPER" => {
853                if func.args.len() != 1 {
854                    return Err(SqlError::InvalidArgument(
855                        "UPPER requires 1 argument".into(),
856                    ));
857                }
858                let val = self.evaluate_expr(&func.args[0], row, params)?;
859                match val {
860                    SochValue::Text(s) => Ok(SochValue::Text(s.to_uppercase())),
861                    SochValue::Null => Ok(SochValue::Null),
862                    _ => Err(SqlError::TypeError("UPPER requires string argument".into())),
863                }
864            }
865
866            "LOWER" => {
867                if func.args.len() != 1 {
868                    return Err(SqlError::InvalidArgument(
869                        "LOWER requires 1 argument".into(),
870                    ));
871                }
872                let val = self.evaluate_expr(&func.args[0], row, params)?;
873                match val {
874                    SochValue::Text(s) => Ok(SochValue::Text(s.to_lowercase())),
875                    SochValue::Null => Ok(SochValue::Null),
876                    _ => Err(SqlError::TypeError("LOWER requires string argument".into())),
877                }
878            }
879
880            "TRIM" => {
881                if func.args.len() != 1 {
882                    return Err(SqlError::InvalidArgument("TRIM requires 1 argument".into()));
883                }
884                let val = self.evaluate_expr(&func.args[0], row, params)?;
885                match val {
886                    SochValue::Text(s) => Ok(SochValue::Text(s.trim().to_string())),
887                    SochValue::Null => Ok(SochValue::Null),
888                    _ => Err(SqlError::TypeError("TRIM requires string argument".into())),
889                }
890            }
891
892            "SUBSTR" | "SUBSTRING" => {
893                if func.args.len() < 2 || func.args.len() > 3 {
894                    return Err(SqlError::InvalidArgument(
895                        "SUBSTR requires 2 or 3 arguments".into(),
896                    ));
897                }
898                let val = self.evaluate_expr(&func.args[0], row, params)?;
899                let start = self.evaluate_expr(&func.args[1], row, params)?;
900                let len = if func.args.len() == 3 {
901                    Some(self.evaluate_expr(&func.args[2], row, params)?)
902                } else {
903                    None
904                };
905
906                match (val, start) {
907                    (SochValue::Text(s), SochValue::Int(start)) => {
908                        let start = (start.max(1) - 1) as usize;
909                        if start >= s.len() {
910                            return Ok(SochValue::Text(String::new()));
911                        }
912                        let result = if let Some(SochValue::Int(len)) = len {
913                            s.chars().skip(start).take(len as usize).collect()
914                        } else {
915                            s.chars().skip(start).collect()
916                        };
917                        Ok(SochValue::Text(result))
918                    }
919                    (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
920                    _ => Err(SqlError::TypeError(
921                        "SUBSTR requires string and integer arguments".into(),
922                    )),
923                }
924            }
925
926            _ => Err(SqlError::NotImplemented(format!(
927                "Function {} not implemented",
928                func_name
929            ))),
930        }
931    }
932
933    // ========== Helper Methods ==========
934
935    fn values_equal(&self, left: &SochValue, right: &SochValue) -> bool {
936        match (left, right) {
937            (SochValue::Null, _) | (_, SochValue::Null) => false,
938            (SochValue::Int(l), SochValue::Int(r)) => l == r,
939            (SochValue::Float(l), SochValue::Float(r)) => (l - r).abs() < f64::EPSILON,
940            (SochValue::Int(l), SochValue::Float(r)) => (*l as f64 - r).abs() < f64::EPSILON,
941            (SochValue::Float(l), SochValue::Int(r)) => (l - *r as f64).abs() < f64::EPSILON,
942            (SochValue::Text(l), SochValue::Text(r)) => l == r,
943            (SochValue::Bool(l), SochValue::Bool(r)) => l == r,
944            (SochValue::Binary(l), SochValue::Binary(r)) => l == r,
945            (SochValue::UInt(l), SochValue::UInt(r)) => l == r,
946            (SochValue::Int(l), SochValue::UInt(r)) => *l >= 0 && (*l as u64) == *r,
947            (SochValue::UInt(l), SochValue::Int(r)) => *r >= 0 && *l == (*r as u64),
948            _ => false,
949        }
950    }
951
952    fn compare_values(
953        &self,
954        left: Option<&SochValue>,
955        right: Option<&SochValue>,
956    ) -> std::cmp::Ordering {
957        match (left, right) {
958            (None, None) => std::cmp::Ordering::Equal,
959            (None, _) => std::cmp::Ordering::Less,
960            (_, None) => std::cmp::Ordering::Greater,
961            (Some(SochValue::Null), _) | (_, Some(SochValue::Null)) => std::cmp::Ordering::Equal,
962            (Some(SochValue::Int(l)), Some(SochValue::Int(r))) => l.cmp(r),
963            (Some(SochValue::Float(l)), Some(SochValue::Float(r))) => {
964                l.partial_cmp(r).unwrap_or(std::cmp::Ordering::Equal)
965            }
966            (Some(SochValue::Int(l)), Some(SochValue::Float(r))) => (*l as f64)
967                .partial_cmp(r)
968                .unwrap_or(std::cmp::Ordering::Equal),
969            (Some(SochValue::Float(l)), Some(SochValue::Int(r))) => l
970                .partial_cmp(&(*r as f64))
971                .unwrap_or(std::cmp::Ordering::Equal),
972            (Some(SochValue::Text(l)), Some(SochValue::Text(r))) => l.cmp(r),
973            (Some(SochValue::UInt(l)), Some(SochValue::UInt(r))) => l.cmp(r),
974            _ => std::cmp::Ordering::Equal,
975        }
976    }
977
978    fn arithmetic_op<FI, FF>(
979        &self,
980        left: &SochValue,
981        right: &SochValue,
982        int_op: FI,
983        float_op: FF,
984    ) -> SqlResult<SochValue>
985    where
986        FI: Fn(i64, i64) -> i64,
987        FF: Fn(f64, f64) -> f64,
988    {
989        match (left, right) {
990            (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
991            (SochValue::Int(l), SochValue::Int(r)) => Ok(SochValue::Int(int_op(*l, *r))),
992            (SochValue::Float(l), SochValue::Float(r)) => Ok(SochValue::Float(float_op(*l, *r))),
993            (SochValue::Int(l), SochValue::Float(r)) => {
994                Ok(SochValue::Float(float_op(*l as f64, *r)))
995            }
996            (SochValue::Float(l), SochValue::Int(r)) => {
997                Ok(SochValue::Float(float_op(*l, *r as f64)))
998            }
999            (SochValue::UInt(l), SochValue::UInt(r)) => {
1000                Ok(SochValue::Int(int_op(*l as i64, *r as i64)))
1001            }
1002            (SochValue::Int(l), SochValue::UInt(r)) => Ok(SochValue::Int(int_op(*l, *r as i64))),
1003            (SochValue::UInt(l), SochValue::Int(r)) => Ok(SochValue::Int(int_op(*l as i64, *r))),
1004            _ => Err(SqlError::TypeError(
1005                "Arithmetic requires numeric operands".into(),
1006            )),
1007        }
1008    }
1009}
1010
1011#[cfg(test)]
1012mod tests {
1013    use super::*;
1014
1015    #[test]
1016    fn test_create_table_and_insert() {
1017        let mut executor = SqlExecutor::new();
1018
1019        // Create table
1020        let result = executor
1021            .execute("CREATE TABLE users (id INTEGER, name VARCHAR(100))")
1022            .unwrap();
1023        assert!(matches!(result, ExecutionResult::Ok));
1024
1025        // Insert rows
1026        let result = executor
1027            .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1028            .unwrap();
1029        assert_eq!(result.rows_affected(), 1);
1030
1031        let result = executor
1032            .execute("INSERT INTO users (id, name) VALUES (2, 'Bob')")
1033            .unwrap();
1034        assert_eq!(result.rows_affected(), 1);
1035
1036        // Select
1037        let result = executor.execute("SELECT * FROM users").unwrap();
1038        assert_eq!(result.rows_affected(), 2);
1039    }
1040
1041    #[test]
1042    fn test_select_with_where() {
1043        let mut executor = SqlExecutor::new();
1044
1045        executor
1046            .execute("CREATE TABLE products (id INTEGER, name TEXT, price FLOAT)")
1047            .unwrap();
1048        executor
1049            .execute("INSERT INTO products (id, name, price) VALUES (1, 'Apple', 1.50)")
1050            .unwrap();
1051        executor
1052            .execute("INSERT INTO products (id, name, price) VALUES (2, 'Banana', 0.75)")
1053            .unwrap();
1054        executor
1055            .execute("INSERT INTO products (id, name, price) VALUES (3, 'Orange', 2.00)")
1056            .unwrap();
1057
1058        let result = executor
1059            .execute("SELECT * FROM products WHERE price > 1.0")
1060            .unwrap();
1061        assert_eq!(result.rows_affected(), 2);
1062    }
1063
1064    #[test]
1065    fn test_update() {
1066        let mut executor = SqlExecutor::new();
1067
1068        executor
1069            .execute("CREATE TABLE users (id INTEGER, name TEXT)")
1070            .unwrap();
1071        executor
1072            .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1073            .unwrap();
1074
1075        let result = executor
1076            .execute("UPDATE users SET name = 'Alicia' WHERE id = 1")
1077            .unwrap();
1078        assert_eq!(result.rows_affected(), 1);
1079
1080        let result = executor
1081            .execute("SELECT * FROM users WHERE name = 'Alicia'")
1082            .unwrap();
1083        assert_eq!(result.rows_affected(), 1);
1084    }
1085
1086    #[test]
1087    fn test_delete() {
1088        let mut executor = SqlExecutor::new();
1089
1090        executor
1091            .execute("CREATE TABLE users (id INTEGER, name TEXT)")
1092            .unwrap();
1093        executor
1094            .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1095            .unwrap();
1096        executor
1097            .execute("INSERT INTO users (id, name) VALUES (2, 'Bob')")
1098            .unwrap();
1099
1100        let result = executor.execute("DELETE FROM users WHERE id = 1").unwrap();
1101        assert_eq!(result.rows_affected(), 1);
1102
1103        let result = executor.execute("SELECT * FROM users").unwrap();
1104        assert_eq!(result.rows_affected(), 1);
1105    }
1106
1107    #[test]
1108    fn test_functions() {
1109        let mut executor = SqlExecutor::new();
1110
1111        executor.execute("CREATE TABLE t (s TEXT)").unwrap();
1112        executor
1113            .execute("INSERT INTO t (s) VALUES ('hello')")
1114            .unwrap();
1115
1116        let result = executor.execute("SELECT UPPER(s) FROM t").unwrap();
1117        if let ExecutionResult::Rows { rows, .. } = result {
1118            let row = &rows[0];
1119            // The column name might be UPPER(s) or similar
1120            assert!(
1121                row.values()
1122                    .any(|v| matches!(v, SochValue::Text(s) if s == "HELLO"))
1123            );
1124        } else {
1125            panic!("Expected rows");
1126        }
1127    }
1128
1129    #[test]
1130    fn test_order_by() {
1131        let mut executor = SqlExecutor::new();
1132
1133        executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1134        executor.execute("INSERT INTO nums (n) VALUES (3)").unwrap();
1135        executor.execute("INSERT INTO nums (n) VALUES (1)").unwrap();
1136        executor.execute("INSERT INTO nums (n) VALUES (2)").unwrap();
1137
1138        let result = executor
1139            .execute("SELECT * FROM nums ORDER BY n ASC")
1140            .unwrap();
1141        if let ExecutionResult::Rows { rows, .. } = result {
1142            let values: Vec<i64> = rows
1143                .iter()
1144                .filter_map(|r| r.get("n"))
1145                .filter_map(|v| {
1146                    if let SochValue::Int(n) = v {
1147                        Some(*n)
1148                    } else {
1149                        None
1150                    }
1151                })
1152                .collect();
1153            assert_eq!(values, vec![1, 2, 3]);
1154        } else {
1155            panic!("Expected rows");
1156        }
1157    }
1158
1159    #[test]
1160    fn test_limit_offset() {
1161        let mut executor = SqlExecutor::new();
1162
1163        executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1164        for i in 1..=10 {
1165            executor
1166                .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1167                .unwrap();
1168        }
1169
1170        let result = executor
1171            .execute("SELECT * FROM nums LIMIT 3 OFFSET 2")
1172            .unwrap();
1173        assert_eq!(result.rows_affected(), 3);
1174    }
1175
1176    #[test]
1177    fn test_between() {
1178        let mut executor = SqlExecutor::new();
1179
1180        executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1181        for i in 1..=10 {
1182            executor
1183                .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1184                .unwrap();
1185        }
1186
1187        let result = executor
1188            .execute("SELECT * FROM nums WHERE n BETWEEN 3 AND 7")
1189            .unwrap();
1190        assert_eq!(result.rows_affected(), 5);
1191    }
1192
1193    #[test]
1194    fn test_in_list() {
1195        let mut executor = SqlExecutor::new();
1196
1197        executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1198        for i in 1..=5 {
1199            executor
1200                .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1201                .unwrap();
1202        }
1203
1204        let result = executor
1205            .execute("SELECT * FROM nums WHERE n IN (1, 3, 5)")
1206            .unwrap();
1207        assert_eq!(result.rows_affected(), 3);
1208    }
1209}