sochdb_query/sql/
bridge.rs

1// Copyright 2025 Sushanth (https://github.com/sushanthpy)
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! # SQL Execution Bridge
16//!
17//! Unified SQL execution pipeline that routes all SQL through a single AST.
18//!
19//! ## Architecture
20//!
21//! ```text
22//! ┌─────────────┐     ┌─────────────┐     ┌─────────────┐     ┌─────────────┐
23//! │   SQL Text  │ --> │   Lexer     │ --> │   Parser    │ --> │    AST      │
24//! └─────────────┘     └─────────────┘     └─────────────┘     └─────────────┘
25//!                                                                    │
26//!                     ┌──────────────────────────────────────────────┘
27//!                     │
28//!                     v
29//! ┌─────────────┐     ┌─────────────┐     ┌─────────────┐
30//! │  Executor   │ <-- │  Planner    │ <-- │  Validator  │
31//! └─────────────┘     └─────────────┘     └─────────────┘
32//!       │
33//!       v
34//! ┌─────────────┐
35//! │   Result    │
36//! └─────────────┘
37//! ```
38//!
39//! ## Benefits
40//!
41//! 1. **Single parser**: All SQL goes through one lexer/parser
42//! 2. **Type-safe AST**: Structured representation of all queries
43//! 3. **Dialect normalization**: MySQL/PostgreSQL/SQLite → canonical AST
44//! 4. **Extensible**: Add new features by extending AST, not string parsing
45
46use super::ast::*;
47use super::compatibility::SqlDialect;
48use super::error::{SqlError, SqlResult};
49use super::parser::Parser;
50use std::collections::HashMap;
51use sochdb_core::SochValue;
52
53/// Execution result types
54#[derive(Debug, Clone)]
55pub enum ExecutionResult {
56    /// SELECT query result
57    Rows {
58        columns: Vec<String>,
59        rows: Vec<HashMap<String, SochValue>>,
60    },
61    /// DML result (INSERT/UPDATE/DELETE)
62    RowsAffected(usize),
63    /// DDL result (CREATE/DROP/ALTER)
64    Ok,
65    /// Transaction control result
66    TransactionOk,
67}
68
69impl ExecutionResult {
70    /// Get rows if this is a SELECT result
71    pub fn rows(&self) -> Option<&Vec<HashMap<String, SochValue>>> {
72        match self {
73            ExecutionResult::Rows { rows, .. } => Some(rows),
74            _ => None,
75        }
76    }
77
78    /// Get column names if this is a SELECT result
79    pub fn columns(&self) -> Option<&Vec<String>> {
80        match self {
81            ExecutionResult::Rows { columns, .. } => Some(columns),
82            _ => None,
83        }
84    }
85
86    /// Get affected row count
87    pub fn rows_affected(&self) -> usize {
88        match self {
89            ExecutionResult::RowsAffected(n) => *n,
90            ExecutionResult::Rows { rows, .. } => rows.len(),
91            _ => 0,
92        }
93    }
94}
95
96/// Storage connection trait for executing SQL against actual storage
97///
98/// Implementations of this trait provide the bridge between parsed SQL
99/// and the underlying storage engine.
100pub trait SqlConnection {
101    /// Execute a SELECT query
102    fn select(
103        &self,
104        table: &str,
105        columns: &[String],
106        where_clause: Option<&Expr>,
107        order_by: &[OrderByItem],
108        limit: Option<usize>,
109        offset: Option<usize>,
110        params: &[SochValue],
111    ) -> SqlResult<ExecutionResult>;
112
113    /// Execute an INSERT
114    fn insert(
115        &mut self,
116        table: &str,
117        columns: Option<&[String]>,
118        rows: &[Vec<Expr>],
119        on_conflict: Option<&OnConflict>,
120        params: &[SochValue],
121    ) -> SqlResult<ExecutionResult>;
122
123    /// Execute an UPDATE
124    fn update(
125        &mut self,
126        table: &str,
127        assignments: &[Assignment],
128        where_clause: Option<&Expr>,
129        params: &[SochValue],
130    ) -> SqlResult<ExecutionResult>;
131
132    /// Execute a DELETE
133    fn delete(
134        &mut self,
135        table: &str,
136        where_clause: Option<&Expr>,
137        params: &[SochValue],
138    ) -> SqlResult<ExecutionResult>;
139
140    /// Create a table
141    fn create_table(&mut self, stmt: &CreateTableStmt) -> SqlResult<ExecutionResult>;
142
143    /// Drop a table
144    fn drop_table(&mut self, stmt: &DropTableStmt) -> SqlResult<ExecutionResult>;
145
146    /// Create an index
147    fn create_index(&mut self, stmt: &CreateIndexStmt) -> SqlResult<ExecutionResult>;
148
149    /// Drop an index
150    fn drop_index(&mut self, stmt: &DropIndexStmt) -> SqlResult<ExecutionResult>;
151
152    /// Begin transaction
153    fn begin(&mut self, stmt: &BeginStmt) -> SqlResult<ExecutionResult>;
154
155    /// Commit transaction
156    fn commit(&mut self) -> SqlResult<ExecutionResult>;
157
158    /// Rollback transaction
159    fn rollback(&mut self, savepoint: Option<&str>) -> SqlResult<ExecutionResult>;
160
161    /// Check if table exists
162    fn table_exists(&self, table: &str) -> SqlResult<bool>;
163
164    /// Check if index exists
165    fn index_exists(&self, index: &str) -> SqlResult<bool>;
166}
167
168/// Unified SQL executor that routes through AST
169pub struct SqlBridge<C: SqlConnection> {
170    conn: C,
171}
172
173impl<C: SqlConnection> SqlBridge<C> {
174    /// Create a new SQL bridge with the given connection
175    pub fn new(conn: C) -> Self {
176        Self { conn }
177    }
178
179    /// Execute a SQL statement
180    pub fn execute(&mut self, sql: &str) -> SqlResult<ExecutionResult> {
181        self.execute_with_params(sql, &[])
182    }
183
184    /// Execute a SQL statement with parameters
185    pub fn execute_with_params(
186        &mut self,
187        sql: &str,
188        params: &[SochValue],
189    ) -> SqlResult<ExecutionResult> {
190        // Detect dialect for better error messages
191        let _dialect = SqlDialect::detect(sql);
192
193        // Parse SQL into AST
194        let stmt = Parser::parse(sql).map_err(SqlError::from_parse_errors)?;
195
196        // Validate placeholder count
197        let max_placeholder = self.find_max_placeholder(&stmt);
198        if max_placeholder as usize > params.len() {
199            return Err(SqlError::InvalidArgument(format!(
200                "Query contains {} placeholders but only {} parameters provided",
201                max_placeholder,
202                params.len()
203            )));
204        }
205
206        // Execute statement
207        self.execute_statement(&stmt, params)
208    }
209
210    /// Execute a parsed statement
211    pub fn execute_statement(
212        &mut self,
213        stmt: &Statement,
214        params: &[SochValue],
215    ) -> SqlResult<ExecutionResult> {
216        match stmt {
217            Statement::Select(select) => self.execute_select(select, params),
218            Statement::Insert(insert) => self.execute_insert(insert, params),
219            Statement::Update(update) => self.execute_update(update, params),
220            Statement::Delete(delete) => self.execute_delete(delete, params),
221            Statement::CreateTable(create) => self.execute_create_table(create),
222            Statement::DropTable(drop) => self.execute_drop_table(drop),
223            Statement::CreateIndex(create) => self.execute_create_index(create),
224            Statement::DropIndex(drop) => self.execute_drop_index(drop),
225            Statement::AlterTable(_alter) => Err(SqlError::NotImplemented(
226                "ALTER TABLE not yet implemented".into(),
227            )),
228            Statement::Begin(begin) => self.conn.begin(begin),
229            Statement::Commit => self.conn.commit(),
230            Statement::Rollback(savepoint) => self.conn.rollback(savepoint.as_deref()),
231            Statement::Savepoint(_name) => Err(SqlError::NotImplemented(
232                "SAVEPOINT not yet implemented".into(),
233            )),
234            Statement::Release(_name) => Err(SqlError::NotImplemented(
235                "RELEASE SAVEPOINT not yet implemented".into(),
236            )),
237            Statement::Explain(_stmt) => Err(SqlError::NotImplemented(
238                "EXPLAIN not yet implemented".into(),
239            )),
240        }
241    }
242
243    fn execute_select(
244        &self,
245        select: &SelectStmt,
246        params: &[SochValue],
247    ) -> SqlResult<ExecutionResult> {
248        // Get table from FROM clause
249        let from = select
250            .from
251            .as_ref()
252            .ok_or_else(|| SqlError::InvalidArgument("SELECT requires FROM clause".into()))?;
253
254        if from.tables.len() != 1 {
255            return Err(SqlError::NotImplemented(
256                "Multi-table queries not yet supported".into(),
257            ));
258        }
259
260        let table_name = match &from.tables[0] {
261            TableRef::Table { name, .. } => name.name().to_string(),
262            TableRef::Subquery { .. } => {
263                return Err(SqlError::NotImplemented(
264                    "Subqueries not yet supported".into(),
265                ));
266            }
267            TableRef::Join { .. } => {
268                return Err(SqlError::NotImplemented(
269                    "JOINs not yet supported".into(),
270                ));
271            }
272            TableRef::Function { .. } => {
273                return Err(SqlError::NotImplemented(
274                    "Table functions not yet supported".into(),
275                ));
276            }
277        };
278
279        // Extract column names
280        let columns = self.extract_select_columns(&select.columns)?;
281
282        // Extract LIMIT/OFFSET
283        let limit = self.extract_limit(&select.limit)?;
284        let offset = self.extract_limit(&select.offset)?;
285
286        self.conn.select(
287            &table_name,
288            &columns,
289            select.where_clause.as_ref(),
290            &select.order_by,
291            limit,
292            offset,
293            params,
294        )
295    }
296
297    fn execute_insert(
298        &mut self,
299        insert: &InsertStmt,
300        params: &[SochValue],
301    ) -> SqlResult<ExecutionResult> {
302        let table_name = insert.table.name();
303
304        let rows = match &insert.source {
305            InsertSource::Values(values) => values,
306            InsertSource::Query(_) => {
307                return Err(SqlError::NotImplemented(
308                    "INSERT ... SELECT not yet supported".into(),
309                ));
310            }
311            InsertSource::Default => {
312                return Err(SqlError::NotImplemented(
313                    "INSERT DEFAULT VALUES not yet supported".into(),
314                ));
315            }
316        };
317
318        self.conn.insert(
319            table_name,
320            insert.columns.as_deref(),
321            rows,
322            insert.on_conflict.as_ref(),
323            params,
324        )
325    }
326
327    fn execute_update(
328        &mut self,
329        update: &UpdateStmt,
330        params: &[SochValue],
331    ) -> SqlResult<ExecutionResult> {
332        let table_name = update.table.name();
333
334        self.conn.update(
335            table_name,
336            &update.assignments,
337            update.where_clause.as_ref(),
338            params,
339        )
340    }
341
342    fn execute_delete(
343        &mut self,
344        delete: &DeleteStmt,
345        params: &[SochValue],
346    ) -> SqlResult<ExecutionResult> {
347        let table_name = delete.table.name();
348
349        self.conn.delete(
350            table_name,
351            delete.where_clause.as_ref(),
352            params,
353        )
354    }
355
356    fn execute_create_table(&mut self, stmt: &CreateTableStmt) -> SqlResult<ExecutionResult> {
357        // Handle IF NOT EXISTS
358        if stmt.if_not_exists {
359            let table_name = stmt.name.name();
360            if self.conn.table_exists(table_name)? {
361                return Ok(ExecutionResult::Ok);
362            }
363        }
364
365        self.conn.create_table(stmt)
366    }
367
368    fn execute_drop_table(&mut self, stmt: &DropTableStmt) -> SqlResult<ExecutionResult> {
369        // Handle IF EXISTS
370        if stmt.if_exists {
371            for name in &stmt.names {
372                if !self.conn.table_exists(name.name())? {
373                    return Ok(ExecutionResult::Ok);
374                }
375            }
376        }
377
378        self.conn.drop_table(stmt)
379    }
380
381    fn execute_create_index(&mut self, stmt: &CreateIndexStmt) -> SqlResult<ExecutionResult> {
382        // Handle IF NOT EXISTS
383        if stmt.if_not_exists {
384            if self.conn.index_exists(&stmt.name)? {
385                return Ok(ExecutionResult::Ok);
386            }
387        }
388
389        self.conn.create_index(stmt)
390    }
391
392    fn execute_drop_index(&mut self, stmt: &DropIndexStmt) -> SqlResult<ExecutionResult> {
393        // Handle IF EXISTS
394        if stmt.if_exists {
395            if !self.conn.index_exists(&stmt.name)? {
396                return Ok(ExecutionResult::Ok);
397            }
398        }
399
400        self.conn.drop_index(stmt)
401    }
402
403    /// Extract column names from SELECT list
404    fn extract_select_columns(&self, items: &[SelectItem]) -> SqlResult<Vec<String>> {
405        let mut columns = Vec::new();
406
407        for item in items {
408            match item {
409                SelectItem::Wildcard => columns.push("*".to_string()),
410                SelectItem::QualifiedWildcard(table) => columns.push(format!("{}.*", table)),
411                SelectItem::Expr { expr, alias } => {
412                    let name = alias.clone().unwrap_or_else(|| match expr {
413                        Expr::Column(col) => col.column.clone(),
414                        Expr::Function(func) => format!("{}()", func.name.name()),
415                        _ => "?column?".to_string(),
416                    });
417                    columns.push(name);
418                }
419            }
420        }
421
422        Ok(columns)
423    }
424
425    /// Extract LIMIT/OFFSET value
426    fn extract_limit(&self, expr: &Option<Expr>) -> SqlResult<Option<usize>> {
427        match expr {
428            Some(Expr::Literal(Literal::Integer(n))) => Ok(Some(*n as usize)),
429            Some(_) => Err(SqlError::InvalidArgument(
430                "LIMIT/OFFSET must be an integer literal".into(),
431            )),
432            None => Ok(None),
433        }
434    }
435
436    /// Find the maximum placeholder index in a statement
437    fn find_max_placeholder(&self, stmt: &Statement) -> u32 {
438        let mut visitor = PlaceholderVisitor::new();
439        visitor.visit_statement(stmt);
440        visitor.max_placeholder
441    }
442}
443
444/// Visitor to find maximum placeholder index
445struct PlaceholderVisitor {
446    max_placeholder: u32,
447}
448
449impl PlaceholderVisitor {
450    fn new() -> Self {
451        Self { max_placeholder: 0 }
452    }
453
454    fn visit_statement(&mut self, stmt: &Statement) {
455        match stmt {
456            Statement::Select(s) => self.visit_select(s),
457            Statement::Insert(i) => self.visit_insert(i),
458            Statement::Update(u) => self.visit_update(u),
459            Statement::Delete(d) => self.visit_delete(d),
460            _ => {}
461        }
462    }
463
464    fn visit_select(&mut self, select: &SelectStmt) {
465        for item in &select.columns {
466            if let SelectItem::Expr { expr, .. } = item {
467                self.visit_expr(expr);
468            }
469        }
470        if let Some(where_clause) = &select.where_clause {
471            self.visit_expr(where_clause);
472        }
473        if let Some(having) = &select.having {
474            self.visit_expr(having);
475        }
476        for order in &select.order_by {
477            self.visit_expr(&order.expr);
478        }
479        if let Some(limit) = &select.limit {
480            self.visit_expr(limit);
481        }
482        if let Some(offset) = &select.offset {
483            self.visit_expr(offset);
484        }
485    }
486
487    fn visit_insert(&mut self, insert: &InsertStmt) {
488        if let InsertSource::Values(rows) = &insert.source {
489            for row in rows {
490                for expr in row {
491                    self.visit_expr(expr);
492                }
493            }
494        }
495    }
496
497    fn visit_update(&mut self, update: &UpdateStmt) {
498        for assign in &update.assignments {
499            self.visit_expr(&assign.value);
500        }
501        if let Some(where_clause) = &update.where_clause {
502            self.visit_expr(where_clause);
503        }
504    }
505
506    fn visit_delete(&mut self, delete: &DeleteStmt) {
507        if let Some(where_clause) = &delete.where_clause {
508            self.visit_expr(where_clause);
509        }
510    }
511
512    fn visit_expr(&mut self, expr: &Expr) {
513        match expr {
514            Expr::Placeholder(n) => {
515                self.max_placeholder = self.max_placeholder.max(*n);
516            }
517            Expr::BinaryOp { left, right, .. } => {
518                self.visit_expr(left);
519                self.visit_expr(right);
520            }
521            Expr::UnaryOp { expr, .. } => {
522                self.visit_expr(expr);
523            }
524            Expr::Function(func) => {
525                for arg in &func.args {
526                    self.visit_expr(arg);
527                }
528            }
529            Expr::Case { operand, conditions, else_result } => {
530                if let Some(op) = operand {
531                    self.visit_expr(op);
532                }
533                for (when, then) in conditions {
534                    self.visit_expr(when);
535                    self.visit_expr(then);
536                }
537                if let Some(else_expr) = else_result {
538                    self.visit_expr(else_expr);
539                }
540            }
541            Expr::InList { expr, list, .. } => {
542                self.visit_expr(expr);
543                for item in list {
544                    self.visit_expr(item);
545                }
546            }
547            Expr::Between { expr, low, high, .. } => {
548                self.visit_expr(expr);
549                self.visit_expr(low);
550                self.visit_expr(high);
551            }
552            Expr::Cast { expr, .. } => {
553                self.visit_expr(expr);
554            }
555            _ => {}
556        }
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563
564    #[test]
565    fn test_placeholder_visitor() {
566        let stmt = Parser::parse("SELECT * FROM users WHERE id = $1 AND name = $2").unwrap();
567        let mut visitor = PlaceholderVisitor::new();
568        visitor.visit_statement(&stmt);
569        assert_eq!(visitor.max_placeholder, 2);
570    }
571
572    #[test]
573    fn test_question_mark_placeholders() {
574        let stmt = Parser::parse("SELECT * FROM users WHERE id = ? AND name = ?").unwrap();
575        let mut visitor = PlaceholderVisitor::new();
576        visitor.visit_statement(&stmt);
577        assert_eq!(visitor.max_placeholder, 2);
578    }
579
580    #[test]
581    fn test_dialect_detection() {
582        assert_eq!(SqlDialect::detect("SELECT * FROM users"), SqlDialect::Standard);
583        assert_eq!(
584            SqlDialect::detect("INSERT IGNORE INTO users VALUES (1)"),
585            SqlDialect::MySQL
586        );
587        assert_eq!(
588            SqlDialect::detect("INSERT OR IGNORE INTO users VALUES (1)"),
589            SqlDialect::SQLite
590        );
591    }
592}