Skip to main content

sqlx_gen/typemap/
mod.rs

1pub mod mysql;
2pub mod postgres;
3pub mod sqlite;
4
5use std::collections::HashMap;
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::introspect::{ColumnInfo, SchemaInfo};
9
10/// Resolved Rust type with its required imports.
11#[derive(Debug, Clone)]
12pub struct RustType {
13    pub path: String,
14    pub needs_import: Option<String>,
15}
16
17impl RustType {
18    pub fn simple(path: &str) -> Self {
19        Self {
20            path: path.to_string(),
21            needs_import: None,
22        }
23    }
24
25    pub fn with_import(path: &str, import: &str) -> Self {
26        Self {
27            path: path.to_string(),
28            needs_import: Some(import.to_string()),
29        }
30    }
31
32    pub fn wrap_option(self) -> Self {
33        Self {
34            path: format!("Option<{}>", self.path),
35            needs_import: self.needs_import,
36        }
37    }
38
39    pub fn wrap_vec(self) -> Self {
40        Self {
41            path: format!("Vec<{}>", self.path),
42            needs_import: self.needs_import,
43        }
44    }
45}
46
47pub fn map_column(
48    col: &ColumnInfo,
49    db_kind: DatabaseKind,
50    schema_info: &SchemaInfo,
51    overrides: &HashMap<String, String>,
52    time_crate: TimeCrate,
53) -> RustType {
54    // Check type overrides first
55    if let Some(override_type) = overrides.get(&col.udt_name) {
56        let rt = RustType::simple(override_type);
57        return if col.is_nullable {
58            rt.wrap_option()
59        } else {
60            rt
61        };
62    }
63
64    let base = match db_kind {
65        DatabaseKind::Postgres => postgres::map_type_qualified(
66            &col.udt_name,
67            col.udt_schema.as_deref(),
68            schema_info,
69            time_crate,
70        ),
71        DatabaseKind::Mysql => mysql::map_type(&col.data_type, &col.udt_name, time_crate),
72        DatabaseKind::Sqlite => sqlite::map_type(&col.udt_name, time_crate),
73    };
74
75    if col.is_nullable {
76        base.wrap_option()
77    } else {
78        base
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use crate::introspect::SchemaInfo;
86    use std::collections::HashMap;
87
88    fn make_col(udt_name: &str, data_type: &str, nullable: bool) -> ColumnInfo {
89        ColumnInfo {
90            name: "test".to_string(),
91            data_type: data_type.to_string(),
92            udt_name: udt_name.to_string(),
93            is_nullable: nullable,
94            is_primary_key: false,
95            ordinal_position: 0,
96            schema_name: "public".to_string(),
97            udt_schema: None,
98            column_default: None,
99        }
100    }
101
102    // --- RustType::simple ---
103
104    #[test]
105    fn test_simple_creates_without_import() {
106        let rt = RustType::simple("i32");
107        assert_eq!(rt.path, "i32");
108        assert!(rt.needs_import.is_none());
109    }
110
111    #[test]
112    fn test_simple_path_correct() {
113        let rt = RustType::simple("String");
114        assert_eq!(rt.path, "String");
115    }
116
117    #[test]
118    fn test_simple_no_import() {
119        let rt = RustType::simple("bool");
120        assert_eq!(rt.needs_import, None);
121    }
122
123    // --- RustType::with_import ---
124
125    #[test]
126    fn test_with_import_creates_with_import() {
127        let rt = RustType::with_import("Uuid", "use uuid::Uuid;");
128        assert_eq!(rt.path, "Uuid");
129        assert_eq!(rt.needs_import, Some("use uuid::Uuid;".to_string()));
130    }
131
132    #[test]
133    fn test_with_import_path_correct() {
134        let rt = RustType::with_import("DateTime<Utc>", "use chrono::{DateTime, Utc};");
135        assert_eq!(rt.path, "DateTime<Utc>");
136    }
137
138    #[test]
139    fn test_with_import_import_present() {
140        let rt = RustType::with_import("Value", "use serde_json::Value;");
141        assert!(rt.needs_import.is_some());
142    }
143
144    // --- RustType::wrap_option ---
145
146    #[test]
147    fn test_wrap_option_wraps_path() {
148        let rt = RustType::simple("i32").wrap_option();
149        assert_eq!(rt.path, "Option<i32>");
150    }
151
152    #[test]
153    fn test_wrap_option_preserves_import() {
154        let rt = RustType::with_import("Uuid", "use uuid::Uuid;").wrap_option();
155        assert_eq!(rt.path, "Option<Uuid>");
156        assert_eq!(rt.needs_import, Some("use uuid::Uuid;".to_string()));
157    }
158
159    #[test]
160    fn test_wrap_option_double_wrap() {
161        let rt = RustType::simple("i32").wrap_option().wrap_option();
162        assert_eq!(rt.path, "Option<Option<i32>>");
163    }
164
165    // --- RustType::wrap_vec ---
166
167    #[test]
168    fn test_wrap_vec_wraps_path() {
169        let rt = RustType::simple("i32").wrap_vec();
170        assert_eq!(rt.path, "Vec<i32>");
171    }
172
173    #[test]
174    fn test_wrap_vec_preserves_import() {
175        let rt = RustType::with_import("Uuid", "use uuid::Uuid;").wrap_vec();
176        assert_eq!(rt.path, "Vec<Uuid>");
177        assert_eq!(rt.needs_import, Some("use uuid::Uuid;".to_string()));
178    }
179
180    // --- map_column ---
181
182    #[test]
183    fn test_override_takes_precedence() {
184        let col = make_col("uuid", "uuid", false);
185        let schema = SchemaInfo::default();
186        let mut overrides = HashMap::new();
187        overrides.insert("uuid".to_string(), "MyUuid".to_string());
188        let rt = map_column(
189            &col,
190            DatabaseKind::Postgres,
191            &schema,
192            &overrides,
193            TimeCrate::Chrono,
194        );
195        assert_eq!(rt.path, "MyUuid");
196        assert!(rt.needs_import.is_none());
197    }
198
199    #[test]
200    fn test_override_with_nullable() {
201        let col = make_col("uuid", "uuid", true);
202        let schema = SchemaInfo::default();
203        let mut overrides = HashMap::new();
204        overrides.insert("uuid".to_string(), "MyUuid".to_string());
205        let rt = map_column(
206            &col,
207            DatabaseKind::Postgres,
208            &schema,
209            &overrides,
210            TimeCrate::Chrono,
211        );
212        assert_eq!(rt.path, "Option<MyUuid>");
213    }
214
215    #[test]
216    fn test_no_override_dispatches_postgres() {
217        let col = make_col("int4", "integer", false);
218        let schema = SchemaInfo::default();
219        let overrides = HashMap::new();
220        let rt = map_column(
221            &col,
222            DatabaseKind::Postgres,
223            &schema,
224            &overrides,
225            TimeCrate::Chrono,
226        );
227        assert_eq!(rt.path, "i32");
228    }
229
230    #[test]
231    fn test_nullable_without_override() {
232        let col = make_col("int4", "integer", true);
233        let schema = SchemaInfo::default();
234        let overrides = HashMap::new();
235        let rt = map_column(
236            &col,
237            DatabaseKind::Postgres,
238            &schema,
239            &overrides,
240            TimeCrate::Chrono,
241        );
242        assert_eq!(rt.path, "Option<i32>");
243    }
244}