sqlx_plus_core/
table_sql.rs

1use syn::{Field, Type};
2
3fn generate_field_definitions(fields: &[Field]) -> Vec<(String, String)> {
4    fields
5        .iter()
6        .map(|f| {
7            let field_name = f.ident.as_ref().unwrap().to_string();
8            let field_type = match &f.ty {
9                Type::Path(type_path) => {
10                    let segment = type_path.path.segments.last().unwrap();
11                    let sql_type = match segment.ident.to_string().as_str() {
12                        "i64" => {
13                            #[cfg(feature = "sqlite")]
14                            {
15                                "INTEGER"
16                            }
17                            #[cfg(any(feature = "postgres", feature = "mysql"))]
18                            {
19                                "BIGINT"
20                            }
21                        }
22                        "u64" => {
23                            #[cfg(feature = "mysql")]
24                            {
25                                "BIGINT UNSIGNED"
26                            }
27                            #[cfg(any(feature = "postgres", feature = "sqlite"))]
28                            {
29                                "BIGINT"
30                            }
31                        }
32                        "i32" | "u32" => {
33                            #[cfg(feature = "mysql")]
34                            {
35                                "INTEGER UNSIGNED"
36                            }
37                            #[cfg(any(feature = "postgres", feature = "sqlite"))]
38                            {
39                                "INTEGER"
40                            }
41                        }
42                        "i16" | "u16" => {
43                            #[cfg(feature = "mysql")]
44                            {
45                                "SMALLINT UNSIGNED"
46                            }
47                            #[cfg(any(feature = "postgres", feature = "sqlite"))]
48                            {
49                                "SMALLINT"
50                            }
51                        }
52                        "i8" | "u8" => {
53                            #[cfg(feature = "mysql")]
54                            {
55                                "TINYINT UNSIGNED"
56                            }
57                            #[cfg(any(feature = "postgres", feature = "sqlite"))]
58                            {
59                                "TINYINT"
60                            }
61                        }
62                        "f64" => "DOUBLE",
63                        "f32" => "FLOAT",
64                        "bool" => "BOOLEAN",
65                        "String" | "&str" => "TEXT",
66                        "Vec<u8>" => "BLOB",
67                        "chrono::NaiveDateTime" => {
68                            #[cfg(feature = "mysql")]
69                            {
70                                "DATETIME"
71                            }
72                            #[cfg(feature = "postgres")]
73                            {
74                                "TIMESTAMP"
75                            }
76                            #[cfg(feature = "sqlite")]
77                            {
78                                "TEXT"
79                            }
80                        }
81                        "chrono::NaiveDate" => "DATE",
82                        "chrono::NaiveTime" => "TIME",
83                        _ => panic!("Unsupported field type: {}", segment.ident),
84                    };
85                    sql_type
86                }
87                _ => panic!("Unsupported field type"),
88            };
89            (field_name, field_type.to_string())
90        })
91        .collect()
92}
93
94pub fn generate_create_table_sql(fields: &[Field], table_name: &str) -> String {
95    let field_definitions = generate_field_definitions(fields);
96    let create_table_sql = field_definitions
97        .iter()
98        .map(|(field_name, field_type)| format!("{} {}", field_name, field_type))
99        .collect::<Vec<_>>()
100        .join(", ");
101
102    format!(
103        "CREATE TABLE IF NOT EXISTS {} ({})",
104        table_name, create_table_sql
105    )
106}
107
108pub fn generate_alter_table_sql(fields: &[Field], table_name: &str) -> Vec<String> {
109    let field_definitions = generate_field_definitions(fields);
110    field_definitions
111        .iter()
112        .map(|(field_name, field_type)| {
113            format!(
114                "ALTER TABLE {} ADD COLUMN {} {}",
115                table_name, field_name, field_type
116            )
117        })
118        .collect()
119}