1pub mod mysql;
2pub mod postgres;
3pub mod sqlite;
4
5use std::collections::HashMap;
6
7use crate::cli::DatabaseKind;
8use crate::introspect::{ColumnInfo, SchemaInfo};
9
10#[derive(Debug, Clone)]
12pub struct RustType {
13 pub path: String,
14 pub needs_import: Option<String>,
15}
16
17impl RustType {
18 pub fn simple(path: &str) -> Self {
19 Self {
20 path: path.to_string(),
21 needs_import: None,
22 }
23 }
24
25 pub fn with_import(path: &str, import: &str) -> Self {
26 Self {
27 path: path.to_string(),
28 needs_import: Some(import.to_string()),
29 }
30 }
31
32 pub fn wrap_option(self) -> Self {
33 Self {
34 path: format!("Option<{}>", self.path),
35 needs_import: self.needs_import,
36 }
37 }
38
39 pub fn wrap_vec(self) -> Self {
40 Self {
41 path: format!("Vec<{}>", self.path),
42 needs_import: self.needs_import,
43 }
44 }
45}
46
47pub fn map_column(
48 col: &ColumnInfo,
49 db_kind: DatabaseKind,
50 schema_info: &SchemaInfo,
51 overrides: &HashMap<String, String>,
52) -> RustType {
53 if let Some(override_type) = overrides.get(&col.udt_name) {
55 let rt = RustType::simple(override_type);
56 return if col.is_nullable { rt.wrap_option() } else { rt };
57 }
58
59 let base = match db_kind {
60 DatabaseKind::Postgres => postgres::map_type(&col.udt_name, schema_info),
61 DatabaseKind::Mysql => mysql::map_type(&col.data_type, &col.udt_name),
62 DatabaseKind::Sqlite => sqlite::map_type(&col.udt_name),
63 };
64
65 if col.is_nullable {
66 base.wrap_option()
67 } else {
68 base
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75 use crate::introspect::SchemaInfo;
76 use std::collections::HashMap;
77
78 fn make_col(udt_name: &str, data_type: &str, nullable: bool) -> ColumnInfo {
79 ColumnInfo {
80 name: "test".to_string(),
81 data_type: data_type.to_string(),
82 udt_name: udt_name.to_string(),
83 is_nullable: nullable,
84 is_primary_key: false,
85 ordinal_position: 0,
86 schema_name: "public".to_string(),
87 column_default: None,
88 }
89 }
90
91 #[test]
94 fn test_simple_creates_without_import() {
95 let rt = RustType::simple("i32");
96 assert_eq!(rt.path, "i32");
97 assert!(rt.needs_import.is_none());
98 }
99
100 #[test]
101 fn test_simple_path_correct() {
102 let rt = RustType::simple("String");
103 assert_eq!(rt.path, "String");
104 }
105
106 #[test]
107 fn test_simple_no_import() {
108 let rt = RustType::simple("bool");
109 assert_eq!(rt.needs_import, None);
110 }
111
112 #[test]
115 fn test_with_import_creates_with_import() {
116 let rt = RustType::with_import("Uuid", "use uuid::Uuid;");
117 assert_eq!(rt.path, "Uuid");
118 assert_eq!(rt.needs_import, Some("use uuid::Uuid;".to_string()));
119 }
120
121 #[test]
122 fn test_with_import_path_correct() {
123 let rt = RustType::with_import("DateTime<Utc>", "use chrono::{DateTime, Utc};");
124 assert_eq!(rt.path, "DateTime<Utc>");
125 }
126
127 #[test]
128 fn test_with_import_import_present() {
129 let rt = RustType::with_import("Value", "use serde_json::Value;");
130 assert!(rt.needs_import.is_some());
131 }
132
133 #[test]
136 fn test_wrap_option_wraps_path() {
137 let rt = RustType::simple("i32").wrap_option();
138 assert_eq!(rt.path, "Option<i32>");
139 }
140
141 #[test]
142 fn test_wrap_option_preserves_import() {
143 let rt = RustType::with_import("Uuid", "use uuid::Uuid;").wrap_option();
144 assert_eq!(rt.path, "Option<Uuid>");
145 assert_eq!(rt.needs_import, Some("use uuid::Uuid;".to_string()));
146 }
147
148 #[test]
149 fn test_wrap_option_double_wrap() {
150 let rt = RustType::simple("i32").wrap_option().wrap_option();
151 assert_eq!(rt.path, "Option<Option<i32>>");
152 }
153
154 #[test]
157 fn test_wrap_vec_wraps_path() {
158 let rt = RustType::simple("i32").wrap_vec();
159 assert_eq!(rt.path, "Vec<i32>");
160 }
161
162 #[test]
163 fn test_wrap_vec_preserves_import() {
164 let rt = RustType::with_import("Uuid", "use uuid::Uuid;").wrap_vec();
165 assert_eq!(rt.path, "Vec<Uuid>");
166 assert_eq!(rt.needs_import, Some("use uuid::Uuid;".to_string()));
167 }
168
169 #[test]
172 fn test_override_takes_precedence() {
173 let col = make_col("uuid", "uuid", false);
174 let schema = SchemaInfo::default();
175 let mut overrides = HashMap::new();
176 overrides.insert("uuid".to_string(), "MyUuid".to_string());
177 let rt = map_column(&col, DatabaseKind::Postgres, &schema, &overrides);
178 assert_eq!(rt.path, "MyUuid");
179 assert!(rt.needs_import.is_none());
180 }
181
182 #[test]
183 fn test_override_with_nullable() {
184 let col = make_col("uuid", "uuid", true);
185 let schema = SchemaInfo::default();
186 let mut overrides = HashMap::new();
187 overrides.insert("uuid".to_string(), "MyUuid".to_string());
188 let rt = map_column(&col, DatabaseKind::Postgres, &schema, &overrides);
189 assert_eq!(rt.path, "Option<MyUuid>");
190 }
191
192 #[test]
193 fn test_no_override_dispatches_postgres() {
194 let col = make_col("int4", "integer", false);
195 let schema = SchemaInfo::default();
196 let overrides = HashMap::new();
197 let rt = map_column(&col, DatabaseKind::Postgres, &schema, &overrides);
198 assert_eq!(rt.path, "i32");
199 }
200
201 #[test]
202 fn test_nullable_without_override() {
203 let col = make_col("int4", "integer", true);
204 let schema = SchemaInfo::default();
205 let overrides = HashMap::new();
206 let rt = map_column(&col, DatabaseKind::Postgres, &schema, &overrides);
207 assert_eq!(rt.path, "Option<i32>");
208 }
209}
210