1#![allow(clippy::new_without_default)]
7
8use vibesql_ast::{ColumnConstraintKind, ColumnDef, Expression, TableConstraint, TableConstraintKind};
9use vibesql_catalog::{ColumnSchema, TableSchema};
10
11use crate::errors::ExecutorError;
12
13pub struct ConstraintResult {
15 pub primary_key: Option<Vec<String>>,
17 pub unique_constraints: Vec<Vec<String>>,
19 pub check_constraints: Vec<(String, Expression)>,
21 pub not_null_columns: Vec<String>,
23}
24
25impl ConstraintResult {
26 pub fn new() -> Self {
28 Self {
29 primary_key: None,
30 unique_constraints: Vec::new(),
31 check_constraints: Vec::new(),
32 not_null_columns: Vec::new(),
33 }
34 }
35}
36
37pub struct ConstraintValidator;
39
40impl ConstraintValidator {
41 pub fn process_constraints(
56 columns: &[ColumnDef],
57 table_constraints: &[TableConstraint],
58 ) -> Result<ConstraintResult, ExecutorError> {
59 let mut result = ConstraintResult::new();
60 let mut constraint_counter = 0;
61
62 let mut has_column_level_pk = false;
64
65 for col_def in columns {
67 for constraint in &col_def.constraints {
68 match &constraint.kind {
69 ColumnConstraintKind::PrimaryKey => {
70 if has_column_level_pk {
71 return Err(ExecutorError::MultiplePrimaryKeys);
72 }
73 if result.primary_key.is_some() {
74 return Err(ExecutorError::MultiplePrimaryKeys);
75 }
76 result.primary_key = Some(vec![col_def.name.clone()]);
77 result.not_null_columns.push(col_def.name.clone());
78 has_column_level_pk = true;
79 }
80 ColumnConstraintKind::Unique => {
81 result.unique_constraints.push(vec![col_def.name.clone()]);
82 }
83 ColumnConstraintKind::Check(expr) => {
84 let constraint_name = format!("check_{}", constraint_counter);
85 constraint_counter += 1;
86 result.check_constraints.push((constraint_name, (**expr).clone()));
87 }
88 ColumnConstraintKind::NotNull => {
89 result.not_null_columns.push(col_def.name.clone());
90 }
91 ColumnConstraintKind::References { .. } => {
92 }
95 ColumnConstraintKind::AutoIncrement => {
96 }
100 ColumnConstraintKind::Key => {
101 }
105 }
106 }
107 }
108
109 for table_constraint in table_constraints {
111 match &table_constraint.kind {
112 TableConstraintKind::PrimaryKey { columns: pk_cols } => {
113 if result.primary_key.is_some() {
115 return Err(ExecutorError::MultiplePrimaryKeys);
116 }
117 let column_names: Vec<String> = pk_cols.iter().map(|c| c.column_name.clone()).collect();
119 result.primary_key = Some(column_names.clone());
120 for col_name in &column_names {
122 if !result.not_null_columns.contains(col_name) {
123 result.not_null_columns.push(col_name.clone());
124 }
125 }
126 }
127 TableConstraintKind::Unique { columns } => {
128 let column_names: Vec<String> = columns.iter().map(|c| c.column_name.clone()).collect();
130 result.unique_constraints.push(column_names);
131 }
132 TableConstraintKind::Check { expr } => {
133 let constraint_name = format!("check_{}", constraint_counter);
134 constraint_counter += 1;
135 result.check_constraints.push((constraint_name, (**expr).clone()));
136 }
137 TableConstraintKind::ForeignKey { .. } => {
138 }
141 TableConstraintKind::Fulltext { .. } => {
142 }
146 }
147 }
148
149 Ok(result)
150 }
151
152 pub fn apply_to_columns(columns: &mut [ColumnSchema], constraint_result: &ConstraintResult) {
161 for col_name in &constraint_result.not_null_columns {
163 if let Some(col) = columns.iter_mut().find(|c| c.name == *col_name) {
164 col.nullable = false;
165 }
166 }
167 }
168
169 pub fn apply_to_schema(table_schema: &mut TableSchema, constraint_result: &ConstraintResult) {
178 if let Some(pk) = &constraint_result.primary_key {
180 table_schema.primary_key = Some(pk.clone());
181 }
182
183 table_schema.unique_constraints = constraint_result.unique_constraints.clone();
185
186 table_schema.check_constraints = constraint_result.check_constraints.clone();
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use vibesql_ast::ColumnConstraint;
194 use vibesql_types::DataType;
195
196 use super::*;
197
198 fn make_column_def(name: &str, constraint_kinds: Vec<ColumnConstraintKind>) -> ColumnDef {
199 ColumnDef {
200 name: name.to_string(),
201 data_type: DataType::Integer,
202 nullable: true,
203 constraints: constraint_kinds
204 .into_iter()
205 .map(|kind| ColumnConstraint { name: None, kind })
206 .collect(),
207 default_value: None,
208 comment: None,
209 }
210 }
211
212 #[test]
213 fn test_column_level_primary_key() {
214 let columns = vec![make_column_def("id", vec![ColumnConstraintKind::PrimaryKey])];
215 let result = ConstraintValidator::process_constraints(&columns, &[]).unwrap();
216
217 assert_eq!(result.primary_key, Some(vec!["id".to_string()]));
218 assert!(result.not_null_columns.contains(&"id".to_string()));
219 }
220
221 #[test]
222 fn test_table_level_primary_key() {
223 let columns = vec![make_column_def("id", vec![]), make_column_def("tenant_id", vec![])];
224 let constraints = vec![TableConstraint {
225 name: None,
226 kind: TableConstraintKind::PrimaryKey {
227 columns: vec![
228 vibesql_ast::IndexColumn {
229 column_name: "id".to_string(),
230 direction: vibesql_ast::OrderDirection::Asc,
231 prefix_length: None,
232 },
233 vibesql_ast::IndexColumn {
234 column_name: "tenant_id".to_string(),
235 direction: vibesql_ast::OrderDirection::Asc,
236 prefix_length: None,
237 },
238 ],
239 },
240 }];
241
242 let result = ConstraintValidator::process_constraints(&columns, &constraints).unwrap();
243
244 assert_eq!(result.primary_key, Some(vec!["id".to_string(), "tenant_id".to_string()]));
245 assert!(result.not_null_columns.contains(&"id".to_string()));
246 assert!(result.not_null_columns.contains(&"tenant_id".to_string()));
247 }
248
249 #[test]
250 fn test_multiple_primary_keys_fails() {
251 let columns = vec![make_column_def("id", vec![ColumnConstraintKind::PrimaryKey])];
252 let constraints = vec![TableConstraint {
253 name: None,
254 kind: TableConstraintKind::PrimaryKey {
255 columns: vec![vibesql_ast::IndexColumn {
256 column_name: "id".to_string(),
257 direction: vibesql_ast::OrderDirection::Asc,
258 prefix_length: None,
259 }],
260 },
261 }];
262
263 let result = ConstraintValidator::process_constraints(&columns, &constraints);
264 assert!(matches!(result, Err(ExecutorError::MultiplePrimaryKeys)));
265 }
266
267 #[test]
268 fn test_unique_constraints() {
269 let columns = vec![
270 make_column_def("email", vec![ColumnConstraintKind::Unique]),
271 make_column_def("username", vec![]),
272 ];
273 let constraints = vec![TableConstraint {
274 name: None,
275 kind: TableConstraintKind::Unique {
276 columns: vec![vibesql_ast::IndexColumn {
277 column_name: "username".to_string(),
278 direction: vibesql_ast::OrderDirection::Asc,
279 prefix_length: None,
280 }],
281 },
282 }];
283
284 let result = ConstraintValidator::process_constraints(&columns, &constraints).unwrap();
285
286 assert_eq!(result.unique_constraints.len(), 2);
287 assert!(result.unique_constraints.contains(&vec!["email".to_string()]));
288 assert!(result.unique_constraints.contains(&vec!["username".to_string()]));
289 }
290
291 #[test]
292 fn test_check_constraints() {
293 use vibesql_types::SqlValue;
294
295 let check_expr = Expression::BinaryOp {
296 left: Box::new(Expression::ColumnRef { table: None, column: "age".to_string() }),
297 op: vibesql_ast::BinaryOperator::GreaterThan,
298 right: Box::new(Expression::Literal(SqlValue::Integer(0))),
299 };
300
301 let columns = vec![make_column_def(
302 "age",
303 vec![ColumnConstraintKind::Check(Box::new(check_expr.clone()))],
304 )];
305
306 let result = ConstraintValidator::process_constraints(&columns, &[]).unwrap();
307
308 assert_eq!(result.check_constraints.len(), 1);
309 assert_eq!(result.check_constraints[0].1, check_expr);
310 }
311
312 #[test]
313 fn test_apply_to_columns() {
314 let mut columns = vec![
315 ColumnSchema::new("id".to_string(), DataType::Integer, true),
316 ColumnSchema::new(
317 "name".to_string(),
318 DataType::Varchar { max_length: Some(100) },
319 true,
320 ),
321 ];
322
323 let mut result = ConstraintResult::new();
324 result.not_null_columns.push("id".to_string());
325
326 ConstraintValidator::apply_to_columns(&mut columns, &result);
327
328 assert!(!columns[0].nullable); assert!(columns[1].nullable); }
331}