1use sqlparser::ast::{
2 ColumnDef, ColumnOption, CreateTable, DataType, Expr, ObjectName, ObjectNamePart, Statement,
3 UnaryOperator, Value as AstValue,
4};
5
6use crate::error::{Result, SQLRiteError};
7use crate::sql::db::table::Value;
8
9fn is_vector_type(name: &ObjectName) -> bool {
14 name.0.len() == 1
15 && match &name.0[0] {
16 ObjectNamePart::Identifier(ident) => ident.value.eq_ignore_ascii_case("VECTOR"),
17 _ => false,
21 }
22}
23
24fn parse_vector_dim(args: &[String]) -> std::result::Result<usize, String> {
29 match args {
30 [] => Err("VECTOR requires a dimension, e.g. `VECTOR(384)`".to_string()),
31 [single] => {
32 let trimmed = single.trim();
33 match trimmed.parse::<usize>() {
34 Ok(d) if d > 0 => Ok(d),
35 Ok(_) => Err(format!("VECTOR dimension must be ≥ 1 (got `{trimmed}`)")),
36 Err(_) => Err(format!(
37 "VECTOR dimension must be a positive integer (got `{trimmed}`)"
38 )),
39 }
40 }
41 many => Err(format!(
42 "VECTOR takes exactly one dimension argument (got {})",
43 many.len()
44 )),
45 }
46}
47
48#[derive(PartialEq, Debug, Clone)]
51pub struct ParsedColumn {
52 pub name: String,
54 pub datatype: String,
56 pub is_pk: bool,
58 pub not_null: bool,
60 pub is_unique: bool,
62 pub default: Option<Value>,
66}
67
68#[derive(Debug)]
72pub struct CreateQuery {
73 pub table_name: String,
75 pub columns: Vec<ParsedColumn>,
77 pub if_not_exists: bool,
81}
82
83pub fn parse_one_column(col: &ColumnDef) -> Result<ParsedColumn> {
92 let name = col.name.to_string();
93
94 let datatype: String = match &col.data_type {
97 DataType::TinyInt(_)
98 | DataType::SmallInt(_)
99 | DataType::Int2(_)
100 | DataType::Int(_)
101 | DataType::Int4(_)
102 | DataType::Int8(_)
103 | DataType::Integer(_)
104 | DataType::BigInt(_) => "Integer".to_string(),
105 DataType::Boolean => "Bool".to_string(),
106 DataType::Text => "Text".to_string(),
107 DataType::Varchar(_bytes) => "Text".to_string(),
108 DataType::Real => "Real".to_string(),
109 DataType::Float(_precision) => "Real".to_string(),
110 DataType::Double(_) => "Real".to_string(),
111 DataType::Decimal(_) => "Real".to_string(),
112 DataType::JSON | DataType::JSONB => "Json".to_string(),
117 DataType::Custom(name, args) if is_vector_type(name) => match parse_vector_dim(args) {
122 Ok(dim) => format!("vector({dim})"),
123 Err(e) => {
124 return Err(SQLRiteError::General(format!(
125 "Invalid VECTOR column '{}': {e}",
126 col.name
127 )));
128 }
129 },
130 other => {
131 eprintln!("not matched on custom type: {other:?}");
132 "Invalid".to_string()
133 }
134 };
135
136 let mut is_pk: bool = false;
137 let mut is_unique: bool = false;
138 let mut not_null: bool = false;
139 let mut default: Option<Value> = None;
140 for column_option in &col.options {
141 match &column_option.option {
142 ColumnOption::PrimaryKey(_) => {
143 if datatype != "Real" && datatype != "Bool" {
146 is_pk = true;
147 is_unique = true;
148 not_null = true;
149 }
150 }
151 ColumnOption::Unique(_) => {
152 if datatype != "Real" && datatype != "Bool" {
155 is_unique = true;
156 }
157 }
158 ColumnOption::NotNull => {
159 not_null = true;
160 }
161 ColumnOption::Default(expr) => {
162 default = Some(eval_literal_default(expr, &datatype, &name)?);
163 }
164 _ => (),
165 };
166 }
167
168 Ok(ParsedColumn {
169 name,
170 datatype,
171 is_pk,
172 not_null,
173 is_unique,
174 default,
175 })
176}
177
178fn eval_literal_default(expr: &Expr, datatype: &str, col_name: &str) -> Result<Value> {
190 let value = match expr {
191 Expr::Value(v) => &v.value,
192 Expr::UnaryOp {
193 op: UnaryOperator::Minus,
194 expr: inner,
195 } => {
196 return match inner.as_ref() {
197 Expr::Value(v) => match &v.value {
198 AstValue::Number(n, _) => {
199 let neg = format!("-{n}");
200 coerce_number_default(&neg, datatype, col_name)
201 }
202 _ => Err(SQLRiteError::General(format!(
203 "DEFAULT for column '{col_name}' must be a literal value"
204 ))),
205 },
206 _ => Err(SQLRiteError::General(format!(
207 "DEFAULT for column '{col_name}' must be a literal value"
208 ))),
209 };
210 }
211 Expr::UnaryOp {
212 op: UnaryOperator::Plus,
213 expr: inner,
214 } => {
215 return eval_literal_default(inner, datatype, col_name);
216 }
217 _ => {
218 return Err(SQLRiteError::General(format!(
219 "DEFAULT for column '{col_name}' must be a literal value"
220 )));
221 }
222 };
223
224 match value {
225 AstValue::Null => Ok(Value::Null),
226 AstValue::Boolean(b) => {
227 if datatype == "Bool" {
228 Ok(Value::Bool(*b))
229 } else {
230 Err(SQLRiteError::General(format!(
231 "DEFAULT type mismatch for column '{col_name}': boolean is not a {datatype}"
232 )))
233 }
234 }
235 AstValue::SingleQuotedString(s) => {
236 if datatype == "Text" {
237 Ok(Value::Text(s.clone()))
238 } else if datatype == "Json" {
239 serde_json::from_str::<serde_json::Value>(s).map_err(|e| {
245 SQLRiteError::General(format!(
246 "DEFAULT type mismatch for column '{col_name}': '{s}' is not valid JSON: {e}"
247 ))
248 })?;
249 Ok(Value::Text(s.clone()))
250 } else {
251 Err(SQLRiteError::General(format!(
252 "DEFAULT type mismatch for column '{col_name}': text is not a {datatype}"
253 )))
254 }
255 }
256 AstValue::Number(n, _) => coerce_number_default(n, datatype, col_name),
257 _ => Err(SQLRiteError::General(format!(
258 "DEFAULT for column '{col_name}' must be a literal value"
259 ))),
260 }
261}
262
263fn coerce_number_default(n: &str, datatype: &str, col_name: &str) -> Result<Value> {
264 match datatype {
265 "Integer" => n.parse::<i64>().map(Value::Integer).map_err(|_| {
266 SQLRiteError::General(format!(
267 "DEFAULT type mismatch for column '{col_name}': '{n}' is not a valid INTEGER"
268 ))
269 }),
270 "Real" => n.parse::<f64>().map(Value::Real).map_err(|_| {
271 SQLRiteError::General(format!(
272 "DEFAULT type mismatch for column '{col_name}': '{n}' is not a valid REAL"
273 ))
274 }),
275 other => Err(SQLRiteError::General(format!(
276 "DEFAULT type mismatch for column '{col_name}': numeric literal is not a {other}"
277 ))),
278 }
279}
280
281impl CreateQuery {
282 pub fn new(statement: &Statement) -> Result<CreateQuery> {
283 match statement {
284 Statement::CreateTable(CreateTable {
286 name,
287 columns,
288 constraints,
289 if_not_exists,
290 ..
291 }) => {
292 let table_name = name;
293 let mut parsed_columns: Vec<ParsedColumn> = vec![];
294
295 for col in columns {
298 let name = col.name.to_string();
300 if parsed_columns.iter().any(|c| c.name == name) {
301 return Err(SQLRiteError::Internal(format!(
302 "Duplicate column name: {}",
303 &name
304 )));
305 }
306
307 let parsed = parse_one_column(col)?;
308
309 if parsed.is_pk && parsed_columns.iter().any(|c| c.is_pk) {
311 return Err(SQLRiteError::Internal(format!(
312 "Table '{}' has more than one primary key",
313 &table_name
314 )));
315 }
316
317 parsed_columns.push(parsed);
318 }
319 let _ = constraints;
327 Ok(CreateQuery {
328 table_name: table_name.to_string(),
329 columns: parsed_columns,
330 if_not_exists: *if_not_exists,
331 })
332 }
333
334 _ => Err(SQLRiteError::Internal("Error parsing query".to_string())),
335 }
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use crate::sql::*;
343
344 #[test]
345 fn create_table_validate_tablename_test() {
346 let sql_input = String::from(
347 "CREATE TABLE contacts (
348 id INTEGER PRIMARY KEY,
349 first_name TEXT NOT NULL,
350 last_name TEXT NOT NULl,
351 email TEXT NOT NULL UNIQUE
352 );",
353 );
354 let expected_table_name = String::from("contacts");
355
356 let dialect = SqlriteDialect::new();
357 let mut ast = Parser::parse_sql(&dialect, &sql_input).unwrap();
358
359 assert!(ast.len() == 1, "ast has more then one Statement");
360
361 let query = ast.pop().unwrap();
362
363 if let Statement::CreateTable(_) = query {
365 let result = CreateQuery::new(&query);
366 match result {
367 Ok(payload) => {
368 assert_eq!(payload.table_name, expected_table_name);
369 }
370 Err(_) => panic!("an error occured during parsing CREATE TABLE Statement"),
371 }
372 }
373 }
374
375 #[test]
378 fn create_query_captures_if_not_exists_flag() {
379 let dialect = SqlriteDialect::new();
380
381 let mut ast =
383 Parser::parse_sql(&dialect, "CREATE TABLE t (id INTEGER PRIMARY KEY);").unwrap();
384 let q = ast.pop().unwrap();
385 assert!(!CreateQuery::new(&q).unwrap().if_not_exists);
386
387 let mut ast = Parser::parse_sql(
389 &dialect,
390 "CREATE TABLE IF NOT EXISTS t (id INTEGER PRIMARY KEY);",
391 )
392 .unwrap();
393 let q = ast.pop().unwrap();
394 assert!(CreateQuery::new(&q).unwrap().if_not_exists);
395 }
396}