Skip to main content

sqlrite/sql/parser/
create.rs

1use sqlparser::ast::{
2    ColumnDef, ColumnOption, CreateTable, DataType, Expr, ObjectName, ObjectNamePart, Statement,
3    UnaryOperator, Value as AstValue,
4};
5
6use crate::error::{Result, SQLRiteError};
7use crate::sql::db::table::Value;
8
9/// True when an `ObjectName` resolves to a single identifier `VECTOR`
10/// (case-insensitive). Phase 7a adds the `VECTOR(N)` column type as a
11/// sqlparser `DataType::Custom` — the engine recognizes it via this
12/// helper so the regular DataType match arm above stays uncluttered.
13fn is_vector_type(name: &ObjectName) -> bool {
14    name.0.len() == 1
15        && match &name.0[0] {
16            ObjectNamePart::Identifier(ident) => ident.value.eq_ignore_ascii_case("VECTOR"),
17            // Function-form ObjectNamePart shouldn't appear in a CREATE TABLE
18            // column type position. If it ever does, treat it as not-a-vector
19            // and the outer match falls through to the "Invalid" arm.
20            _ => false,
21        }
22}
23
24/// Parses the dimension out of the `Custom` args for `VECTOR(N)`.
25/// `args` is the `Vec<String>` sqlparser hands back for parenthesized
26/// type arguments — for `VECTOR(384)` that's `["384"]`. Validates that
27/// exactly one positive-integer argument was supplied.
28fn parse_vector_dim(args: &[String]) -> std::result::Result<usize, String> {
29    match args {
30        [] => Err("VECTOR requires a dimension, e.g. `VECTOR(384)`".to_string()),
31        [single] => {
32            let trimmed = single.trim();
33            match trimmed.parse::<usize>() {
34                Ok(d) if d > 0 => Ok(d),
35                Ok(_) => Err(format!("VECTOR dimension must be ≥ 1 (got `{trimmed}`)")),
36                Err(_) => Err(format!(
37                    "VECTOR dimension must be a positive integer (got `{trimmed}`)"
38                )),
39            }
40        }
41        many => Err(format!(
42            "VECTOR takes exactly one dimension argument (got {})",
43            many.len()
44        )),
45    }
46}
47
48/// The schema for each SQL column in every table is represented by
49/// the following structure after parsed and tokenized
50#[derive(PartialEq, Debug, Clone)]
51pub struct ParsedColumn {
52    /// Name of the column
53    pub name: String,
54    /// Datatype of the column in String format
55    pub datatype: String,
56    /// Value representing if column is PRIMARY KEY
57    pub is_pk: bool,
58    /// Value representing if column was declared with the NOT NULL Constraint
59    pub not_null: bool,
60    /// Value representing if column was declared with the UNIQUE Constraint
61    pub is_unique: bool,
62    /// Literal value to use when this column is omitted from an INSERT.
63    /// Restricted to literal expressions (integer, real, text, bool, NULL);
64    /// non-literal `DEFAULT` expressions are rejected at CREATE TABLE time.
65    pub default: Option<Value>,
66}
67
68/// The following structure represents a CREATE TABLE query already parsed
69/// and broken down into name and a Vector of `ParsedColumn` metadata
70///
71#[derive(Debug)]
72pub struct CreateQuery {
73    /// name of table after parking and tokenizing of query
74    pub table_name: String,
75    /// Vector of `ParsedColumn` type with column metadata information
76    pub columns: Vec<ParsedColumn>,
77}
78
79/// Parses a single sqlparser `ColumnDef` into our internal `ParsedColumn`
80/// representation. Extracted from `CreateQuery::new` so `ALTER TABLE ADD
81/// COLUMN` can reuse the same column-shape parsing without re-implementing
82/// the type / constraint / default plumbing.
83///
84/// Caller-side responsibilities not handled here:
85/// - duplicate column name detection (a multi-column invariant)
86/// - "more than one PRIMARY KEY" detection (a multi-column invariant)
87pub fn parse_one_column(col: &ColumnDef) -> Result<ParsedColumn> {
88    let name = col.name.to_string();
89
90    // Parsing each column for it data type
91    // For now only accepting basic data types
92    let datatype: String = match &col.data_type {
93        DataType::TinyInt(_)
94        | DataType::SmallInt(_)
95        | DataType::Int2(_)
96        | DataType::Int(_)
97        | DataType::Int4(_)
98        | DataType::Int8(_)
99        | DataType::Integer(_)
100        | DataType::BigInt(_) => "Integer".to_string(),
101        DataType::Boolean => "Bool".to_string(),
102        DataType::Text => "Text".to_string(),
103        DataType::Varchar(_bytes) => "Text".to_string(),
104        DataType::Real => "Real".to_string(),
105        DataType::Float(_precision) => "Real".to_string(),
106        DataType::Double(_) => "Real".to_string(),
107        DataType::Decimal(_) => "Real".to_string(),
108        // Phase 7e — `JSON` parses as a unit variant in
109        // sqlparser's DataType enum. JSONB is treated as
110        // an alias (matches PostgreSQL's permissive
111        // behaviour); both store as text under the hood.
112        DataType::JSON | DataType::JSONB => "Json".to_string(),
113        // Phase 7a — `VECTOR(N)` parses as Custom("VECTOR", ["N"]).
114        // sqlparser's SQLite dialect doesn't have a built-in
115        // Vector variant; Custom is what unrecognized type
116        // names + their parenthesized args fall through to.
117        DataType::Custom(name, args) if is_vector_type(name) => match parse_vector_dim(args) {
118            Ok(dim) => format!("vector({dim})"),
119            Err(e) => {
120                return Err(SQLRiteError::General(format!(
121                    "Invalid VECTOR column '{}': {e}",
122                    col.name
123                )));
124            }
125        },
126        other => {
127            eprintln!("not matched on custom type: {other:?}");
128            "Invalid".to_string()
129        }
130    };
131
132    let mut is_pk: bool = false;
133    let mut is_unique: bool = false;
134    let mut not_null: bool = false;
135    let mut default: Option<Value> = None;
136    for column_option in &col.options {
137        match &column_option.option {
138            ColumnOption::PrimaryKey(_) => {
139                // For now, only Integer and Text types can be PRIMARY KEY and Unique
140                // Therefore Indexed.
141                if datatype != "Real" && datatype != "Bool" {
142                    is_pk = true;
143                    is_unique = true;
144                    not_null = true;
145                }
146            }
147            ColumnOption::Unique(_) => {
148                // For now, only Integer and Text types can be UNIQUE
149                // Therefore Indexed.
150                if datatype != "Real" && datatype != "Bool" {
151                    is_unique = true;
152                }
153            }
154            ColumnOption::NotNull => {
155                not_null = true;
156            }
157            ColumnOption::Default(expr) => {
158                default = Some(eval_literal_default(expr, &datatype, &name)?);
159            }
160            _ => (),
161        };
162    }
163
164    Ok(ParsedColumn {
165        name,
166        datatype,
167        is_pk,
168        not_null,
169        is_unique,
170        default,
171    })
172}
173
174/// Evaluates a `DEFAULT <expr>` clause to a runtime `Value`. Restricted to
175/// literal expressions — anything else (function calls, column references,
176/// arithmetic on non-literals, `CURRENT_TIMESTAMP`, …) is rejected with a
177/// typed error so users see the limit at `CREATE TABLE` time rather than
178/// silently accepting a `DEFAULT` we can't honour at INSERT time.
179///
180/// Negative numeric literals come through sqlparser as `UnaryOp { Minus, Value(N) }`;
181/// we unwrap one level of leading `+`/`-` to support `DEFAULT -1` / `DEFAULT +3.14`.
182///
183/// Type-checks the literal against the column's declared datatype and
184/// rejects mismatches (e.g. `INTEGER ... DEFAULT 'foo'`).
185fn eval_literal_default(expr: &Expr, datatype: &str, col_name: &str) -> Result<Value> {
186    let value = match expr {
187        Expr::Value(v) => &v.value,
188        Expr::UnaryOp {
189            op: UnaryOperator::Minus,
190            expr: inner,
191        } => {
192            return match inner.as_ref() {
193                Expr::Value(v) => match &v.value {
194                    AstValue::Number(n, _) => {
195                        let neg = format!("-{n}");
196                        coerce_number_default(&neg, datatype, col_name)
197                    }
198                    _ => Err(SQLRiteError::General(format!(
199                        "DEFAULT for column '{col_name}' must be a literal value"
200                    ))),
201                },
202                _ => Err(SQLRiteError::General(format!(
203                    "DEFAULT for column '{col_name}' must be a literal value"
204                ))),
205            };
206        }
207        Expr::UnaryOp {
208            op: UnaryOperator::Plus,
209            expr: inner,
210        } => {
211            return eval_literal_default(inner, datatype, col_name);
212        }
213        _ => {
214            return Err(SQLRiteError::General(format!(
215                "DEFAULT for column '{col_name}' must be a literal value"
216            )));
217        }
218    };
219
220    match value {
221        AstValue::Null => Ok(Value::Null),
222        AstValue::Boolean(b) => {
223            if datatype == "Bool" {
224                Ok(Value::Bool(*b))
225            } else {
226                Err(SQLRiteError::General(format!(
227                    "DEFAULT type mismatch for column '{col_name}': boolean is not a {datatype}"
228                )))
229            }
230        }
231        AstValue::SingleQuotedString(s) => {
232            if datatype == "Text" {
233                Ok(Value::Text(s.clone()))
234            } else if datatype == "Json" {
235                // JSON columns accept text literals only if they parse as
236                // JSON — otherwise an ALTER TABLE ADD COLUMN ... JSON
237                // DEFAULT '<garbage>' would silently backfill every row
238                // with invalid JSON (insert_row's per-row JSON validation
239                // is bypassed during the backfill path).
240                serde_json::from_str::<serde_json::Value>(s).map_err(|e| {
241                    SQLRiteError::General(format!(
242                        "DEFAULT type mismatch for column '{col_name}': '{s}' is not valid JSON: {e}"
243                    ))
244                })?;
245                Ok(Value::Text(s.clone()))
246            } else {
247                Err(SQLRiteError::General(format!(
248                    "DEFAULT type mismatch for column '{col_name}': text is not a {datatype}"
249                )))
250            }
251        }
252        AstValue::Number(n, _) => coerce_number_default(n, datatype, col_name),
253        _ => Err(SQLRiteError::General(format!(
254            "DEFAULT for column '{col_name}' must be a literal value"
255        ))),
256    }
257}
258
259fn coerce_number_default(n: &str, datatype: &str, col_name: &str) -> Result<Value> {
260    match datatype {
261        "Integer" => n.parse::<i64>().map(Value::Integer).map_err(|_| {
262            SQLRiteError::General(format!(
263                "DEFAULT type mismatch for column '{col_name}': '{n}' is not a valid INTEGER"
264            ))
265        }),
266        "Real" => n.parse::<f64>().map(Value::Real).map_err(|_| {
267            SQLRiteError::General(format!(
268                "DEFAULT type mismatch for column '{col_name}': '{n}' is not a valid REAL"
269            ))
270        }),
271        other => Err(SQLRiteError::General(format!(
272            "DEFAULT type mismatch for column '{col_name}': numeric literal is not a {other}"
273        ))),
274    }
275}
276
277impl CreateQuery {
278    pub fn new(statement: &Statement) -> Result<CreateQuery> {
279        match statement {
280            // Confirming the Statement is sqlparser::ast:Statement::CreateTable
281            Statement::CreateTable(CreateTable {
282                name,
283                columns,
284                constraints,
285                ..
286            }) => {
287                let table_name = name;
288                let mut parsed_columns: Vec<ParsedColumn> = vec![];
289
290                // Iterating over the columns returned form the Parser::parse:sql
291                // in the mod sql
292                for col in columns {
293                    // Checks if columm already added to parsed_columns, if so, returns an error
294                    let name = col.name.to_string();
295                    if parsed_columns.iter().any(|c| c.name == name) {
296                        return Err(SQLRiteError::Internal(format!(
297                            "Duplicate column name: {}",
298                            &name
299                        )));
300                    }
301
302                    let parsed = parse_one_column(col)?;
303
304                    // Multi-column invariant: only one PRIMARY KEY per table.
305                    if parsed.is_pk && parsed_columns.iter().any(|c| c.is_pk) {
306                        return Err(SQLRiteError::Internal(format!(
307                            "Table '{}' has more than one primary key",
308                            &table_name
309                        )));
310                    }
311
312                    parsed_columns.push(parsed);
313                }
314                // TODO: handle constraints + check constraints + ON DELETE /
315                // ON UPDATE referential actions properly. They're currently
316                // parsed by `sqlparser` and dropped on the floor here.
317                // (Previously we `println!`-ed them to stdout as a debug
318                // aid — removed in the engine-stdout-pollution cleanup;
319                // flip to a `tracing` span if we ever want them visible in
320                // dev builds.)
321                let _ = constraints;
322                Ok(CreateQuery {
323                    table_name: table_name.to_string(),
324                    columns: parsed_columns,
325                })
326            }
327
328            _ => Err(SQLRiteError::Internal("Error parsing query".to_string())),
329        }
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use crate::sql::*;
337
338    #[test]
339    fn create_table_validate_tablename_test() {
340        let sql_input = String::from(
341            "CREATE TABLE contacts (
342            id INTEGER PRIMARY KEY,
343            first_name TEXT NOT NULL,
344            last_name TEXT NOT NULl,
345            email TEXT NOT NULL UNIQUE
346        );",
347        );
348        let expected_table_name = String::from("contacts");
349
350        let dialect = SqlriteDialect::new();
351        let mut ast = Parser::parse_sql(&dialect, &sql_input).unwrap();
352
353        assert!(ast.len() == 1, "ast has more then one Statement");
354
355        let query = ast.pop().unwrap();
356
357        // Initialy only implementing some basic SQL Statements
358        if let Statement::CreateTable(_) = query {
359            let result = CreateQuery::new(&query);
360            match result {
361                Ok(payload) => {
362                    assert_eq!(payload.table_name, expected_table_name);
363                }
364                Err(_) => panic!("an error occured during parsing CREATE TABLE Statement"),
365            }
366        }
367    }
368}