Skip to main content

schema_sql_generator/common/
column_constraint_generator.rs

1use crate::common::generator_context::GeneratorContext;
2use schema_model::model::column::Column;
3use schema_model::model::column_type::ColumnType;
4use schema_model::model::table::Table;
5use schema_model::model::types::BooleanMode;
6use std::hash::DefaultHasher;
7use std::hash::{Hash, Hasher};
8
9const CK_PREFIX: &str = "ck_";
10
11pub trait ColumnConstraintGenerator {
12    fn column_check_constraints(&self, table: &Table) -> Vec<String>;
13}
14
15pub struct DefaultColumnConstraintGenerator {
16    context: GeneratorContext,
17}
18
19impl DefaultColumnConstraintGenerator {
20    pub fn new(context: GeneratorContext) -> Self {
21        Self { context }
22    }
23
24    pub fn context(&self) -> &GeneratorContext {
25        &self.context
26    }
27
28    pub fn generate_constraint(&self, table: &Table, column: &Column) -> String {
29        let constraint_sql = self.check_constraint_sql(column);
30
31        if constraint_sql.is_some() {
32            return format!(
33                "   constraint {} {}",
34                self.constraint_name(table.name(), column.name()),
35                constraint_sql.unwrap()
36            );
37        }
38
39        String::new()
40    }
41
42    fn constraint_name(&self, table_name: &str, column_name: &str) -> String {
43        let table_name = table_name.to_lowercase();
44        let column_name = column_name.to_lowercase();
45        let hash = self.combined_hash(&table_name, &column_name);
46        let table_name = self.truncate_lower(&table_name, 9);
47        let column_name = self.truncate_lower(&column_name, 9);
48
49        format!("{}{}_{}_{}", CK_PREFIX, table_name, column_name, hash)
50    }
51
52    fn truncate_lower(&self, s: &str, max_len: usize) -> String {
53        s.to_lowercase().chars().take(max_len).collect()
54    }
55
56    fn combined_hash(&self, table_name: &str, column_name: &str) -> String {
57        let combined_name = format!("{}_{}", table_name, column_name);
58        let mut hasher = DefaultHasher::new();
59        combined_name.hash(&mut hasher);
60        format!("{:X}", hasher.finish())
61    }
62
63    fn check_constraint_sql(&self, column: &Column) -> Option<String> {
64        if column.column_type() == ColumnType::Boolean {
65            self.boolean_check_constraint(column)
66        } else if let Some(constraint) = column.check_constraint() {
67            Some(constraint.to_string())
68        } else if column.column_type() == ColumnType::Enum {
69            self.enum_check_constraint_sql(column)
70        } else if column.has_min_or_max_value() {
71            self.min_max_constraint_sql(column)
72        } else {
73            None
74        }
75    }
76
77    fn boolean_check_constraint(&self, column: &Column) -> Option<String> {
78        match self.context.settings().boolean_mode() {
79            BooleanMode::YesNo => Some(format!("check({} in ('Yes','No'))", column.name())),
80            BooleanMode::YN => Some(format!("check({} in ('Y','N'))", column.name())),
81            BooleanMode::Native => None,
82        }
83    }
84
85    fn enum_check_constraint_sql(&self, column: &Column) -> Option<String> {
86        let schema_name = column.schema_name();
87        let schema = self
88            .context
89            .settings()
90            .database_model()
91            .find_schema(schema_name);
92        let enum_type = column.enum_type();
93        let enum_values = schema.get_enum_type(enum_type?).values().clone();
94
95        let joined_values = enum_values
96            .iter()
97            .map(|value| format!("'{}'", value.code()))
98            .collect::<Vec<_>>()
99            .join(", ");
100
101        Some(format!("check({} in ({}))", column.name(), joined_values))
102    }
103
104    fn min_max_constraint_sql(&self, column: &Column) -> Option<String> {
105        let min_value = column.min_value();
106        let max_value = column.max_value();
107        let mut sql = String::from("check(");
108
109        if min_value.is_some() {
110            sql.push_str(column.name());
111            sql.push_str(" >= ");
112            sql.push_str(min_value.unwrap().to_string().as_str());
113        }
114
115        if min_value.is_some() && max_value.is_some() {
116            sql.push_str(" and ");
117        }
118
119        if max_value.is_some() {
120            sql.push_str(column.name());
121            sql.push_str(" <= ");
122            sql.push_str(max_value.unwrap().to_string().as_str());
123        }
124
125        sql.push_str(")");
126
127        Some(sql)
128    }
129}
130
131impl ColumnConstraintGenerator for DefaultColumnConstraintGenerator {
132    fn column_check_constraints(&self, table: &Table) -> Vec<String> {
133        let columns = table.columns_with_check_constraints(self.context.settings().boolean_mode());
134
135        columns
136            .iter()
137            .map(|column| self.generate_constraint(table, column))
138            .collect()
139    }
140}