1use super::ast::*;
47use super::compatibility::SqlDialect;
48use super::error::{SqlError, SqlResult};
49use super::parser::Parser;
50use std::collections::HashMap;
51use sochdb_core::SochValue;
52
53#[derive(Debug, Clone)]
55pub enum ExecutionResult {
56 Rows {
58 columns: Vec<String>,
59 rows: Vec<HashMap<String, SochValue>>,
60 },
61 RowsAffected(usize),
63 Ok,
65 TransactionOk,
67}
68
69impl ExecutionResult {
70 pub fn rows(&self) -> Option<&Vec<HashMap<String, SochValue>>> {
72 match self {
73 ExecutionResult::Rows { rows, .. } => Some(rows),
74 _ => None,
75 }
76 }
77
78 pub fn columns(&self) -> Option<&Vec<String>> {
80 match self {
81 ExecutionResult::Rows { columns, .. } => Some(columns),
82 _ => None,
83 }
84 }
85
86 pub fn rows_affected(&self) -> usize {
88 match self {
89 ExecutionResult::RowsAffected(n) => *n,
90 ExecutionResult::Rows { rows, .. } => rows.len(),
91 _ => 0,
92 }
93 }
94}
95
96pub trait SqlConnection {
101 fn select(
103 &self,
104 table: &str,
105 columns: &[String],
106 where_clause: Option<&Expr>,
107 order_by: &[OrderByItem],
108 limit: Option<usize>,
109 offset: Option<usize>,
110 params: &[SochValue],
111 ) -> SqlResult<ExecutionResult>;
112
113 fn insert(
115 &mut self,
116 table: &str,
117 columns: Option<&[String]>,
118 rows: &[Vec<Expr>],
119 on_conflict: Option<&OnConflict>,
120 params: &[SochValue],
121 ) -> SqlResult<ExecutionResult>;
122
123 fn update(
125 &mut self,
126 table: &str,
127 assignments: &[Assignment],
128 where_clause: Option<&Expr>,
129 params: &[SochValue],
130 ) -> SqlResult<ExecutionResult>;
131
132 fn delete(
134 &mut self,
135 table: &str,
136 where_clause: Option<&Expr>,
137 params: &[SochValue],
138 ) -> SqlResult<ExecutionResult>;
139
140 fn create_table(&mut self, stmt: &CreateTableStmt) -> SqlResult<ExecutionResult>;
142
143 fn drop_table(&mut self, stmt: &DropTableStmt) -> SqlResult<ExecutionResult>;
145
146 fn create_index(&mut self, stmt: &CreateIndexStmt) -> SqlResult<ExecutionResult>;
148
149 fn drop_index(&mut self, stmt: &DropIndexStmt) -> SqlResult<ExecutionResult>;
151
152 fn begin(&mut self, stmt: &BeginStmt) -> SqlResult<ExecutionResult>;
154
155 fn commit(&mut self) -> SqlResult<ExecutionResult>;
157
158 fn rollback(&mut self, savepoint: Option<&str>) -> SqlResult<ExecutionResult>;
160
161 fn table_exists(&self, table: &str) -> SqlResult<bool>;
163
164 fn index_exists(&self, index: &str) -> SqlResult<bool>;
166}
167
168pub struct SqlBridge<C: SqlConnection> {
170 conn: C,
171}
172
173impl<C: SqlConnection> SqlBridge<C> {
174 pub fn new(conn: C) -> Self {
176 Self { conn }
177 }
178
179 pub fn execute(&mut self, sql: &str) -> SqlResult<ExecutionResult> {
181 self.execute_with_params(sql, &[])
182 }
183
184 pub fn execute_with_params(
186 &mut self,
187 sql: &str,
188 params: &[SochValue],
189 ) -> SqlResult<ExecutionResult> {
190 let _dialect = SqlDialect::detect(sql);
192
193 let stmt = Parser::parse(sql).map_err(SqlError::from_parse_errors)?;
195
196 let max_placeholder = self.find_max_placeholder(&stmt);
198 if max_placeholder as usize > params.len() {
199 return Err(SqlError::InvalidArgument(format!(
200 "Query contains {} placeholders but only {} parameters provided",
201 max_placeholder,
202 params.len()
203 )));
204 }
205
206 self.execute_statement(&stmt, params)
208 }
209
210 pub fn execute_statement(
212 &mut self,
213 stmt: &Statement,
214 params: &[SochValue],
215 ) -> SqlResult<ExecutionResult> {
216 match stmt {
217 Statement::Select(select) => self.execute_select(select, params),
218 Statement::Insert(insert) => self.execute_insert(insert, params),
219 Statement::Update(update) => self.execute_update(update, params),
220 Statement::Delete(delete) => self.execute_delete(delete, params),
221 Statement::CreateTable(create) => self.execute_create_table(create),
222 Statement::DropTable(drop) => self.execute_drop_table(drop),
223 Statement::CreateIndex(create) => self.execute_create_index(create),
224 Statement::DropIndex(drop) => self.execute_drop_index(drop),
225 Statement::AlterTable(_alter) => Err(SqlError::NotImplemented(
226 "ALTER TABLE not yet implemented".into(),
227 )),
228 Statement::Begin(begin) => self.conn.begin(begin),
229 Statement::Commit => self.conn.commit(),
230 Statement::Rollback(savepoint) => self.conn.rollback(savepoint.as_deref()),
231 Statement::Savepoint(_name) => Err(SqlError::NotImplemented(
232 "SAVEPOINT not yet implemented".into(),
233 )),
234 Statement::Release(_name) => Err(SqlError::NotImplemented(
235 "RELEASE SAVEPOINT not yet implemented".into(),
236 )),
237 Statement::Explain(_stmt) => Err(SqlError::NotImplemented(
238 "EXPLAIN not yet implemented".into(),
239 )),
240 }
241 }
242
243 fn execute_select(
244 &self,
245 select: &SelectStmt,
246 params: &[SochValue],
247 ) -> SqlResult<ExecutionResult> {
248 let from = select
250 .from
251 .as_ref()
252 .ok_or_else(|| SqlError::InvalidArgument("SELECT requires FROM clause".into()))?;
253
254 if from.tables.len() != 1 {
255 return Err(SqlError::NotImplemented(
256 "Multi-table queries not yet supported".into(),
257 ));
258 }
259
260 let table_name = match &from.tables[0] {
261 TableRef::Table { name, .. } => name.name().to_string(),
262 TableRef::Subquery { .. } => {
263 return Err(SqlError::NotImplemented(
264 "Subqueries not yet supported".into(),
265 ));
266 }
267 TableRef::Join { .. } => {
268 return Err(SqlError::NotImplemented(
269 "JOINs not yet supported".into(),
270 ));
271 }
272 TableRef::Function { .. } => {
273 return Err(SqlError::NotImplemented(
274 "Table functions not yet supported".into(),
275 ));
276 }
277 };
278
279 let columns = self.extract_select_columns(&select.columns)?;
281
282 let limit = self.extract_limit(&select.limit)?;
284 let offset = self.extract_limit(&select.offset)?;
285
286 self.conn.select(
287 &table_name,
288 &columns,
289 select.where_clause.as_ref(),
290 &select.order_by,
291 limit,
292 offset,
293 params,
294 )
295 }
296
297 fn execute_insert(
298 &mut self,
299 insert: &InsertStmt,
300 params: &[SochValue],
301 ) -> SqlResult<ExecutionResult> {
302 let table_name = insert.table.name();
303
304 let rows = match &insert.source {
305 InsertSource::Values(values) => values,
306 InsertSource::Query(_) => {
307 return Err(SqlError::NotImplemented(
308 "INSERT ... SELECT not yet supported".into(),
309 ));
310 }
311 InsertSource::Default => {
312 return Err(SqlError::NotImplemented(
313 "INSERT DEFAULT VALUES not yet supported".into(),
314 ));
315 }
316 };
317
318 self.conn.insert(
319 table_name,
320 insert.columns.as_deref(),
321 rows,
322 insert.on_conflict.as_ref(),
323 params,
324 )
325 }
326
327 fn execute_update(
328 &mut self,
329 update: &UpdateStmt,
330 params: &[SochValue],
331 ) -> SqlResult<ExecutionResult> {
332 let table_name = update.table.name();
333
334 self.conn.update(
335 table_name,
336 &update.assignments,
337 update.where_clause.as_ref(),
338 params,
339 )
340 }
341
342 fn execute_delete(
343 &mut self,
344 delete: &DeleteStmt,
345 params: &[SochValue],
346 ) -> SqlResult<ExecutionResult> {
347 let table_name = delete.table.name();
348
349 self.conn.delete(
350 table_name,
351 delete.where_clause.as_ref(),
352 params,
353 )
354 }
355
356 fn execute_create_table(&mut self, stmt: &CreateTableStmt) -> SqlResult<ExecutionResult> {
357 if stmt.if_not_exists {
359 let table_name = stmt.name.name();
360 if self.conn.table_exists(table_name)? {
361 return Ok(ExecutionResult::Ok);
362 }
363 }
364
365 self.conn.create_table(stmt)
366 }
367
368 fn execute_drop_table(&mut self, stmt: &DropTableStmt) -> SqlResult<ExecutionResult> {
369 if stmt.if_exists {
371 for name in &stmt.names {
372 if !self.conn.table_exists(name.name())? {
373 return Ok(ExecutionResult::Ok);
374 }
375 }
376 }
377
378 self.conn.drop_table(stmt)
379 }
380
381 fn execute_create_index(&mut self, stmt: &CreateIndexStmt) -> SqlResult<ExecutionResult> {
382 if stmt.if_not_exists {
384 if self.conn.index_exists(&stmt.name)? {
385 return Ok(ExecutionResult::Ok);
386 }
387 }
388
389 self.conn.create_index(stmt)
390 }
391
392 fn execute_drop_index(&mut self, stmt: &DropIndexStmt) -> SqlResult<ExecutionResult> {
393 if stmt.if_exists {
395 if !self.conn.index_exists(&stmt.name)? {
396 return Ok(ExecutionResult::Ok);
397 }
398 }
399
400 self.conn.drop_index(stmt)
401 }
402
403 fn extract_select_columns(&self, items: &[SelectItem]) -> SqlResult<Vec<String>> {
405 let mut columns = Vec::new();
406
407 for item in items {
408 match item {
409 SelectItem::Wildcard => columns.push("*".to_string()),
410 SelectItem::QualifiedWildcard(table) => columns.push(format!("{}.*", table)),
411 SelectItem::Expr { expr, alias } => {
412 let name = alias.clone().unwrap_or_else(|| match expr {
413 Expr::Column(col) => col.column.clone(),
414 Expr::Function(func) => format!("{}()", func.name.name()),
415 _ => "?column?".to_string(),
416 });
417 columns.push(name);
418 }
419 }
420 }
421
422 Ok(columns)
423 }
424
425 fn extract_limit(&self, expr: &Option<Expr>) -> SqlResult<Option<usize>> {
427 match expr {
428 Some(Expr::Literal(Literal::Integer(n))) => Ok(Some(*n as usize)),
429 Some(_) => Err(SqlError::InvalidArgument(
430 "LIMIT/OFFSET must be an integer literal".into(),
431 )),
432 None => Ok(None),
433 }
434 }
435
436 fn find_max_placeholder(&self, stmt: &Statement) -> u32 {
438 let mut visitor = PlaceholderVisitor::new();
439 visitor.visit_statement(stmt);
440 visitor.max_placeholder
441 }
442}
443
444struct PlaceholderVisitor {
446 max_placeholder: u32,
447}
448
449impl PlaceholderVisitor {
450 fn new() -> Self {
451 Self { max_placeholder: 0 }
452 }
453
454 fn visit_statement(&mut self, stmt: &Statement) {
455 match stmt {
456 Statement::Select(s) => self.visit_select(s),
457 Statement::Insert(i) => self.visit_insert(i),
458 Statement::Update(u) => self.visit_update(u),
459 Statement::Delete(d) => self.visit_delete(d),
460 _ => {}
461 }
462 }
463
464 fn visit_select(&mut self, select: &SelectStmt) {
465 for item in &select.columns {
466 if let SelectItem::Expr { expr, .. } = item {
467 self.visit_expr(expr);
468 }
469 }
470 if let Some(where_clause) = &select.where_clause {
471 self.visit_expr(where_clause);
472 }
473 if let Some(having) = &select.having {
474 self.visit_expr(having);
475 }
476 for order in &select.order_by {
477 self.visit_expr(&order.expr);
478 }
479 if let Some(limit) = &select.limit {
480 self.visit_expr(limit);
481 }
482 if let Some(offset) = &select.offset {
483 self.visit_expr(offset);
484 }
485 }
486
487 fn visit_insert(&mut self, insert: &InsertStmt) {
488 if let InsertSource::Values(rows) = &insert.source {
489 for row in rows {
490 for expr in row {
491 self.visit_expr(expr);
492 }
493 }
494 }
495 }
496
497 fn visit_update(&mut self, update: &UpdateStmt) {
498 for assign in &update.assignments {
499 self.visit_expr(&assign.value);
500 }
501 if let Some(where_clause) = &update.where_clause {
502 self.visit_expr(where_clause);
503 }
504 }
505
506 fn visit_delete(&mut self, delete: &DeleteStmt) {
507 if let Some(where_clause) = &delete.where_clause {
508 self.visit_expr(where_clause);
509 }
510 }
511
512 fn visit_expr(&mut self, expr: &Expr) {
513 match expr {
514 Expr::Placeholder(n) => {
515 self.max_placeholder = self.max_placeholder.max(*n);
516 }
517 Expr::BinaryOp { left, right, .. } => {
518 self.visit_expr(left);
519 self.visit_expr(right);
520 }
521 Expr::UnaryOp { expr, .. } => {
522 self.visit_expr(expr);
523 }
524 Expr::Function(func) => {
525 for arg in &func.args {
526 self.visit_expr(arg);
527 }
528 }
529 Expr::Case { operand, conditions, else_result } => {
530 if let Some(op) = operand {
531 self.visit_expr(op);
532 }
533 for (when, then) in conditions {
534 self.visit_expr(when);
535 self.visit_expr(then);
536 }
537 if let Some(else_expr) = else_result {
538 self.visit_expr(else_expr);
539 }
540 }
541 Expr::InList { expr, list, .. } => {
542 self.visit_expr(expr);
543 for item in list {
544 self.visit_expr(item);
545 }
546 }
547 Expr::Between { expr, low, high, .. } => {
548 self.visit_expr(expr);
549 self.visit_expr(low);
550 self.visit_expr(high);
551 }
552 Expr::Cast { expr, .. } => {
553 self.visit_expr(expr);
554 }
555 _ => {}
556 }
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563
564 #[test]
565 fn test_placeholder_visitor() {
566 let stmt = Parser::parse("SELECT * FROM users WHERE id = $1 AND name = $2").unwrap();
567 let mut visitor = PlaceholderVisitor::new();
568 visitor.visit_statement(&stmt);
569 assert_eq!(visitor.max_placeholder, 2);
570 }
571
572 #[test]
573 fn test_question_mark_placeholders() {
574 let stmt = Parser::parse("SELECT * FROM users WHERE id = ? AND name = ?").unwrap();
575 let mut visitor = PlaceholderVisitor::new();
576 visitor.visit_statement(&stmt);
577 assert_eq!(visitor.max_placeholder, 2);
578 }
579
580 #[test]
581 fn test_dialect_detection() {
582 assert_eq!(SqlDialect::detect("SELECT * FROM users"), SqlDialect::Standard);
583 assert_eq!(
584 SqlDialect::detect("INSERT IGNORE INTO users VALUES (1)"),
585 SqlDialect::MySQL
586 );
587 assert_eq!(
588 SqlDialect::detect("INSERT OR IGNORE INTO users VALUES (1)"),
589 SqlDialect::SQLite
590 );
591 }
592}