Skip to main content

sochdb_query/sql/
bridge.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! # SQL Execution Bridge
19//!
20//! Unified SQL execution pipeline that routes all SQL through a single AST.
21//!
22//! ## Architecture
23//!
24//! ```text
25//! ┌─────────────┐     ┌─────────────┐     ┌─────────────┐     ┌─────────────┐
26//! │   SQL Text  │ --> │   Lexer     │ --> │   Parser    │ --> │    AST      │
27//! └─────────────┘     └─────────────┘     └─────────────┘     └─────────────┘
28//!                                                                    │
29//!                     ┌──────────────────────────────────────────────┘
30//!                     │
31//!                     v
32//! ┌─────────────┐     ┌─────────────┐     ┌─────────────┐
33//! │  Executor   │ <-- │  Planner    │ <-- │  Validator  │
34//! └─────────────┘     └─────────────┘     └─────────────┘
35//!       │
36//!       v
37//! ┌─────────────┐
38//! │   Result    │
39//! └─────────────┘
40//! ```
41//!
42//! ## Benefits
43//!
44//! 1. **Single parser**: All SQL goes through one lexer/parser
45//! 2. **Type-safe AST**: Structured representation of all queries
46//! 3. **Dialect normalization**: MySQL/PostgreSQL/SQLite → canonical AST
47//! 4. **Extensible**: Add new features by extending AST, not string parsing
48
49use super::ast::*;
50use super::compatibility::SqlDialect;
51use super::error::{SqlError, SqlResult};
52use super::parser::Parser;
53use std::collections::HashMap;
54use sochdb_core::SochValue;
55
56/// Execution result types
57#[derive(Debug, Clone)]
58pub enum ExecutionResult {
59    /// SELECT query result
60    Rows {
61        columns: Vec<String>,
62        rows: Vec<HashMap<String, SochValue>>,
63    },
64    /// DML result (INSERT/UPDATE/DELETE)
65    RowsAffected(usize),
66    /// DDL result (CREATE/DROP/ALTER)
67    Ok,
68    /// Transaction control result
69    TransactionOk,
70}
71
72impl ExecutionResult {
73    /// Get rows if this is a SELECT result
74    pub fn rows(&self) -> Option<&Vec<HashMap<String, SochValue>>> {
75        match self {
76            ExecutionResult::Rows { rows, .. } => Some(rows),
77            _ => None,
78        }
79    }
80
81    /// Get column names if this is a SELECT result
82    pub fn columns(&self) -> Option<&Vec<String>> {
83        match self {
84            ExecutionResult::Rows { columns, .. } => Some(columns),
85            _ => None,
86        }
87    }
88
89    /// Get affected row count
90    pub fn rows_affected(&self) -> usize {
91        match self {
92            ExecutionResult::RowsAffected(n) => *n,
93            ExecutionResult::Rows { rows, .. } => rows.len(),
94            _ => 0,
95        }
96    }
97}
98
99/// Storage connection trait for executing SQL against actual storage
100///
101/// Implementations of this trait provide the bridge between parsed SQL
102/// and the underlying storage engine.
103pub trait SqlConnection {
104    /// Execute a SELECT query
105    fn select(
106        &self,
107        table: &str,
108        columns: &[String],
109        where_clause: Option<&Expr>,
110        order_by: &[OrderByItem],
111        limit: Option<usize>,
112        offset: Option<usize>,
113        params: &[SochValue],
114    ) -> SqlResult<ExecutionResult>;
115
116    /// Execute an INSERT
117    fn insert(
118        &mut self,
119        table: &str,
120        columns: Option<&[String]>,
121        rows: &[Vec<Expr>],
122        on_conflict: Option<&OnConflict>,
123        params: &[SochValue],
124    ) -> SqlResult<ExecutionResult>;
125
126    /// Execute an UPDATE
127    fn update(
128        &mut self,
129        table: &str,
130        assignments: &[Assignment],
131        where_clause: Option<&Expr>,
132        params: &[SochValue],
133    ) -> SqlResult<ExecutionResult>;
134
135    /// Execute a DELETE
136    fn delete(
137        &mut self,
138        table: &str,
139        where_clause: Option<&Expr>,
140        params: &[SochValue],
141    ) -> SqlResult<ExecutionResult>;
142
143    /// Create a table
144    fn create_table(&mut self, stmt: &CreateTableStmt) -> SqlResult<ExecutionResult>;
145
146    /// Drop a table
147    fn drop_table(&mut self, stmt: &DropTableStmt) -> SqlResult<ExecutionResult>;
148
149    /// Create an index
150    fn create_index(&mut self, stmt: &CreateIndexStmt) -> SqlResult<ExecutionResult>;
151
152    /// Drop an index
153    fn drop_index(&mut self, stmt: &DropIndexStmt) -> SqlResult<ExecutionResult>;
154
155    /// Begin transaction
156    fn begin(&mut self, stmt: &BeginStmt) -> SqlResult<ExecutionResult>;
157
158    /// Commit transaction
159    fn commit(&mut self) -> SqlResult<ExecutionResult>;
160
161    /// Rollback transaction
162    fn rollback(&mut self, savepoint: Option<&str>) -> SqlResult<ExecutionResult>;
163
164    /// Check if table exists
165    fn table_exists(&self, table: &str) -> SqlResult<bool>;
166
167    /// Check if index exists
168    fn index_exists(&self, index: &str) -> SqlResult<bool>;
169}
170
171/// Unified SQL executor that routes through AST
172pub struct SqlBridge<C: SqlConnection> {
173    conn: C,
174}
175
176impl<C: SqlConnection> SqlBridge<C> {
177    /// Create a new SQL bridge with the given connection
178    pub fn new(conn: C) -> Self {
179        Self { conn }
180    }
181
182    /// Execute a SQL statement
183    pub fn execute(&mut self, sql: &str) -> SqlResult<ExecutionResult> {
184        self.execute_with_params(sql, &[])
185    }
186
187    /// Execute a SQL statement with parameters
188    pub fn execute_with_params(
189        &mut self,
190        sql: &str,
191        params: &[SochValue],
192    ) -> SqlResult<ExecutionResult> {
193        // Detect dialect for better error messages
194        let _dialect = SqlDialect::detect(sql);
195
196        // Parse SQL into AST
197        let stmt = Parser::parse(sql).map_err(SqlError::from_parse_errors)?;
198
199        // Validate placeholder count
200        let max_placeholder = self.find_max_placeholder(&stmt);
201        if max_placeholder as usize > params.len() {
202            return Err(SqlError::InvalidArgument(format!(
203                "Query contains {} placeholders but only {} parameters provided",
204                max_placeholder,
205                params.len()
206            )));
207        }
208
209        // Execute statement
210        self.execute_statement(&stmt, params)
211    }
212
213    /// Execute a parsed statement
214    pub fn execute_statement(
215        &mut self,
216        stmt: &Statement,
217        params: &[SochValue],
218    ) -> SqlResult<ExecutionResult> {
219        match stmt {
220            Statement::Select(select) => self.execute_select(select, params),
221            Statement::Insert(insert) => self.execute_insert(insert, params),
222            Statement::Update(update) => self.execute_update(update, params),
223            Statement::Delete(delete) => self.execute_delete(delete, params),
224            Statement::CreateTable(create) => self.execute_create_table(create),
225            Statement::DropTable(drop) => self.execute_drop_table(drop),
226            Statement::CreateIndex(create) => self.execute_create_index(create),
227            Statement::DropIndex(drop) => self.execute_drop_index(drop),
228            Statement::AlterTable(_alter) => Err(SqlError::NotImplemented(
229                "ALTER TABLE not yet implemented".into(),
230            )),
231            Statement::Begin(begin) => self.conn.begin(begin),
232            Statement::Commit => self.conn.commit(),
233            Statement::Rollback(savepoint) => self.conn.rollback(savepoint.as_deref()),
234            Statement::Savepoint(_name) => Err(SqlError::NotImplemented(
235                "SAVEPOINT not yet implemented".into(),
236            )),
237            Statement::Release(_name) => Err(SqlError::NotImplemented(
238                "RELEASE SAVEPOINT not yet implemented".into(),
239            )),
240            Statement::Explain(_stmt) => Err(SqlError::NotImplemented(
241                "EXPLAIN not yet implemented".into(),
242            )),
243        }
244    }
245
246    fn execute_select(
247        &self,
248        select: &SelectStmt,
249        params: &[SochValue],
250    ) -> SqlResult<ExecutionResult> {
251        // Get table from FROM clause
252        let from = select
253            .from
254            .as_ref()
255            .ok_or_else(|| SqlError::InvalidArgument("SELECT requires FROM clause".into()))?;
256
257        if from.tables.len() != 1 {
258            return Err(SqlError::NotImplemented(
259                "Multi-table queries not yet supported".into(),
260            ));
261        }
262
263        let table_name = match &from.tables[0] {
264            TableRef::Table { name, .. } => name.name().to_string(),
265            TableRef::Subquery { .. } => {
266                return Err(SqlError::NotImplemented(
267                    "Subqueries not yet supported".into(),
268                ));
269            }
270            TableRef::Join { .. } => {
271                return Err(SqlError::NotImplemented(
272                    "JOINs not yet supported".into(),
273                ));
274            }
275            TableRef::Function { .. } => {
276                return Err(SqlError::NotImplemented(
277                    "Table functions not yet supported".into(),
278                ));
279            }
280        };
281
282        // Extract column names
283        let columns = self.extract_select_columns(&select.columns)?;
284
285        // Extract LIMIT/OFFSET
286        let limit = self.extract_limit(&select.limit)?;
287        let offset = self.extract_limit(&select.offset)?;
288
289        self.conn.select(
290            &table_name,
291            &columns,
292            select.where_clause.as_ref(),
293            &select.order_by,
294            limit,
295            offset,
296            params,
297        )
298    }
299
300    fn execute_insert(
301        &mut self,
302        insert: &InsertStmt,
303        params: &[SochValue],
304    ) -> SqlResult<ExecutionResult> {
305        let table_name = insert.table.name();
306
307        let rows = match &insert.source {
308            InsertSource::Values(values) => values,
309            InsertSource::Query(_) => {
310                return Err(SqlError::NotImplemented(
311                    "INSERT ... SELECT not yet supported".into(),
312                ));
313            }
314            InsertSource::Default => {
315                return Err(SqlError::NotImplemented(
316                    "INSERT DEFAULT VALUES not yet supported".into(),
317                ));
318            }
319        };
320
321        self.conn.insert(
322            table_name,
323            insert.columns.as_deref(),
324            rows,
325            insert.on_conflict.as_ref(),
326            params,
327        )
328    }
329
330    fn execute_update(
331        &mut self,
332        update: &UpdateStmt,
333        params: &[SochValue],
334    ) -> SqlResult<ExecutionResult> {
335        let table_name = update.table.name();
336
337        self.conn.update(
338            table_name,
339            &update.assignments,
340            update.where_clause.as_ref(),
341            params,
342        )
343    }
344
345    fn execute_delete(
346        &mut self,
347        delete: &DeleteStmt,
348        params: &[SochValue],
349    ) -> SqlResult<ExecutionResult> {
350        let table_name = delete.table.name();
351
352        self.conn.delete(
353            table_name,
354            delete.where_clause.as_ref(),
355            params,
356        )
357    }
358
359    fn execute_create_table(&mut self, stmt: &CreateTableStmt) -> SqlResult<ExecutionResult> {
360        // Handle IF NOT EXISTS
361        if stmt.if_not_exists {
362            let table_name = stmt.name.name();
363            if self.conn.table_exists(table_name)? {
364                return Ok(ExecutionResult::Ok);
365            }
366        }
367
368        self.conn.create_table(stmt)
369    }
370
371    fn execute_drop_table(&mut self, stmt: &DropTableStmt) -> SqlResult<ExecutionResult> {
372        // Handle IF EXISTS
373        if stmt.if_exists {
374            for name in &stmt.names {
375                if !self.conn.table_exists(name.name())? {
376                    return Ok(ExecutionResult::Ok);
377                }
378            }
379        }
380
381        self.conn.drop_table(stmt)
382    }
383
384    fn execute_create_index(&mut self, stmt: &CreateIndexStmt) -> SqlResult<ExecutionResult> {
385        // Handle IF NOT EXISTS
386        if stmt.if_not_exists {
387            if self.conn.index_exists(&stmt.name)? {
388                return Ok(ExecutionResult::Ok);
389            }
390        }
391
392        self.conn.create_index(stmt)
393    }
394
395    fn execute_drop_index(&mut self, stmt: &DropIndexStmt) -> SqlResult<ExecutionResult> {
396        // Handle IF EXISTS
397        if stmt.if_exists {
398            if !self.conn.index_exists(&stmt.name)? {
399                return Ok(ExecutionResult::Ok);
400            }
401        }
402
403        self.conn.drop_index(stmt)
404    }
405
406    /// Extract column names from SELECT list
407    fn extract_select_columns(&self, items: &[SelectItem]) -> SqlResult<Vec<String>> {
408        let mut columns = Vec::new();
409
410        for item in items {
411            match item {
412                SelectItem::Wildcard => columns.push("*".to_string()),
413                SelectItem::QualifiedWildcard(table) => columns.push(format!("{}.*", table)),
414                SelectItem::Expr { expr, alias } => {
415                    let name = alias.clone().unwrap_or_else(|| match expr {
416                        Expr::Column(col) => col.column.clone(),
417                        Expr::Function(func) => format!("{}()", func.name.name()),
418                        _ => "?column?".to_string(),
419                    });
420                    columns.push(name);
421                }
422            }
423        }
424
425        Ok(columns)
426    }
427
428    /// Extract LIMIT/OFFSET value
429    fn extract_limit(&self, expr: &Option<Expr>) -> SqlResult<Option<usize>> {
430        match expr {
431            Some(Expr::Literal(Literal::Integer(n))) => Ok(Some(*n as usize)),
432            Some(_) => Err(SqlError::InvalidArgument(
433                "LIMIT/OFFSET must be an integer literal".into(),
434            )),
435            None => Ok(None),
436        }
437    }
438
439    /// Find the maximum placeholder index in a statement
440    fn find_max_placeholder(&self, stmt: &Statement) -> u32 {
441        let mut visitor = PlaceholderVisitor::new();
442        visitor.visit_statement(stmt);
443        visitor.max_placeholder
444    }
445}
446
447/// Visitor to find maximum placeholder index
448struct PlaceholderVisitor {
449    max_placeholder: u32,
450}
451
452impl PlaceholderVisitor {
453    fn new() -> Self {
454        Self { max_placeholder: 0 }
455    }
456
457    fn visit_statement(&mut self, stmt: &Statement) {
458        match stmt {
459            Statement::Select(s) => self.visit_select(s),
460            Statement::Insert(i) => self.visit_insert(i),
461            Statement::Update(u) => self.visit_update(u),
462            Statement::Delete(d) => self.visit_delete(d),
463            _ => {}
464        }
465    }
466
467    fn visit_select(&mut self, select: &SelectStmt) {
468        for item in &select.columns {
469            if let SelectItem::Expr { expr, .. } = item {
470                self.visit_expr(expr);
471            }
472        }
473        if let Some(where_clause) = &select.where_clause {
474            self.visit_expr(where_clause);
475        }
476        if let Some(having) = &select.having {
477            self.visit_expr(having);
478        }
479        for order in &select.order_by {
480            self.visit_expr(&order.expr);
481        }
482        if let Some(limit) = &select.limit {
483            self.visit_expr(limit);
484        }
485        if let Some(offset) = &select.offset {
486            self.visit_expr(offset);
487        }
488    }
489
490    fn visit_insert(&mut self, insert: &InsertStmt) {
491        if let InsertSource::Values(rows) = &insert.source {
492            for row in rows {
493                for expr in row {
494                    self.visit_expr(expr);
495                }
496            }
497        }
498    }
499
500    fn visit_update(&mut self, update: &UpdateStmt) {
501        for assign in &update.assignments {
502            self.visit_expr(&assign.value);
503        }
504        if let Some(where_clause) = &update.where_clause {
505            self.visit_expr(where_clause);
506        }
507    }
508
509    fn visit_delete(&mut self, delete: &DeleteStmt) {
510        if let Some(where_clause) = &delete.where_clause {
511            self.visit_expr(where_clause);
512        }
513    }
514
515    fn visit_expr(&mut self, expr: &Expr) {
516        match expr {
517            Expr::Placeholder(n) => {
518                self.max_placeholder = self.max_placeholder.max(*n);
519            }
520            Expr::BinaryOp { left, right, .. } => {
521                self.visit_expr(left);
522                self.visit_expr(right);
523            }
524            Expr::UnaryOp { expr, .. } => {
525                self.visit_expr(expr);
526            }
527            Expr::Function(func) => {
528                for arg in &func.args {
529                    self.visit_expr(arg);
530                }
531            }
532            Expr::Case { operand, conditions, else_result } => {
533                if let Some(op) = operand {
534                    self.visit_expr(op);
535                }
536                for (when, then) in conditions {
537                    self.visit_expr(when);
538                    self.visit_expr(then);
539                }
540                if let Some(else_expr) = else_result {
541                    self.visit_expr(else_expr);
542                }
543            }
544            Expr::InList { expr, list, .. } => {
545                self.visit_expr(expr);
546                for item in list {
547                    self.visit_expr(item);
548                }
549            }
550            Expr::Between { expr, low, high, .. } => {
551                self.visit_expr(expr);
552                self.visit_expr(low);
553                self.visit_expr(high);
554            }
555            Expr::Cast { expr, .. } => {
556                self.visit_expr(expr);
557            }
558            _ => {}
559        }
560    }
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566
567    #[test]
568    fn test_placeholder_visitor() {
569        let stmt = Parser::parse("SELECT * FROM users WHERE id = $1 AND name = $2").unwrap();
570        let mut visitor = PlaceholderVisitor::new();
571        visitor.visit_statement(&stmt);
572        assert_eq!(visitor.max_placeholder, 2);
573    }
574
575    #[test]
576    fn test_question_mark_placeholders() {
577        let stmt = Parser::parse("SELECT * FROM users WHERE id = ? AND name = ?").unwrap();
578        let mut visitor = PlaceholderVisitor::new();
579        visitor.visit_statement(&stmt);
580        assert_eq!(visitor.max_placeholder, 2);
581    }
582
583    #[test]
584    fn test_dialect_detection() {
585        assert_eq!(SqlDialect::detect("SELECT * FROM users"), SqlDialect::Standard);
586        assert_eq!(
587            SqlDialect::detect("INSERT IGNORE INTO users VALUES (1)"),
588            SqlDialect::MySQL
589        );
590        assert_eq!(
591            SqlDialect::detect("INSERT OR IGNORE INTO users VALUES (1)"),
592            SqlDialect::SQLite
593        );
594    }
595}