Skip to main content

sqlx_gen/typemap/
mod.rs

1pub mod mysql;
2pub mod postgres;
3pub mod sqlite;
4
5use std::collections::HashMap;
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::introspect::{ColumnInfo, SchemaInfo};
9
10/// Resolved Rust type with its required imports.
11#[derive(Debug, Clone)]
12pub struct RustType {
13    pub path: String,
14    pub needs_import: Option<String>,
15}
16
17impl RustType {
18    pub fn simple(path: &str) -> Self {
19        Self {
20            path: path.to_string(),
21            needs_import: None,
22        }
23    }
24
25    pub fn with_import(path: &str, import: &str) -> Self {
26        Self {
27            path: path.to_string(),
28            needs_import: Some(import.to_string()),
29        }
30    }
31
32    pub fn wrap_option(self) -> Self {
33        Self {
34            path: format!("Option<{}>", self.path),
35            needs_import: self.needs_import,
36        }
37    }
38
39    pub fn wrap_vec(self) -> Self {
40        Self {
41            path: format!("Vec<{}>", self.path),
42            needs_import: self.needs_import,
43        }
44    }
45}
46
47pub fn map_column(
48    col: &ColumnInfo,
49    db_kind: DatabaseKind,
50    schema_info: &SchemaInfo,
51    overrides: &HashMap<String, String>,
52    time_crate: TimeCrate,
53) -> RustType {
54    // Check type overrides first
55    if let Some(override_type) = overrides.get(&col.udt_name) {
56        let rt = RustType::simple(override_type);
57        return if col.is_nullable { rt.wrap_option() } else { rt };
58    }
59
60    let base = match db_kind {
61        DatabaseKind::Postgres => postgres::map_type(&col.udt_name, schema_info, time_crate),
62        DatabaseKind::Mysql => mysql::map_type(&col.data_type, &col.udt_name, time_crate),
63        DatabaseKind::Sqlite => sqlite::map_type(&col.udt_name, time_crate),
64    };
65
66    if col.is_nullable {
67        base.wrap_option()
68    } else {
69        base
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76    use crate::introspect::SchemaInfo;
77    use std::collections::HashMap;
78
79    fn make_col(udt_name: &str, data_type: &str, nullable: bool) -> ColumnInfo {
80        ColumnInfo {
81            name: "test".to_string(),
82            data_type: data_type.to_string(),
83            udt_name: udt_name.to_string(),
84            is_nullable: nullable,
85            is_primary_key: false,
86            ordinal_position: 0,
87            schema_name: "public".to_string(),
88            column_default: None,
89        }
90    }
91
92    // --- RustType::simple ---
93
94    #[test]
95    fn test_simple_creates_without_import() {
96        let rt = RustType::simple("i32");
97        assert_eq!(rt.path, "i32");
98        assert!(rt.needs_import.is_none());
99    }
100
101    #[test]
102    fn test_simple_path_correct() {
103        let rt = RustType::simple("String");
104        assert_eq!(rt.path, "String");
105    }
106
107    #[test]
108    fn test_simple_no_import() {
109        let rt = RustType::simple("bool");
110        assert_eq!(rt.needs_import, None);
111    }
112
113    // --- RustType::with_import ---
114
115    #[test]
116    fn test_with_import_creates_with_import() {
117        let rt = RustType::with_import("Uuid", "use uuid::Uuid;");
118        assert_eq!(rt.path, "Uuid");
119        assert_eq!(rt.needs_import, Some("use uuid::Uuid;".to_string()));
120    }
121
122    #[test]
123    fn test_with_import_path_correct() {
124        let rt = RustType::with_import("DateTime<Utc>", "use chrono::{DateTime, Utc};");
125        assert_eq!(rt.path, "DateTime<Utc>");
126    }
127
128    #[test]
129    fn test_with_import_import_present() {
130        let rt = RustType::with_import("Value", "use serde_json::Value;");
131        assert!(rt.needs_import.is_some());
132    }
133
134    // --- RustType::wrap_option ---
135
136    #[test]
137    fn test_wrap_option_wraps_path() {
138        let rt = RustType::simple("i32").wrap_option();
139        assert_eq!(rt.path, "Option<i32>");
140    }
141
142    #[test]
143    fn test_wrap_option_preserves_import() {
144        let rt = RustType::with_import("Uuid", "use uuid::Uuid;").wrap_option();
145        assert_eq!(rt.path, "Option<Uuid>");
146        assert_eq!(rt.needs_import, Some("use uuid::Uuid;".to_string()));
147    }
148
149    #[test]
150    fn test_wrap_option_double_wrap() {
151        let rt = RustType::simple("i32").wrap_option().wrap_option();
152        assert_eq!(rt.path, "Option<Option<i32>>");
153    }
154
155    // --- RustType::wrap_vec ---
156
157    #[test]
158    fn test_wrap_vec_wraps_path() {
159        let rt = RustType::simple("i32").wrap_vec();
160        assert_eq!(rt.path, "Vec<i32>");
161    }
162
163    #[test]
164    fn test_wrap_vec_preserves_import() {
165        let rt = RustType::with_import("Uuid", "use uuid::Uuid;").wrap_vec();
166        assert_eq!(rt.path, "Vec<Uuid>");
167        assert_eq!(rt.needs_import, Some("use uuid::Uuid;".to_string()));
168    }
169
170    // --- map_column ---
171
172    #[test]
173    fn test_override_takes_precedence() {
174        let col = make_col("uuid", "uuid", false);
175        let schema = SchemaInfo::default();
176        let mut overrides = HashMap::new();
177        overrides.insert("uuid".to_string(), "MyUuid".to_string());
178        let rt = map_column(&col, DatabaseKind::Postgres, &schema, &overrides, TimeCrate::Chrono);
179        assert_eq!(rt.path, "MyUuid");
180        assert!(rt.needs_import.is_none());
181    }
182
183    #[test]
184    fn test_override_with_nullable() {
185        let col = make_col("uuid", "uuid", true);
186        let schema = SchemaInfo::default();
187        let mut overrides = HashMap::new();
188        overrides.insert("uuid".to_string(), "MyUuid".to_string());
189        let rt = map_column(&col, DatabaseKind::Postgres, &schema, &overrides, TimeCrate::Chrono);
190        assert_eq!(rt.path, "Option<MyUuid>");
191    }
192
193    #[test]
194    fn test_no_override_dispatches_postgres() {
195        let col = make_col("int4", "integer", false);
196        let schema = SchemaInfo::default();
197        let overrides = HashMap::new();
198        let rt = map_column(&col, DatabaseKind::Postgres, &schema, &overrides, TimeCrate::Chrono);
199        assert_eq!(rt.path, "i32");
200    }
201
202    #[test]
203    fn test_nullable_without_override() {
204        let col = make_col("int4", "integer", true);
205        let schema = SchemaInfo::default();
206        let overrides = HashMap::new();
207        let rt = map_column(&col, DatabaseKind::Postgres, &schema, &overrides, TimeCrate::Chrono);
208        assert_eq!(rt.path, "Option<i32>");
209    }
210}
211