Skip to main content

sqlx_gen/typemap/
mod.rs

1pub mod mysql;
2pub mod postgres;
3pub mod sqlite;
4
5use std::collections::HashMap;
6
7use crate::cli::DatabaseKind;
8use crate::introspect::{ColumnInfo, SchemaInfo};
9
10/// Resolved Rust type with its required imports.
11#[derive(Debug, Clone)]
12pub struct RustType {
13    pub path: String,
14    pub needs_import: Option<String>,
15}
16
17impl RustType {
18    pub fn simple(path: &str) -> Self {
19        Self {
20            path: path.to_string(),
21            needs_import: None,
22        }
23    }
24
25    pub fn with_import(path: &str, import: &str) -> Self {
26        Self {
27            path: path.to_string(),
28            needs_import: Some(import.to_string()),
29        }
30    }
31
32    pub fn wrap_option(self) -> Self {
33        Self {
34            path: format!("Option<{}>", self.path),
35            needs_import: self.needs_import,
36        }
37    }
38
39    pub fn wrap_vec(self) -> Self {
40        Self {
41            path: format!("Vec<{}>", self.path),
42            needs_import: self.needs_import,
43        }
44    }
45}
46
47pub fn map_column(
48    col: &ColumnInfo,
49    db_kind: DatabaseKind,
50    schema_info: &SchemaInfo,
51    overrides: &HashMap<String, String>,
52) -> RustType {
53    // Check type overrides first
54    if let Some(override_type) = overrides.get(&col.udt_name) {
55        let rt = RustType::simple(override_type);
56        return if col.is_nullable { rt.wrap_option() } else { rt };
57    }
58
59    let base = match db_kind {
60        DatabaseKind::Postgres => postgres::map_type(&col.udt_name, schema_info),
61        DatabaseKind::Mysql => mysql::map_type(&col.data_type, &col.udt_name),
62        DatabaseKind::Sqlite => sqlite::map_type(&col.udt_name),
63    };
64
65    if col.is_nullable {
66        base.wrap_option()
67    } else {
68        base
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75    use crate::introspect::SchemaInfo;
76    use std::collections::HashMap;
77
78    fn make_col(udt_name: &str, data_type: &str, nullable: bool) -> ColumnInfo {
79        ColumnInfo {
80            name: "test".to_string(),
81            data_type: data_type.to_string(),
82            udt_name: udt_name.to_string(),
83            is_nullable: nullable,
84            is_primary_key: false,
85            ordinal_position: 0,
86            schema_name: "public".to_string(),
87            column_default: None,
88        }
89    }
90
91    // --- RustType::simple ---
92
93    #[test]
94    fn test_simple_creates_without_import() {
95        let rt = RustType::simple("i32");
96        assert_eq!(rt.path, "i32");
97        assert!(rt.needs_import.is_none());
98    }
99
100    #[test]
101    fn test_simple_path_correct() {
102        let rt = RustType::simple("String");
103        assert_eq!(rt.path, "String");
104    }
105
106    #[test]
107    fn test_simple_no_import() {
108        let rt = RustType::simple("bool");
109        assert_eq!(rt.needs_import, None);
110    }
111
112    // --- RustType::with_import ---
113
114    #[test]
115    fn test_with_import_creates_with_import() {
116        let rt = RustType::with_import("Uuid", "use uuid::Uuid;");
117        assert_eq!(rt.path, "Uuid");
118        assert_eq!(rt.needs_import, Some("use uuid::Uuid;".to_string()));
119    }
120
121    #[test]
122    fn test_with_import_path_correct() {
123        let rt = RustType::with_import("DateTime<Utc>", "use chrono::{DateTime, Utc};");
124        assert_eq!(rt.path, "DateTime<Utc>");
125    }
126
127    #[test]
128    fn test_with_import_import_present() {
129        let rt = RustType::with_import("Value", "use serde_json::Value;");
130        assert!(rt.needs_import.is_some());
131    }
132
133    // --- RustType::wrap_option ---
134
135    #[test]
136    fn test_wrap_option_wraps_path() {
137        let rt = RustType::simple("i32").wrap_option();
138        assert_eq!(rt.path, "Option<i32>");
139    }
140
141    #[test]
142    fn test_wrap_option_preserves_import() {
143        let rt = RustType::with_import("Uuid", "use uuid::Uuid;").wrap_option();
144        assert_eq!(rt.path, "Option<Uuid>");
145        assert_eq!(rt.needs_import, Some("use uuid::Uuid;".to_string()));
146    }
147
148    #[test]
149    fn test_wrap_option_double_wrap() {
150        let rt = RustType::simple("i32").wrap_option().wrap_option();
151        assert_eq!(rt.path, "Option<Option<i32>>");
152    }
153
154    // --- RustType::wrap_vec ---
155
156    #[test]
157    fn test_wrap_vec_wraps_path() {
158        let rt = RustType::simple("i32").wrap_vec();
159        assert_eq!(rt.path, "Vec<i32>");
160    }
161
162    #[test]
163    fn test_wrap_vec_preserves_import() {
164        let rt = RustType::with_import("Uuid", "use uuid::Uuid;").wrap_vec();
165        assert_eq!(rt.path, "Vec<Uuid>");
166        assert_eq!(rt.needs_import, Some("use uuid::Uuid;".to_string()));
167    }
168
169    // --- map_column ---
170
171    #[test]
172    fn test_override_takes_precedence() {
173        let col = make_col("uuid", "uuid", false);
174        let schema = SchemaInfo::default();
175        let mut overrides = HashMap::new();
176        overrides.insert("uuid".to_string(), "MyUuid".to_string());
177        let rt = map_column(&col, DatabaseKind::Postgres, &schema, &overrides);
178        assert_eq!(rt.path, "MyUuid");
179        assert!(rt.needs_import.is_none());
180    }
181
182    #[test]
183    fn test_override_with_nullable() {
184        let col = make_col("uuid", "uuid", true);
185        let schema = SchemaInfo::default();
186        let mut overrides = HashMap::new();
187        overrides.insert("uuid".to_string(), "MyUuid".to_string());
188        let rt = map_column(&col, DatabaseKind::Postgres, &schema, &overrides);
189        assert_eq!(rt.path, "Option<MyUuid>");
190    }
191
192    #[test]
193    fn test_no_override_dispatches_postgres() {
194        let col = make_col("int4", "integer", false);
195        let schema = SchemaInfo::default();
196        let overrides = HashMap::new();
197        let rt = map_column(&col, DatabaseKind::Postgres, &schema, &overrides);
198        assert_eq!(rt.path, "i32");
199    }
200
201    #[test]
202    fn test_nullable_without_override() {
203        let col = make_col("int4", "integer", true);
204        let schema = SchemaInfo::default();
205        let overrides = HashMap::new();
206        let rt = map_column(&col, DatabaseKind::Postgres, &schema, &overrides);
207        assert_eq!(rt.path, "Option<i32>");
208    }
209}
210