silent_db/core/
tables.rs

1use crate::core::dsl::SqlStatement;
2use crate::core::fields::Field;
3use crate::core::indices::IndexTrait;
4use anyhow::Result;
5use std::path::Path;
6use std::rc::Rc;
7
8pub trait TableUtil {
9    fn get_name(&self) -> String;
10    fn get_all_tables(&self) -> String;
11    fn get_table(&self, table: &str) -> String;
12    fn transform(&self, table: &SqlStatement) -> Result<Box<dyn Table>>;
13    fn generate_models(&self, tables: Vec<SqlStatement>, models_path: &Path) -> Result<()>;
14
15    /// 从字段字符串中检测字段类型和长度
16    fn detect_fields(&self, field_str: &str) -> DetectField {
17        let field_str = field_str.to_lowercase();
18        // 利用正则取出字段后的包含括号的长度数值
19        // 如 int(11) -> 11
20        // 如 varchar(255) -> 255
21        // 如 decimal(10, 2) -> 10, 2
22        let re = regex::Regex::new(r"\((\d+)(?:, (\d+))?\)").unwrap();
23        let length = if let Some(caps) = re.captures(&field_str) {
24            if caps.len() == 3 {
25                match (caps.get(1), caps.get(2)) {
26                    (Some(max_digits), Some(decimal_places)) => Some(
27                        (
28                            max_digits.as_str().parse::<u8>().unwrap_or(0),
29                            decimal_places.as_str().parse::<u8>().unwrap_or(0),
30                        )
31                            .into(),
32                    ),
33                    (Some(max_length), None) => {
34                        Some(max_length.as_str().parse::<u16>().unwrap_or(0).into())
35                    }
36                    (_, _) => None,
37                }
38            } else {
39                None
40            }
41        } else {
42            None
43        };
44        // 利用正则表达式取出首个左括号之前的字符串作为字段类型
45        // 如 int(11) -> int
46        // 如 varchar(255) -> varchar
47        // 如 decimal(10, 2) -> decimal
48        let re = regex::Regex::new(r"(\w+)(?:\(\d+(?:, \d+)?\))?$").unwrap();
49        let field_type = re
50            .captures(&field_str)
51            .unwrap()
52            .get(1)
53            .unwrap()
54            .as_str()
55            .to_string();
56        DetectField { field_type, length }
57    }
58
59    /// 从DetectField中获取字段类型和结构体类型
60    fn get_field_type(&self, detect_field: &DetectField) -> (&str, &str);
61}
62
63#[derive(Debug, Eq, PartialEq)]
64pub struct DetectField {
65    pub field_type: String,
66    pub length: Option<DetectFieldLength>,
67}
68
69#[derive(Debug, Eq, PartialEq)]
70pub enum DetectFieldLength {
71    MaxLength(u16),
72    MaxDigits(u8, u8),
73}
74
75impl From<u16> for DetectFieldLength {
76    fn from(length: u16) -> Self {
77        DetectFieldLength::MaxLength(length)
78    }
79}
80
81impl From<(u8, u8)> for DetectFieldLength {
82    fn from(digits: (u8, u8)) -> Self {
83        DetectFieldLength::MaxDigits(digits.0, digits.1)
84    }
85}
86
87pub trait Table {
88    fn get_name(&self) -> String;
89    fn get_fields(&self) -> Vec<Rc<dyn Field>>;
90    fn get_indices(&self) -> Vec<Rc<dyn IndexTrait>> {
91        vec![]
92    }
93    fn get_comment(&self) -> Option<String> {
94        None
95    }
96    fn get_create_sql(&self) -> String {
97        let mut sql = format!("CREATE TABLE `{}` (", self.get_name());
98        let fields: Vec<String> = self
99            .get_fields()
100            .iter()
101            .map(|field| field.get_create_sql())
102            .collect();
103        sql.push_str(&fields.join(", "));
104        if !self.get_indices().is_empty() {
105            sql.push_str(", ");
106            let indices: Vec<String> = self
107                .get_indices()
108                .iter()
109                .map(|index| index.get_create_sql())
110                .collect();
111            sql.push_str(&indices.join(", "));
112        }
113        sql.push(')');
114        if let Some(comment) = self.get_comment() {
115            sql.push_str(&format!(" COMMENT='{}'", comment));
116        }
117        sql.push(';');
118        sql
119    }
120    fn get_drop_sql(&self) -> String {
121        format!("DROP TABLE `{}`;", self.get_name())
122    }
123}
124
125pub trait TableManage {
126    fn get_manager(&self) -> Box<dyn Table> {
127        Self::manager()
128    }
129    fn manager() -> Box<dyn Table>;
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use crate::core::fields::{Field, FieldType};
136    use serde::{Deserialize, Serialize};
137
138    struct TestTable;
139
140    #[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
141    struct IntField {
142        name: String,
143        default: Option<String>,
144        nullable: bool,
145        primary_key: bool,
146        unique: bool,
147        auto_increment: bool,
148        comment: Option<String>,
149    }
150
151    impl Field for IntField {
152        fn get_name(&self) -> String {
153            self.name.clone()
154        }
155        fn get_type(&self) -> Box<dyn FieldType> {
156            Box::new(IntType)
157        }
158        fn get_default(&self) -> Option<String> {
159            self.default.clone()
160        }
161        fn get_nullable(&self) -> bool {
162            self.nullable
163        }
164        fn get_primary_key(&self) -> bool {
165            self.primary_key
166        }
167        fn get_unique(&self) -> bool {
168            self.unique
169        }
170        fn get_auto_increment(&self) -> bool {
171            self.auto_increment
172        }
173        fn get_comment(&self) -> Option<String> {
174            self.comment.clone()
175        }
176    }
177
178    struct IntType;
179
180    impl FieldType for IntType {
181        fn get_type_str(&self) -> String {
182            "INT".to_string()
183        }
184    }
185
186    impl Table for TestTable {
187        fn get_name(&self) -> String {
188            "test_table".to_string()
189        }
190        fn get_fields(&self) -> Vec<Rc<dyn Field>> {
191            let int = IntField {
192                name: "id".to_string(),
193                default: None,
194                nullable: false,
195                primary_key: true,
196                unique: false,
197                auto_increment: true,
198                comment: None,
199            };
200            vec![Rc::new(int)]
201        }
202        fn get_comment(&self) -> Option<String> {
203            Some("Test Table".to_string())
204        }
205    }
206
207    #[test]
208    fn test_get_create_sql() {
209        let table = TestTable;
210        assert_eq!(
211            table.get_create_sql(),
212            "CREATE TABLE `test_table` (`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT) COMMENT='Test Table';"
213        );
214    }
215
216    #[test]
217    fn test_get_drop_sql() {
218        let table = TestTable;
219        assert_eq!(table.get_drop_sql(), "DROP TABLE `test_table`;");
220    }
221}