Skip to main content

sqlglot_rust/executor/
mod.rs

1//! In-memory SQL execution engine for testing and validation.
2//!
3//! Provides the ability to run SQL queries against Rust data structures,
4//! supporting SELECT with WHERE, GROUP BY, HAVING, ORDER BY, LIMIT/OFFSET,
5//! JOINs, aggregate functions, subqueries, CTEs, and set operations.
6//!
7//! # Example
8//!
9//! ```
10//! use std::collections::HashMap;
11//! use sqlglot_rust::executor::{execute, Table, Value};
12//!
13//! let mut tables = HashMap::new();
14//! tables.insert("t".to_string(), Table::new(
15//!     vec!["a".to_string(), "b".to_string()],
16//!     vec![
17//!         vec![Value::Int(1), Value::String("x".to_string())],
18//!         vec![Value::Int(2), Value::String("y".to_string())],
19//!     ],
20//! ));
21//! let result = execute("SELECT a, b FROM t WHERE a > 1", &tables).unwrap();
22//! assert_eq!(result.row_count(), 1);
23//! ```
24
25mod engine;
26mod eval;
27
28use std::collections::HashMap;
29use std::fmt;
30use std::hash::{Hash, Hasher};
31
32use crate::ast::Statement;
33use crate::dialects::Dialect;
34use crate::errors::Result;
35use crate::parser;
36
37// ═══════════════════════════════════════════════════════════════════════
38// Value
39// ═══════════════════════════════════════════════════════════════════════
40
41/// A SQL value that can be stored in a table cell or produced by
42/// expression evaluation.
43#[derive(Debug, Clone)]
44pub enum Value {
45    /// SQL NULL.
46    Null,
47    /// A boolean value.
48    Boolean(bool),
49    /// A 64-bit integer.
50    Int(i64),
51    /// A 64-bit floating-point number.
52    Float(f64),
53    /// A string value.
54    String(String),
55}
56
57impl PartialEq for Value {
58    fn eq(&self, other: &Self) -> bool {
59        match (self, other) {
60            (Value::Null, Value::Null) => true,
61            (Value::Boolean(a), Value::Boolean(b)) => a == b,
62            (Value::Int(a), Value::Int(b)) => a == b,
63            (Value::Float(a), Value::Float(b)) => a == b,
64            (Value::Int(a), Value::Float(b)) => (*a as f64) == *b,
65            (Value::Float(a), Value::Int(b)) => *a == (*b as f64),
66            (Value::String(a), Value::String(b)) => a == b,
67            _ => false,
68        }
69    }
70}
71
72impl Eq for Value {}
73
74impl Hash for Value {
75    fn hash<H: Hasher>(&self, state: &mut H) {
76        core::mem::discriminant(self).hash(state);
77        match self {
78            Value::Null => {}
79            Value::Boolean(b) => b.hash(state),
80            Value::Int(i) => i.hash(state),
81            Value::Float(f) => f.to_bits().hash(state),
82            Value::String(s) => s.hash(state),
83        }
84    }
85}
86
87impl fmt::Display for Value {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        match self {
90            Value::Null => write!(f, "NULL"),
91            Value::Boolean(b) => write!(f, "{b}"),
92            Value::Int(i) => write!(f, "{i}"),
93            Value::Float(v) => {
94                if v.fract() == 0.0 && v.abs() < 1e15 {
95                    write!(f, "{v:.1}")
96                } else {
97                    write!(f, "{v}")
98                }
99            }
100            Value::String(s) => write!(f, "{s}"),
101        }
102    }
103}
104
105impl Value {
106    /// Returns `true` if this value is NULL.
107    #[must_use]
108    pub fn is_null(&self) -> bool {
109        matches!(self, Value::Null)
110    }
111
112    /// Returns `true` if this value is truthy (non-NULL, non-zero,
113    /// non-empty).
114    #[must_use]
115    pub fn is_truthy(&self) -> bool {
116        match self {
117            Value::Null => false,
118            Value::Boolean(b) => *b,
119            Value::Int(i) => *i != 0,
120            Value::Float(f) => *f != 0.0,
121            Value::String(s) => !s.is_empty(),
122        }
123    }
124
125    /// Try to convert to `f64`.
126    #[must_use]
127    pub fn to_f64(&self) -> Option<f64> {
128        match self {
129            Value::Int(i) => Some(*i as f64),
130            Value::Float(f) => Some(*f),
131            Value::String(s) => s.parse().ok(),
132            Value::Boolean(b) => Some(if *b { 1.0 } else { 0.0 }),
133            Value::Null => None,
134        }
135    }
136
137    /// Try to convert to `i64`.
138    #[must_use]
139    pub fn to_i64(&self) -> Option<i64> {
140        match self {
141            Value::Int(i) => Some(*i),
142            Value::Float(f) => Some(*f as i64),
143            Value::String(s) => s.parse().ok(),
144            Value::Boolean(b) => Some(i64::from(*b)),
145            Value::Null => None,
146        }
147    }
148
149    /// Convert to a `String` representation (empty string for NULL).
150    #[must_use]
151    pub fn to_string_val(&self) -> String {
152        match self {
153            Value::Null => String::new(),
154            other => other.to_string(),
155        }
156    }
157}
158
159impl PartialOrd for Value {
160    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
161        match (self, other) {
162            (Value::Null, Value::Null) => Some(std::cmp::Ordering::Equal),
163            (Value::Null, _) => Some(std::cmp::Ordering::Less),
164            (_, Value::Null) => Some(std::cmp::Ordering::Greater),
165            (Value::Int(a), Value::Int(b)) => a.partial_cmp(b),
166            (Value::Float(a), Value::Float(b)) => a.partial_cmp(b),
167            (Value::Int(a), Value::Float(b)) => (*a as f64).partial_cmp(b),
168            (Value::Float(a), Value::Int(b)) => a.partial_cmp(&(*b as f64)),
169            (Value::String(a), Value::String(b)) => Some(a.cmp(b)),
170            (Value::Boolean(a), Value::Boolean(b)) => a.partial_cmp(b),
171            _ => None,
172        }
173    }
174}
175
176// ═══════════════════════════════════════════════════════════════════════
177// Table
178// ═══════════════════════════════════════════════════════════════════════
179
180/// An in-memory table with named columns and rows of values.
181#[derive(Debug, Clone)]
182pub struct Table {
183    /// Column names.
184    pub columns: Vec<String>,
185    /// Row data. Each inner `Vec` has one entry per column.
186    pub rows: Vec<Vec<Value>>,
187}
188
189impl Table {
190    /// Create a new table from owned column names and rows.
191    pub fn new(columns: Vec<String>, rows: Vec<Vec<Value>>) -> Self {
192        Self { columns, rows }
193    }
194
195    /// Create a table from string-slice column names.
196    pub fn from_rows(columns: Vec<&str>, rows: Vec<Vec<Value>>) -> Self {
197        Self {
198            columns: columns.into_iter().map(String::from).collect(),
199            rows,
200        }
201    }
202}
203
204/// A mapping of table names to in-memory tables.
205pub type Tables = HashMap<String, Table>;
206
207// ═══════════════════════════════════════════════════════════════════════
208// ResultSet
209// ═══════════════════════════════════════════════════════════════════════
210
211/// The result of executing a SQL query.
212#[derive(Debug, Clone)]
213pub struct ResultSet {
214    /// Column names in the result.
215    pub columns: Vec<String>,
216    /// Row data.
217    pub rows: Vec<Vec<Value>>,
218}
219
220impl ResultSet {
221    /// Create a new result set.
222    #[must_use]
223    pub fn new(columns: Vec<String>, rows: Vec<Vec<Value>>) -> Self {
224        Self { columns, rows }
225    }
226
227    /// Create an empty result set.
228    #[must_use]
229    pub fn empty() -> Self {
230        Self {
231            columns: vec![],
232            rows: vec![],
233        }
234    }
235
236    /// Number of rows.
237    #[must_use]
238    pub fn row_count(&self) -> usize {
239        self.rows.len()
240    }
241
242    /// Number of columns.
243    #[must_use]
244    pub fn column_count(&self) -> usize {
245        self.columns.len()
246    }
247}
248
249impl fmt::Display for ResultSet {
250    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251        if self.columns.is_empty() {
252            return write!(f, "(empty)");
253        }
254        // Compute column widths.
255        let mut widths: Vec<usize> = self.columns.iter().map(|c| c.len()).collect();
256        for row in &self.rows {
257            for (i, val) in row.iter().enumerate() {
258                if i < widths.len() {
259                    widths[i] = widths[i].max(val.to_string().len());
260                }
261            }
262        }
263        // Header.
264        let header: Vec<String> = self
265            .columns
266            .iter()
267            .enumerate()
268            .map(|(i, c)| format!("{:width$}", c, width = widths[i]))
269            .collect();
270        writeln!(f, "{}", header.join(" | "))?;
271        let sep: Vec<String> = widths.iter().map(|w| "-".repeat(*w)).collect();
272        writeln!(f, "{}", sep.join("-+-"))?;
273        // Rows.
274        for row in &self.rows {
275            let cells: Vec<String> = row
276                .iter()
277                .enumerate()
278                .map(|(i, v)| format!("{:width$}", v, width = widths.get(i).copied().unwrap_or(0)))
279                .collect();
280            writeln!(f, "{}", cells.join(" | "))?;
281        }
282        Ok(())
283    }
284}
285
286// ═══════════════════════════════════════════════════════════════════════
287// RowContext (internal)
288// ═══════════════════════════════════════════════════════════════════════
289
290/// Internal row context used for expression evaluation.
291///
292/// Columns are stored as `"table_alias.column_name"` (lowercased) so
293/// that both qualified and unqualified look-ups work.
294#[derive(Debug, Clone)]
295pub(crate) struct RowContext {
296    pub columns: Vec<String>,
297    pub values: Vec<Value>,
298}
299
300impl RowContext {
301    pub fn empty() -> Self {
302        Self {
303            columns: vec![],
304            values: vec![],
305        }
306    }
307
308    pub fn new(columns: Vec<String>, values: Vec<Value>) -> Self {
309        Self { columns, values }
310    }
311
312    /// Look up a value by unqualified column name.
313    pub fn get(&self, name: &str) -> Option<&Value> {
314        let name_lower = name.to_lowercase();
315        // Exact match first.
316        for (i, col) in self.columns.iter().enumerate() {
317            if col.to_lowercase() == name_lower {
318                return Some(&self.values[i]);
319            }
320        }
321        // Strip table qualifier.
322        for (i, col) in self.columns.iter().enumerate() {
323            let col_lower = col.to_lowercase();
324            if let Some((_, suffix)) = col_lower.rsplit_once('.') {
325                if suffix == name_lower {
326                    return Some(&self.values[i]);
327                }
328            }
329        }
330        None
331    }
332
333    /// Look up a value by qualified column name (`table.column`).
334    pub fn get_qualified(&self, table: &str, name: &str) -> Option<&Value> {
335        let qualified = format!("{}.{}", table, name).to_lowercase();
336        for (i, col) in self.columns.iter().enumerate() {
337            if col.to_lowercase() == qualified {
338                return Some(&self.values[i]);
339            }
340        }
341        self.get(name)
342    }
343
344    /// Merge two row contexts (used for JOINs).
345    pub fn merge(&self, other: &RowContext) -> RowContext {
346        let mut columns = self.columns.clone();
347        let mut values = self.values.clone();
348        columns.extend(other.columns.iter().cloned());
349        values.extend(other.values.iter().cloned());
350        RowContext { columns, values }
351    }
352
353    /// Create a NULL-filled context with the given column names.
354    pub fn null_row(columns: &[String]) -> RowContext {
355        RowContext {
356            columns: columns.to_vec(),
357            values: vec![Value::Null; columns.len()],
358        }
359    }
360}
361
362// ═══════════════════════════════════════════════════════════════════════
363// Public API
364// ═══════════════════════════════════════════════════════════════════════
365
366/// Execute a SQL query string against the provided tables.
367///
368/// Parses the query as ANSI SQL, then runs it in-memory against
369/// `tables`.
370pub fn execute(sql: &str, tables: &Tables) -> Result<ResultSet> {
371    let stmt = parser::parse(sql, Dialect::Ansi)?;
372    execute_statement(&stmt, tables)
373}
374
375/// Execute a pre-parsed [`Statement`] against the provided tables.
376pub fn execute_statement(stmt: &Statement, tables: &Tables) -> Result<ResultSet> {
377    let mut ctx = engine::ExecutionContext::new(tables);
378    ctx.execute(stmt)
379}