1#![allow(clippy::new_without_default)]
7
8use vibesql_ast::{
9 pretty_print::ToSql, ColumnConstraintKind, ColumnDef, Expression, TableConstraint,
10 TableConstraintKind,
11};
12use vibesql_catalog::{ColumnSchema, TableSchema};
13use vibesql_types::DataType;
14
15use crate::errors::ExecutorError;
16
17pub struct ConstraintResult {
19 pub primary_key: Option<Vec<String>>,
21 pub unique_constraints: Vec<Vec<String>>,
23 pub check_constraints: Vec<(String, Expression)>,
25 pub not_null_columns: Vec<String>,
27}
28
29impl ConstraintResult {
30 pub fn new() -> Self {
32 Self {
33 primary_key: None,
34 unique_constraints: Vec::new(),
35 check_constraints: Vec::new(),
36 not_null_columns: Vec::new(),
37 }
38 }
39}
40
41pub struct ConstraintValidator;
43
44impl ConstraintValidator {
45 pub fn process_constraints(
60 columns: &[ColumnDef],
61 table_constraints: &[TableConstraint],
62 ) -> Result<ConstraintResult, ExecutorError> {
63 let mut result = ConstraintResult::new();
64
65 let mut has_column_level_pk = false;
67
68 for col_def in columns {
70 for constraint in &col_def.constraints {
71 match &constraint.kind {
72 ColumnConstraintKind::PrimaryKey { .. } => {
73 if has_column_level_pk {
74 return Err(ExecutorError::MultiplePrimaryKeys);
75 }
76 if result.primary_key.is_some() {
77 return Err(ExecutorError::MultiplePrimaryKeys);
78 }
79 result.primary_key = Some(vec![col_def.name.clone()]);
80 if col_def.data_type == DataType::Integer {
83 result.not_null_columns.push(col_def.name.clone());
84 }
85 has_column_level_pk = true;
86 }
87 ColumnConstraintKind::Unique { .. } => {
88 result.unique_constraints.push(vec![col_def.name.clone()]);
89 }
90 ColumnConstraintKind::Check(expr) => {
91 let constraint_name = constraint
94 .name
95 .clone()
96 .unwrap_or_else(|| expr.to_sql());
97 result.check_constraints.push((constraint_name, (**expr).clone()));
98 }
99 ColumnConstraintKind::NotNull
100 | ColumnConstraintKind::NotNullWithConflict { .. } => {
101 result.not_null_columns.push(col_def.name.clone());
102 }
103 ColumnConstraintKind::References { .. } => {
104 }
107 ColumnConstraintKind::AutoIncrement => {
108 }
112 ColumnConstraintKind::Key => {
113 }
117 ColumnConstraintKind::Collate(_) => {
118 }
122 }
123 }
124 }
125
126 for table_constraint in table_constraints {
128 match &table_constraint.kind {
129 TableConstraintKind::PrimaryKey { columns: pk_cols, .. } => {
130 if result.primary_key.is_some() {
132 return Err(ExecutorError::MultiplePrimaryKeys);
133 }
134 let column_names: Vec<String> =
136 pk_cols.iter().map(|c| c.expect_column_name().to_string()).collect();
137 result.primary_key = Some(column_names.clone());
138 for col_name in &column_names {
141 if let Some(col_def) = columns.iter().find(|c| &c.name == col_name) {
142 if col_def.data_type == DataType::Integer
143 && !result.not_null_columns.contains(col_name)
144 {
145 result.not_null_columns.push(col_name.to_string());
146 }
147 }
148 }
149 }
150 TableConstraintKind::Unique { columns, .. } => {
151 let column_names: Vec<String> =
153 columns.iter().map(|c| c.expect_column_name().to_string()).collect();
154 result.unique_constraints.push(column_names);
155 }
156 TableConstraintKind::Check { expr } => {
157 let constraint_name = table_constraint
160 .name
161 .clone()
162 .unwrap_or_else(|| expr.to_sql());
163 result.check_constraints.push((constraint_name, (**expr).clone()));
164 }
165 TableConstraintKind::ForeignKey { .. } => {
166 }
169 TableConstraintKind::Fulltext { .. } => {
170 }
174 }
175 }
176
177 Ok(result)
178 }
179
180 pub fn apply_to_columns(columns: &mut [ColumnSchema], constraint_result: &ConstraintResult) {
189 for col_name in &constraint_result.not_null_columns {
191 if let Some(col) = columns.iter_mut().find(|c| c.name == *col_name) {
192 col.nullable = false;
193 }
194 }
195 }
196
197 pub fn apply_to_schema(table_schema: &mut TableSchema, constraint_result: &ConstraintResult) {
206 if let Some(pk) = &constraint_result.primary_key {
208 table_schema.primary_key = Some(pk.clone());
209 }
210
211 table_schema.unique_constraints = constraint_result.unique_constraints.clone();
213
214 table_schema.check_constraints = constraint_result.check_constraints.clone();
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use vibesql_ast::ColumnConstraint;
222 use vibesql_types::DataType;
223
224 use super::*;
225
226 fn make_column_def(name: &str, constraint_kinds: Vec<ColumnConstraintKind>) -> ColumnDef {
227 make_column_def_with_type(name, DataType::Integer, constraint_kinds)
228 }
229
230 fn make_column_def_with_type(
231 name: &str,
232 data_type: DataType,
233 constraint_kinds: Vec<ColumnConstraintKind>,
234 ) -> ColumnDef {
235 ColumnDef {
236 name: name.to_string(),
237 data_type,
238 nullable: true,
239 constraints: constraint_kinds
240 .into_iter()
241 .map(|kind| ColumnConstraint { name: None, kind })
242 .collect(),
243 default_value: None,
244 comment: None,
245 generated_expr: None, is_exact_integer_type: false,
246 }
247 }
248
249 #[test]
250 fn test_column_level_primary_key() {
251 let columns = vec![make_column_def(
252 "id",
253 vec![ColumnConstraintKind::PrimaryKey { on_conflict: None }],
254 )];
255 let result = ConstraintValidator::process_constraints(&columns, &[]).unwrap();
256
257 assert_eq!(result.primary_key, Some(vec!["id".to_string()]));
258 assert!(result.not_null_columns.contains(&"id".to_string()));
259 }
260
261 #[test]
262 fn test_table_level_primary_key() {
263 let columns = vec![make_column_def("id", vec![]), make_column_def("tenant_id", vec![])];
264 let constraints = vec![TableConstraint {
265 name: None,
266 kind: TableConstraintKind::PrimaryKey {
267 columns: vec![
268 vibesql_ast::IndexColumn::Column {
269 column_name: "id".to_string(),
270 direction: vibesql_ast::OrderDirection::Asc,
271 prefix_length: None,
272 },
273 vibesql_ast::IndexColumn::Column {
274 column_name: "tenant_id".to_string(),
275 direction: vibesql_ast::OrderDirection::Asc,
276 prefix_length: None,
277 },
278 ],
279 on_conflict: None,
280 },
281 }];
282
283 let result = ConstraintValidator::process_constraints(&columns, &constraints).unwrap();
284
285 assert_eq!(result.primary_key, Some(vec!["id".to_string(), "tenant_id".to_string()]));
286 assert!(result.not_null_columns.contains(&"id".to_string()));
287 assert!(result.not_null_columns.contains(&"tenant_id".to_string()));
288 }
289
290 #[test]
291 fn test_multiple_primary_keys_fails() {
292 let columns = vec![make_column_def(
293 "id",
294 vec![ColumnConstraintKind::PrimaryKey { on_conflict: None }],
295 )];
296 let constraints = vec![TableConstraint {
297 name: None,
298 kind: TableConstraintKind::PrimaryKey {
299 columns: vec![vibesql_ast::IndexColumn::Column {
300 column_name: "id".to_string(),
301 direction: vibesql_ast::OrderDirection::Asc,
302 prefix_length: None,
303 }],
304 on_conflict: None,
305 },
306 }];
307
308 let result = ConstraintValidator::process_constraints(&columns, &constraints);
309 assert!(matches!(result, Err(ExecutorError::MultiplePrimaryKeys)));
310 }
311
312 #[test]
313 fn test_unique_constraints() {
314 let columns = vec![
315 make_column_def("email", vec![ColumnConstraintKind::Unique { on_conflict: None }]),
316 make_column_def("username", vec![]),
317 ];
318 let constraints = vec![TableConstraint {
319 name: None,
320 kind: TableConstraintKind::Unique {
321 columns: vec![vibesql_ast::IndexColumn::Column {
322 column_name: "username".to_string(),
323 direction: vibesql_ast::OrderDirection::Asc,
324 prefix_length: None,
325 }],
326 on_conflict: None,
327 },
328 }];
329
330 let result = ConstraintValidator::process_constraints(&columns, &constraints).unwrap();
331
332 assert_eq!(result.unique_constraints.len(), 2);
333 assert!(result.unique_constraints.contains(&vec!["email".to_string()]));
334 assert!(result.unique_constraints.contains(&vec!["username".to_string()]));
335 }
336
337 #[test]
338 fn test_check_constraints() {
339 use vibesql_types::SqlValue;
340
341 let check_expr = Expression::BinaryOp {
342 left: Box::new(Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
343 "age", false,
344 ))),
345 op: vibesql_ast::BinaryOperator::GreaterThan,
346 right: Box::new(Expression::Literal(SqlValue::Integer(0))),
347 };
348
349 let columns = vec![make_column_def(
350 "age",
351 vec![ColumnConstraintKind::Check(Box::new(check_expr.clone()))],
352 )];
353
354 let result = ConstraintValidator::process_constraints(&columns, &[]).unwrap();
355
356 assert_eq!(result.check_constraints.len(), 1);
357 assert_eq!(result.check_constraints[0].1, check_expr);
358 }
359
360 #[test]
361 fn test_apply_to_columns() {
362 let mut columns = vec![
363 ColumnSchema::new("id".to_string(), DataType::Integer, true),
364 ColumnSchema::new(
365 "name".to_string(),
366 DataType::Varchar { max_length: Some(100) },
367 true,
368 ),
369 ];
370
371 let mut result = ConstraintResult::new();
372 result.not_null_columns.push("id".to_string());
373
374 ConstraintValidator::apply_to_columns(&mut columns, &result);
375
376 assert!(!columns[0].nullable); assert!(columns[1].nullable); }
379
380 #[test]
383 fn test_text_primary_key_allows_null() {
384 let columns = vec![make_column_def_with_type(
386 "name",
387 DataType::Varchar { max_length: None },
388 vec![ColumnConstraintKind::PrimaryKey { on_conflict: None }],
389 )];
390 let result = ConstraintValidator::process_constraints(&columns, &[]).unwrap();
391
392 assert_eq!(result.primary_key, Some(vec!["name".to_string()]));
393 assert!(!result.not_null_columns.contains(&"name".to_string()));
395 }
396
397 #[test]
398 fn test_typeless_primary_key_allows_null() {
399 let columns = vec![make_column_def_with_type(
401 "c",
402 DataType::Varchar { max_length: None },
403 vec![ColumnConstraintKind::PrimaryKey { on_conflict: None }],
404 )];
405 let result = ConstraintValidator::process_constraints(&columns, &[]).unwrap();
406
407 assert_eq!(result.primary_key, Some(vec!["c".to_string()]));
408 assert!(!result.not_null_columns.contains(&"c".to_string()));
409 }
410
411 #[test]
412 fn test_integer_primary_key_has_not_null() {
413 let columns = vec![make_column_def_with_type(
415 "id",
416 DataType::Integer,
417 vec![ColumnConstraintKind::PrimaryKey { on_conflict: None }],
418 )];
419 let result = ConstraintValidator::process_constraints(&columns, &[]).unwrap();
420
421 assert_eq!(result.primary_key, Some(vec!["id".to_string()]));
422 assert!(result.not_null_columns.contains(&"id".to_string()));
423 }
424
425 #[test]
426 fn test_table_level_pk_with_mixed_types() {
427 let columns = vec![
429 make_column_def_with_type("id", DataType::Integer, vec![]),
430 make_column_def_with_type("code", DataType::Varchar { max_length: None }, vec![]),
431 ];
432 let constraints = vec![TableConstraint {
433 name: None,
434 kind: TableConstraintKind::PrimaryKey {
435 columns: vec![
436 vibesql_ast::IndexColumn::Column {
437 column_name: "id".to_string(),
438 direction: vibesql_ast::OrderDirection::Asc,
439 prefix_length: None,
440 },
441 vibesql_ast::IndexColumn::Column {
442 column_name: "code".to_string(),
443 direction: vibesql_ast::OrderDirection::Asc,
444 prefix_length: None,
445 },
446 ],
447 on_conflict: None,
448 },
449 }];
450
451 let result = ConstraintValidator::process_constraints(&columns, &constraints).unwrap();
452
453 assert_eq!(result.primary_key, Some(vec!["id".to_string(), "code".to_string()]));
454 assert!(result.not_null_columns.contains(&"id".to_string()));
456 assert!(!result.not_null_columns.contains(&"code".to_string()));
457 }
458
459 #[test]
460 fn test_real_primary_key_allows_null() {
461 let columns = vec![make_column_def_with_type(
463 "value",
464 DataType::Real,
465 vec![ColumnConstraintKind::PrimaryKey { on_conflict: None }],
466 )];
467 let result = ConstraintValidator::process_constraints(&columns, &[]).unwrap();
468
469 assert_eq!(result.primary_key, Some(vec!["value".to_string()]));
470 assert!(!result.not_null_columns.contains(&"value".to_string()));
471 }
472
473 #[test]
474 fn test_bigint_primary_key_allows_null() {
475 let columns = vec![make_column_def_with_type(
477 "big_id",
478 DataType::Bigint,
479 vec![ColumnConstraintKind::PrimaryKey { on_conflict: None }],
480 )];
481 let result = ConstraintValidator::process_constraints(&columns, &[]).unwrap();
482
483 assert_eq!(result.primary_key, Some(vec!["big_id".to_string()]));
484 assert!(!result.not_null_columns.contains(&"big_id".to_string()));
486 }
487}