Skip to main content

sqlrite/sql/parser/
create.rs

1use sqlparser::ast::{ColumnOption, CreateTable, DataType, ObjectName, ObjectNamePart, Statement};
2
3use crate::error::{Result, SQLRiteError};
4
5/// True when an `ObjectName` resolves to a single identifier `VECTOR`
6/// (case-insensitive). Phase 7a adds the `VECTOR(N)` column type as a
7/// sqlparser `DataType::Custom` — the engine recognizes it via this
8/// helper so the regular DataType match arm above stays uncluttered.
9fn is_vector_type(name: &ObjectName) -> bool {
10    name.0.len() == 1
11        && match &name.0[0] {
12            ObjectNamePart::Identifier(ident) => ident.value.eq_ignore_ascii_case("VECTOR"),
13            // Function-form ObjectNamePart shouldn't appear in a CREATE TABLE
14            // column type position. If it ever does, treat it as not-a-vector
15            // and the outer match falls through to the "Invalid" arm.
16            _ => false,
17        }
18}
19
20/// Parses the dimension out of the `Custom` args for `VECTOR(N)`.
21/// `args` is the `Vec<String>` sqlparser hands back for parenthesized
22/// type arguments — for `VECTOR(384)` that's `["384"]`. Validates that
23/// exactly one positive-integer argument was supplied.
24fn parse_vector_dim(args: &[String]) -> std::result::Result<usize, String> {
25    match args {
26        [] => Err("VECTOR requires a dimension, e.g. `VECTOR(384)`".to_string()),
27        [single] => {
28            let trimmed = single.trim();
29            match trimmed.parse::<usize>() {
30                Ok(d) if d > 0 => Ok(d),
31                Ok(_) => Err(format!("VECTOR dimension must be ≥ 1 (got `{trimmed}`)")),
32                Err(_) => Err(format!(
33                    "VECTOR dimension must be a positive integer (got `{trimmed}`)"
34                )),
35            }
36        }
37        many => Err(format!(
38            "VECTOR takes exactly one dimension argument (got {})",
39            many.len()
40        )),
41    }
42}
43
44/// The schema for each SQL column in every table is represented by
45/// the following structure after parsed and tokenized
46#[derive(PartialEq, Debug)]
47pub struct ParsedColumn {
48    /// Name of the column
49    pub name: String,
50    /// Datatype of the column in String format
51    pub datatype: String,
52    /// Value representing if column is PRIMARY KEY
53    pub is_pk: bool,
54    /// Value representing if column was declared with the NOT NULL Constraint
55    pub not_null: bool,
56    /// Value representing if column was declared with the UNIQUE Constraint
57    pub is_unique: bool,
58}
59
60/// The following structure represents a CREATE TABLE query already parsed
61/// and broken down into name and a Vector of `ParsedColumn` metadata
62///
63#[derive(Debug)]
64pub struct CreateQuery {
65    /// name of table after parking and tokenizing of query
66    pub table_name: String,
67    /// Vector of `ParsedColumn` type with column metadata information
68    pub columns: Vec<ParsedColumn>,
69}
70
71impl CreateQuery {
72    pub fn new(statement: &Statement) -> Result<CreateQuery> {
73        match statement {
74            // Confirming the Statement is sqlparser::ast:Statement::CreateTable
75            Statement::CreateTable(CreateTable {
76                name,
77                columns,
78                constraints,
79                ..
80            }) => {
81                let table_name = name;
82                let mut parsed_columns: Vec<ParsedColumn> = vec![];
83
84                // Iterating over the columns returned form the Parser::parse:sql
85                // in the mod sql
86                for col in columns {
87                    let name = col.name.to_string();
88
89                    // Checks if columm already added to parsed_columns, if so, returns an error
90                    if parsed_columns.iter().any(|col| col.name == name) {
91                        return Err(SQLRiteError::Internal(format!(
92                            "Duplicate column name: {}",
93                            &name
94                        )));
95                    }
96
97                    // Parsing each column for it data type
98                    // For now only accepting basic data types
99                    let datatype: String = match &col.data_type {
100                        DataType::TinyInt(_)
101                        | DataType::SmallInt(_)
102                        | DataType::Int2(_)
103                        | DataType::Int(_)
104                        | DataType::Int4(_)
105                        | DataType::Int8(_)
106                        | DataType::Integer(_)
107                        | DataType::BigInt(_) => "Integer".to_string(),
108                        DataType::Boolean => "Bool".to_string(),
109                        DataType::Text => "Text".to_string(),
110                        DataType::Varchar(_bytes) => "Text".to_string(),
111                        DataType::Real => "Real".to_string(),
112                        DataType::Float(_precision) => "Real".to_string(),
113                        DataType::Double(_) => "Real".to_string(),
114                        DataType::Decimal(_) => "Real".to_string(),
115                        // Phase 7a — `VECTOR(N)` parses as Custom("VECTOR", ["N"]).
116                        // sqlparser's SQLite dialect doesn't have a built-in
117                        // Vector variant; Custom is what unrecognized type
118                        // names + their parenthesized args fall through to.
119                        DataType::Custom(name, args) if is_vector_type(name) => {
120                            match parse_vector_dim(args) {
121                                Ok(dim) => format!("vector({dim})"),
122                                Err(e) => {
123                                    return Err(SQLRiteError::General(format!(
124                                        "Invalid VECTOR column '{}': {e}",
125                                        col.name
126                                    )));
127                                }
128                            }
129                        }
130                        other => {
131                            eprintln!("not matched on custom type: {other:?}");
132                            "Invalid".to_string()
133                        }
134                    };
135
136                    // checking if column is PRIMARY KEY
137                    let mut is_pk: bool = false;
138                    // chekcing if column is UNIQUE
139                    let mut is_unique: bool = false;
140                    // chekcing if column is NULLABLE
141                    let mut not_null: bool = false;
142                    for column_option in &col.options {
143                        match &column_option.option {
144                            ColumnOption::PrimaryKey(_) => {
145                                // For now, only Integer and Text types can be PRIMARY KEY and Unique
146                                // Therefore Indexed.
147                                if datatype != "Real" && datatype != "Bool" {
148                                    // Checks if table being created already has a PRIMARY KEY, if so, returns an error
149                                    if parsed_columns.iter().any(|col| col.is_pk) {
150                                        return Err(SQLRiteError::Internal(format!(
151                                            "Table '{}' has more than one primary key",
152                                            &table_name
153                                        )));
154                                    }
155                                    is_pk = true;
156                                    is_unique = true;
157                                    not_null = true;
158                                }
159                            }
160                            ColumnOption::Unique(_) => {
161                                // For now, only Integer and Text types can be UNIQUE
162                                // Therefore Indexed.
163                                if datatype != "Real" && datatype != "Bool" {
164                                    is_unique = true;
165                                }
166                            }
167                            ColumnOption::NotNull => {
168                                not_null = true;
169                            }
170                            _ => (),
171                        };
172                    }
173
174                    parsed_columns.push(ParsedColumn {
175                        name,
176                        datatype: datatype.to_string(),
177                        is_pk,
178                        not_null,
179                        is_unique,
180                    });
181                }
182                // TODO: Handle constraints,
183                // Default value and others.
184                for constraint in constraints {
185                    println!("{constraint:?}");
186                }
187                Ok(CreateQuery {
188                    table_name: table_name.to_string(),
189                    columns: parsed_columns,
190                })
191            }
192
193            _ => Err(SQLRiteError::Internal("Error parsing query".to_string())),
194        }
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use crate::sql::*;
202
203    #[test]
204    fn create_table_validate_tablename_test() {
205        let sql_input = String::from(
206            "CREATE TABLE contacts (
207            id INTEGER PRIMARY KEY,
208            first_name TEXT NOT NULL,
209            last_name TEXT NOT NULl,
210            email TEXT NOT NULL UNIQUE
211        );",
212        );
213        let expected_table_name = String::from("contacts");
214
215        let dialect = SQLiteDialect {};
216        let mut ast = Parser::parse_sql(&dialect, &sql_input).unwrap();
217
218        assert!(ast.len() == 1, "ast has more then one Statement");
219
220        let query = ast.pop().unwrap();
221
222        // Initialy only implementing some basic SQL Statements
223        if let Statement::CreateTable(_) = query {
224            let result = CreateQuery::new(&query);
225            match result {
226                Ok(payload) => {
227                    assert_eq!(payload.table_name, expected_table_name);
228                }
229                Err(_) => panic!("an error occured during parsing CREATE TABLE Statement"),
230            }
231        }
232    }
233}