1#![allow(clippy::new_without_default)]
7
8use vibesql_ast::{
9 ColumnConstraintKind, ColumnDef, Expression, TableConstraint, TableConstraintKind,
10};
11use vibesql_catalog::{ColumnSchema, TableSchema};
12
13use crate::errors::ExecutorError;
14
15pub struct ConstraintResult {
17 pub primary_key: Option<Vec<String>>,
19 pub unique_constraints: Vec<Vec<String>>,
21 pub check_constraints: Vec<(String, Expression)>,
23 pub not_null_columns: Vec<String>,
25}
26
27impl ConstraintResult {
28 pub fn new() -> Self {
30 Self {
31 primary_key: None,
32 unique_constraints: Vec::new(),
33 check_constraints: Vec::new(),
34 not_null_columns: Vec::new(),
35 }
36 }
37}
38
39pub struct ConstraintValidator;
41
42impl ConstraintValidator {
43 pub fn process_constraints(
58 columns: &[ColumnDef],
59 table_constraints: &[TableConstraint],
60 ) -> Result<ConstraintResult, ExecutorError> {
61 let mut result = ConstraintResult::new();
62 let mut constraint_counter = 0;
63
64 let mut has_column_level_pk = false;
66
67 for col_def in columns {
69 for constraint in &col_def.constraints {
70 match &constraint.kind {
71 ColumnConstraintKind::PrimaryKey => {
72 if has_column_level_pk {
73 return Err(ExecutorError::MultiplePrimaryKeys);
74 }
75 if result.primary_key.is_some() {
76 return Err(ExecutorError::MultiplePrimaryKeys);
77 }
78 result.primary_key = Some(vec![col_def.name.clone()]);
79 result.not_null_columns.push(col_def.name.clone());
80 has_column_level_pk = true;
81 }
82 ColumnConstraintKind::Unique => {
83 result.unique_constraints.push(vec![col_def.name.clone()]);
84 }
85 ColumnConstraintKind::Check(expr) => {
86 let constraint_name = format!("check_{}", constraint_counter);
87 constraint_counter += 1;
88 result.check_constraints.push((constraint_name, (**expr).clone()));
89 }
90 ColumnConstraintKind::NotNull => {
91 result.not_null_columns.push(col_def.name.clone());
92 }
93 ColumnConstraintKind::References { .. } => {
94 }
97 ColumnConstraintKind::AutoIncrement => {
98 }
102 ColumnConstraintKind::Key => {
103 }
107 }
108 }
109 }
110
111 for table_constraint in table_constraints {
113 match &table_constraint.kind {
114 TableConstraintKind::PrimaryKey { columns: pk_cols } => {
115 if result.primary_key.is_some() {
117 return Err(ExecutorError::MultiplePrimaryKeys);
118 }
119 let column_names: Vec<String> =
121 pk_cols.iter().map(|c| c.column_name.clone()).collect();
122 result.primary_key = Some(column_names.clone());
123 for col_name in &column_names {
125 if !result.not_null_columns.contains(col_name) {
126 result.not_null_columns.push(col_name.clone());
127 }
128 }
129 }
130 TableConstraintKind::Unique { columns } => {
131 let column_names: Vec<String> =
133 columns.iter().map(|c| c.column_name.clone()).collect();
134 result.unique_constraints.push(column_names);
135 }
136 TableConstraintKind::Check { expr } => {
137 let constraint_name = format!("check_{}", constraint_counter);
138 constraint_counter += 1;
139 result.check_constraints.push((constraint_name, (**expr).clone()));
140 }
141 TableConstraintKind::ForeignKey { .. } => {
142 }
145 TableConstraintKind::Fulltext { .. } => {
146 }
150 }
151 }
152
153 Ok(result)
154 }
155
156 pub fn apply_to_columns(columns: &mut [ColumnSchema], constraint_result: &ConstraintResult) {
165 for col_name in &constraint_result.not_null_columns {
167 if let Some(col) = columns.iter_mut().find(|c| c.name == *col_name) {
168 col.nullable = false;
169 }
170 }
171 }
172
173 pub fn apply_to_schema(table_schema: &mut TableSchema, constraint_result: &ConstraintResult) {
182 if let Some(pk) = &constraint_result.primary_key {
184 table_schema.primary_key = Some(pk.clone());
185 }
186
187 table_schema.unique_constraints = constraint_result.unique_constraints.clone();
189
190 table_schema.check_constraints = constraint_result.check_constraints.clone();
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use vibesql_ast::ColumnConstraint;
198 use vibesql_types::DataType;
199
200 use super::*;
201
202 fn make_column_def(name: &str, constraint_kinds: Vec<ColumnConstraintKind>) -> ColumnDef {
203 ColumnDef {
204 name: name.to_string(),
205 data_type: DataType::Integer,
206 nullable: true,
207 constraints: constraint_kinds
208 .into_iter()
209 .map(|kind| ColumnConstraint { name: None, kind })
210 .collect(),
211 default_value: None,
212 comment: None,
213 }
214 }
215
216 #[test]
217 fn test_column_level_primary_key() {
218 let columns = vec![make_column_def("id", vec![ColumnConstraintKind::PrimaryKey])];
219 let result = ConstraintValidator::process_constraints(&columns, &[]).unwrap();
220
221 assert_eq!(result.primary_key, Some(vec!["id".to_string()]));
222 assert!(result.not_null_columns.contains(&"id".to_string()));
223 }
224
225 #[test]
226 fn test_table_level_primary_key() {
227 let columns = vec![make_column_def("id", vec![]), make_column_def("tenant_id", vec![])];
228 let constraints = vec![TableConstraint {
229 name: None,
230 kind: TableConstraintKind::PrimaryKey {
231 columns: vec![
232 vibesql_ast::IndexColumn {
233 column_name: "id".to_string(),
234 direction: vibesql_ast::OrderDirection::Asc,
235 prefix_length: None,
236 },
237 vibesql_ast::IndexColumn {
238 column_name: "tenant_id".to_string(),
239 direction: vibesql_ast::OrderDirection::Asc,
240 prefix_length: None,
241 },
242 ],
243 },
244 }];
245
246 let result = ConstraintValidator::process_constraints(&columns, &constraints).unwrap();
247
248 assert_eq!(result.primary_key, Some(vec!["id".to_string(), "tenant_id".to_string()]));
249 assert!(result.not_null_columns.contains(&"id".to_string()));
250 assert!(result.not_null_columns.contains(&"tenant_id".to_string()));
251 }
252
253 #[test]
254 fn test_multiple_primary_keys_fails() {
255 let columns = vec![make_column_def("id", vec![ColumnConstraintKind::PrimaryKey])];
256 let constraints = vec![TableConstraint {
257 name: None,
258 kind: TableConstraintKind::PrimaryKey {
259 columns: vec![vibesql_ast::IndexColumn {
260 column_name: "id".to_string(),
261 direction: vibesql_ast::OrderDirection::Asc,
262 prefix_length: None,
263 }],
264 },
265 }];
266
267 let result = ConstraintValidator::process_constraints(&columns, &constraints);
268 assert!(matches!(result, Err(ExecutorError::MultiplePrimaryKeys)));
269 }
270
271 #[test]
272 fn test_unique_constraints() {
273 let columns = vec![
274 make_column_def("email", vec![ColumnConstraintKind::Unique]),
275 make_column_def("username", vec![]),
276 ];
277 let constraints = vec![TableConstraint {
278 name: None,
279 kind: TableConstraintKind::Unique {
280 columns: vec![vibesql_ast::IndexColumn {
281 column_name: "username".to_string(),
282 direction: vibesql_ast::OrderDirection::Asc,
283 prefix_length: None,
284 }],
285 },
286 }];
287
288 let result = ConstraintValidator::process_constraints(&columns, &constraints).unwrap();
289
290 assert_eq!(result.unique_constraints.len(), 2);
291 assert!(result.unique_constraints.contains(&vec!["email".to_string()]));
292 assert!(result.unique_constraints.contains(&vec!["username".to_string()]));
293 }
294
295 #[test]
296 fn test_check_constraints() {
297 use vibesql_types::SqlValue;
298
299 let check_expr = Expression::BinaryOp {
300 left: Box::new(Expression::ColumnRef { table: None, column: "age".to_string() }),
301 op: vibesql_ast::BinaryOperator::GreaterThan,
302 right: Box::new(Expression::Literal(SqlValue::Integer(0))),
303 };
304
305 let columns = vec![make_column_def(
306 "age",
307 vec![ColumnConstraintKind::Check(Box::new(check_expr.clone()))],
308 )];
309
310 let result = ConstraintValidator::process_constraints(&columns, &[]).unwrap();
311
312 assert_eq!(result.check_constraints.len(), 1);
313 assert_eq!(result.check_constraints[0].1, check_expr);
314 }
315
316 #[test]
317 fn test_apply_to_columns() {
318 let mut columns = vec![
319 ColumnSchema::new("id".to_string(), DataType::Integer, true),
320 ColumnSchema::new(
321 "name".to_string(),
322 DataType::Varchar { max_length: Some(100) },
323 true,
324 ),
325 ];
326
327 let mut result = ConstraintResult::new();
328 result.not_null_columns.push("id".to_string());
329
330 ConstraintValidator::apply_to_columns(&mut columns, &result);
331
332 assert!(!columns[0].nullable); assert!(columns[1].nullable); }
335}