1use sqlparser::ast::{
2 ColumnDef, ColumnOption, CreateTable, DataType, Expr, ObjectName, ObjectNamePart, Statement,
3 UnaryOperator, Value as AstValue,
4};
5
6use crate::error::{Result, SQLRiteError};
7use crate::sql::db::table::Value;
8
9fn is_vector_type(name: &ObjectName) -> bool {
14 name.0.len() == 1
15 && match &name.0[0] {
16 ObjectNamePart::Identifier(ident) => ident.value.eq_ignore_ascii_case("VECTOR"),
17 _ => false,
21 }
22}
23
24fn parse_vector_dim(args: &[String]) -> std::result::Result<usize, String> {
29 match args {
30 [] => Err("VECTOR requires a dimension, e.g. `VECTOR(384)`".to_string()),
31 [single] => {
32 let trimmed = single.trim();
33 match trimmed.parse::<usize>() {
34 Ok(d) if d > 0 => Ok(d),
35 Ok(_) => Err(format!("VECTOR dimension must be ≥ 1 (got `{trimmed}`)")),
36 Err(_) => Err(format!(
37 "VECTOR dimension must be a positive integer (got `{trimmed}`)"
38 )),
39 }
40 }
41 many => Err(format!(
42 "VECTOR takes exactly one dimension argument (got {})",
43 many.len()
44 )),
45 }
46}
47
48#[derive(PartialEq, Debug, Clone)]
51pub struct ParsedColumn {
52 pub name: String,
54 pub datatype: String,
56 pub is_pk: bool,
58 pub not_null: bool,
60 pub is_unique: bool,
62 pub default: Option<Value>,
66}
67
68#[derive(Debug)]
72pub struct CreateQuery {
73 pub table_name: String,
75 pub columns: Vec<ParsedColumn>,
77}
78
79pub fn parse_one_column(col: &ColumnDef) -> Result<ParsedColumn> {
88 let name = col.name.to_string();
89
90 let datatype: String = match &col.data_type {
93 DataType::TinyInt(_)
94 | DataType::SmallInt(_)
95 | DataType::Int2(_)
96 | DataType::Int(_)
97 | DataType::Int4(_)
98 | DataType::Int8(_)
99 | DataType::Integer(_)
100 | DataType::BigInt(_) => "Integer".to_string(),
101 DataType::Boolean => "Bool".to_string(),
102 DataType::Text => "Text".to_string(),
103 DataType::Varchar(_bytes) => "Text".to_string(),
104 DataType::Real => "Real".to_string(),
105 DataType::Float(_precision) => "Real".to_string(),
106 DataType::Double(_) => "Real".to_string(),
107 DataType::Decimal(_) => "Real".to_string(),
108 DataType::JSON | DataType::JSONB => "Json".to_string(),
113 DataType::Custom(name, args) if is_vector_type(name) => match parse_vector_dim(args) {
118 Ok(dim) => format!("vector({dim})"),
119 Err(e) => {
120 return Err(SQLRiteError::General(format!(
121 "Invalid VECTOR column '{}': {e}",
122 col.name
123 )));
124 }
125 },
126 other => {
127 eprintln!("not matched on custom type: {other:?}");
128 "Invalid".to_string()
129 }
130 };
131
132 let mut is_pk: bool = false;
133 let mut is_unique: bool = false;
134 let mut not_null: bool = false;
135 let mut default: Option<Value> = None;
136 for column_option in &col.options {
137 match &column_option.option {
138 ColumnOption::PrimaryKey(_) => {
139 if datatype != "Real" && datatype != "Bool" {
142 is_pk = true;
143 is_unique = true;
144 not_null = true;
145 }
146 }
147 ColumnOption::Unique(_) => {
148 if datatype != "Real" && datatype != "Bool" {
151 is_unique = true;
152 }
153 }
154 ColumnOption::NotNull => {
155 not_null = true;
156 }
157 ColumnOption::Default(expr) => {
158 default = Some(eval_literal_default(expr, &datatype, &name)?);
159 }
160 _ => (),
161 };
162 }
163
164 Ok(ParsedColumn {
165 name,
166 datatype,
167 is_pk,
168 not_null,
169 is_unique,
170 default,
171 })
172}
173
174fn eval_literal_default(expr: &Expr, datatype: &str, col_name: &str) -> Result<Value> {
186 let value = match expr {
187 Expr::Value(v) => &v.value,
188 Expr::UnaryOp {
189 op: UnaryOperator::Minus,
190 expr: inner,
191 } => {
192 return match inner.as_ref() {
193 Expr::Value(v) => match &v.value {
194 AstValue::Number(n, _) => {
195 let neg = format!("-{n}");
196 coerce_number_default(&neg, datatype, col_name)
197 }
198 _ => Err(SQLRiteError::General(format!(
199 "DEFAULT for column '{col_name}' must be a literal value"
200 ))),
201 },
202 _ => Err(SQLRiteError::General(format!(
203 "DEFAULT for column '{col_name}' must be a literal value"
204 ))),
205 };
206 }
207 Expr::UnaryOp {
208 op: UnaryOperator::Plus,
209 expr: inner,
210 } => {
211 return eval_literal_default(inner, datatype, col_name);
212 }
213 _ => {
214 return Err(SQLRiteError::General(format!(
215 "DEFAULT for column '{col_name}' must be a literal value"
216 )));
217 }
218 };
219
220 match value {
221 AstValue::Null => Ok(Value::Null),
222 AstValue::Boolean(b) => {
223 if datatype == "Bool" {
224 Ok(Value::Bool(*b))
225 } else {
226 Err(SQLRiteError::General(format!(
227 "DEFAULT type mismatch for column '{col_name}': boolean is not a {datatype}"
228 )))
229 }
230 }
231 AstValue::SingleQuotedString(s) => {
232 if datatype == "Text" {
233 Ok(Value::Text(s.clone()))
234 } else if datatype == "Json" {
235 serde_json::from_str::<serde_json::Value>(s).map_err(|e| {
241 SQLRiteError::General(format!(
242 "DEFAULT type mismatch for column '{col_name}': '{s}' is not valid JSON: {e}"
243 ))
244 })?;
245 Ok(Value::Text(s.clone()))
246 } else {
247 Err(SQLRiteError::General(format!(
248 "DEFAULT type mismatch for column '{col_name}': text is not a {datatype}"
249 )))
250 }
251 }
252 AstValue::Number(n, _) => coerce_number_default(n, datatype, col_name),
253 _ => Err(SQLRiteError::General(format!(
254 "DEFAULT for column '{col_name}' must be a literal value"
255 ))),
256 }
257}
258
259fn coerce_number_default(n: &str, datatype: &str, col_name: &str) -> Result<Value> {
260 match datatype {
261 "Integer" => n.parse::<i64>().map(Value::Integer).map_err(|_| {
262 SQLRiteError::General(format!(
263 "DEFAULT type mismatch for column '{col_name}': '{n}' is not a valid INTEGER"
264 ))
265 }),
266 "Real" => n.parse::<f64>().map(Value::Real).map_err(|_| {
267 SQLRiteError::General(format!(
268 "DEFAULT type mismatch for column '{col_name}': '{n}' is not a valid REAL"
269 ))
270 }),
271 other => Err(SQLRiteError::General(format!(
272 "DEFAULT type mismatch for column '{col_name}': numeric literal is not a {other}"
273 ))),
274 }
275}
276
277impl CreateQuery {
278 pub fn new(statement: &Statement) -> Result<CreateQuery> {
279 match statement {
280 Statement::CreateTable(CreateTable {
282 name,
283 columns,
284 constraints,
285 ..
286 }) => {
287 let table_name = name;
288 let mut parsed_columns: Vec<ParsedColumn> = vec![];
289
290 for col in columns {
293 let name = col.name.to_string();
295 if parsed_columns.iter().any(|c| c.name == name) {
296 return Err(SQLRiteError::Internal(format!(
297 "Duplicate column name: {}",
298 &name
299 )));
300 }
301
302 let parsed = parse_one_column(col)?;
303
304 if parsed.is_pk && parsed_columns.iter().any(|c| c.is_pk) {
306 return Err(SQLRiteError::Internal(format!(
307 "Table '{}' has more than one primary key",
308 &table_name
309 )));
310 }
311
312 parsed_columns.push(parsed);
313 }
314 let _ = constraints;
322 Ok(CreateQuery {
323 table_name: table_name.to_string(),
324 columns: parsed_columns,
325 })
326 }
327
328 _ => Err(SQLRiteError::Internal("Error parsing query".to_string())),
329 }
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use crate::sql::*;
337
338 #[test]
339 fn create_table_validate_tablename_test() {
340 let sql_input = String::from(
341 "CREATE TABLE contacts (
342 id INTEGER PRIMARY KEY,
343 first_name TEXT NOT NULL,
344 last_name TEXT NOT NULl,
345 email TEXT NOT NULL UNIQUE
346 );",
347 );
348 let expected_table_name = String::from("contacts");
349
350 let dialect = SqlriteDialect::new();
351 let mut ast = Parser::parse_sql(&dialect, &sql_input).unwrap();
352
353 assert!(ast.len() == 1, "ast has more then one Statement");
354
355 let query = ast.pop().unwrap();
356
357 if let Statement::CreateTable(_) = query {
359 let result = CreateQuery::new(&query);
360 match result {
361 Ok(payload) => {
362 assert_eq!(payload.table_name, expected_table_name);
363 }
364 Err(_) => panic!("an error occured during parsing CREATE TABLE Statement"),
365 }
366 }
367 }
368}