Skip to main content

sqlx_gen/typemap/
postgres.rs

1use heck::ToUpperCamelCase;
2
3use super::RustType;
4use crate::introspect::SchemaInfo;
5
6pub fn map_type(udt_name: &str, schema_info: &SchemaInfo) -> RustType {
7    // Handle array types (prefixed with '_' in PG)
8    if let Some(inner) = udt_name.strip_prefix('_') {
9        let inner_type = map_type(inner, schema_info);
10        return inner_type.wrap_vec();
11    }
12
13    // Check if it's a known enum
14    if schema_info.enums.iter().any(|e| e.name == udt_name) {
15        let name = udt_name.to_upper_camel_case();
16        return RustType::with_import(&name, &format!("use super::types::{};", name));
17    }
18
19    // Check if it's a known composite type
20    if schema_info.composite_types.iter().any(|c| c.name == udt_name) {
21        let name = udt_name.to_upper_camel_case();
22        return RustType::with_import(&name, &format!("use super::types::{};", name));
23    }
24
25    // Check if it's a known domain
26    if let Some(domain) = schema_info.domains.iter().find(|d| d.name == udt_name) {
27        // Map to the domain's base type
28        return map_type(&domain.base_type, schema_info);
29    }
30
31    match udt_name {
32        "bool" => RustType::simple("bool"),
33        "int2" | "smallint" | "smallserial" => RustType::simple("i16"),
34        "int4" | "int" | "integer" | "serial" => RustType::simple("i32"),
35        "int8" | "bigint" | "bigserial" => RustType::simple("i64"),
36        "float4" | "real" => RustType::simple("f32"),
37        "float8" | "double precision" => RustType::simple("f64"),
38        "numeric" | "decimal" => {
39            RustType::with_import("Decimal", "use rust_decimal::Decimal;")
40        }
41        "varchar" | "text" | "bpchar" | "char" | "name" | "citext" => RustType::simple("String"),
42        "bytea" => RustType::simple("Vec<u8>"),
43        "timestamp" | "timestamp without time zone" => {
44            RustType::with_import("NaiveDateTime", "use chrono::NaiveDateTime;")
45        }
46        "timestamptz" | "timestamp with time zone" => {
47            RustType::with_import("DateTime<Utc>", "use chrono::{DateTime, Utc};")
48        }
49        "date" => RustType::with_import("NaiveDate", "use chrono::NaiveDate;"),
50        "time" | "time without time zone" => {
51            RustType::with_import("NaiveTime", "use chrono::NaiveTime;")
52        }
53        "timetz" | "time with time zone" => {
54            RustType::with_import("NaiveTime", "use chrono::NaiveTime;")
55        }
56        "uuid" => RustType::with_import("Uuid", "use uuid::Uuid;"),
57        "json" | "jsonb" => {
58            RustType::with_import("Value", "use serde_json::Value;")
59        }
60        "inet" | "cidr" => {
61            RustType::with_import("IpNetwork", "use ipnetwork::IpNetwork;")
62        }
63        "oid" => RustType::simple("u32"),
64        _ => RustType::simple("String"), // fallback
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71    use crate::introspect::{CompositeTypeInfo, DomainInfo, EnumInfo};
72
73    fn empty_schema() -> SchemaInfo {
74        SchemaInfo::default()
75    }
76
77    fn schema_with_enum(name: &str) -> SchemaInfo {
78        SchemaInfo {
79            enums: vec![EnumInfo {
80                schema_name: "public".to_string(),
81                name: name.to_string(),
82                variants: vec!["a".to_string()],
83            }],
84            ..Default::default()
85        }
86    }
87
88    fn schema_with_composite(name: &str) -> SchemaInfo {
89        SchemaInfo {
90            composite_types: vec![CompositeTypeInfo {
91                schema_name: "public".to_string(),
92                name: name.to_string(),
93                fields: vec![],
94            }],
95            ..Default::default()
96        }
97    }
98
99    fn schema_with_domain(name: &str, base: &str) -> SchemaInfo {
100        SchemaInfo {
101            domains: vec![DomainInfo {
102                schema_name: "public".to_string(),
103                name: name.to_string(),
104                base_type: base.to_string(),
105            }],
106            ..Default::default()
107        }
108    }
109
110    // --- builtins ---
111
112    #[test]
113    fn test_bool() {
114        assert_eq!(map_type("bool", &empty_schema()).path, "bool");
115    }
116
117    #[test]
118    fn test_int2() {
119        assert_eq!(map_type("int2", &empty_schema()).path, "i16");
120    }
121
122    #[test]
123    fn test_smallint() {
124        assert_eq!(map_type("smallint", &empty_schema()).path, "i16");
125    }
126
127    #[test]
128    fn test_smallserial() {
129        assert_eq!(map_type("smallserial", &empty_schema()).path, "i16");
130    }
131
132    #[test]
133    fn test_int4() {
134        assert_eq!(map_type("int4", &empty_schema()).path, "i32");
135    }
136
137    #[test]
138    fn test_integer() {
139        assert_eq!(map_type("integer", &empty_schema()).path, "i32");
140    }
141
142    #[test]
143    fn test_serial() {
144        assert_eq!(map_type("serial", &empty_schema()).path, "i32");
145    }
146
147    #[test]
148    fn test_int8() {
149        assert_eq!(map_type("int8", &empty_schema()).path, "i64");
150    }
151
152    #[test]
153    fn test_bigint() {
154        assert_eq!(map_type("bigint", &empty_schema()).path, "i64");
155    }
156
157    #[test]
158    fn test_bigserial() {
159        assert_eq!(map_type("bigserial", &empty_schema()).path, "i64");
160    }
161
162    #[test]
163    fn test_float4() {
164        assert_eq!(map_type("float4", &empty_schema()).path, "f32");
165    }
166
167    #[test]
168    fn test_real() {
169        assert_eq!(map_type("real", &empty_schema()).path, "f32");
170    }
171
172    #[test]
173    fn test_float8() {
174        assert_eq!(map_type("float8", &empty_schema()).path, "f64");
175    }
176
177    #[test]
178    fn test_double_precision() {
179        assert_eq!(map_type("double precision", &empty_schema()).path, "f64");
180    }
181
182    #[test]
183    fn test_numeric() {
184        let rt = map_type("numeric", &empty_schema());
185        assert_eq!(rt.path, "Decimal");
186        assert!(rt.needs_import.as_ref().unwrap().contains("rust_decimal"));
187    }
188
189    #[test]
190    fn test_decimal() {
191        let rt = map_type("decimal", &empty_schema());
192        assert_eq!(rt.path, "Decimal");
193    }
194
195    #[test]
196    fn test_varchar() {
197        assert_eq!(map_type("varchar", &empty_schema()).path, "String");
198    }
199
200    #[test]
201    fn test_text() {
202        assert_eq!(map_type("text", &empty_schema()).path, "String");
203    }
204
205    #[test]
206    fn test_bpchar() {
207        assert_eq!(map_type("bpchar", &empty_schema()).path, "String");
208    }
209
210    #[test]
211    fn test_citext() {
212        assert_eq!(map_type("citext", &empty_schema()).path, "String");
213    }
214
215    #[test]
216    fn test_name() {
217        assert_eq!(map_type("name", &empty_schema()).path, "String");
218    }
219
220    #[test]
221    fn test_bytea() {
222        assert_eq!(map_type("bytea", &empty_schema()).path, "Vec<u8>");
223    }
224
225    #[test]
226    fn test_uuid() {
227        let rt = map_type("uuid", &empty_schema());
228        assert_eq!(rt.path, "Uuid");
229        assert!(rt.needs_import.as_ref().unwrap().contains("uuid::Uuid"));
230    }
231
232    #[test]
233    fn test_json() {
234        let rt = map_type("json", &empty_schema());
235        assert_eq!(rt.path, "Value");
236        assert!(rt.needs_import.as_ref().unwrap().contains("serde_json"));
237    }
238
239    #[test]
240    fn test_jsonb() {
241        let rt = map_type("jsonb", &empty_schema());
242        assert_eq!(rt.path, "Value");
243    }
244
245    #[test]
246    fn test_timestamp() {
247        let rt = map_type("timestamp", &empty_schema());
248        assert_eq!(rt.path, "NaiveDateTime");
249        assert!(rt.needs_import.as_ref().unwrap().contains("chrono"));
250    }
251
252    #[test]
253    fn test_timestamptz() {
254        let rt = map_type("timestamptz", &empty_schema());
255        assert_eq!(rt.path, "DateTime<Utc>");
256        assert!(rt.needs_import.as_ref().unwrap().contains("chrono"));
257    }
258
259    #[test]
260    fn test_date() {
261        let rt = map_type("date", &empty_schema());
262        assert_eq!(rt.path, "NaiveDate");
263    }
264
265    #[test]
266    fn test_time() {
267        let rt = map_type("time", &empty_schema());
268        assert_eq!(rt.path, "NaiveTime");
269    }
270
271    #[test]
272    fn test_timetz() {
273        let rt = map_type("timetz", &empty_schema());
274        assert_eq!(rt.path, "NaiveTime");
275    }
276
277    #[test]
278    fn test_inet() {
279        let rt = map_type("inet", &empty_schema());
280        assert_eq!(rt.path, "IpNetwork");
281        assert!(rt.needs_import.as_ref().unwrap().contains("ipnetwork"));
282    }
283
284    #[test]
285    fn test_cidr() {
286        let rt = map_type("cidr", &empty_schema());
287        assert_eq!(rt.path, "IpNetwork");
288    }
289
290    #[test]
291    fn test_oid() {
292        assert_eq!(map_type("oid", &empty_schema()).path, "u32");
293    }
294
295    // --- arrays ---
296
297    #[test]
298    fn test_array_int4() {
299        assert_eq!(map_type("_int4", &empty_schema()).path, "Vec<i32>");
300    }
301
302    #[test]
303    fn test_array_text() {
304        assert_eq!(map_type("_text", &empty_schema()).path, "Vec<String>");
305    }
306
307    #[test]
308    fn test_array_uuid() {
309        let rt = map_type("_uuid", &empty_schema());
310        assert_eq!(rt.path, "Vec<Uuid>");
311        assert!(rt.needs_import.is_some());
312    }
313
314    #[test]
315    fn test_array_bool() {
316        assert_eq!(map_type("_bool", &empty_schema()).path, "Vec<bool>");
317    }
318
319    #[test]
320    fn test_array_jsonb() {
321        let rt = map_type("_jsonb", &empty_schema());
322        assert_eq!(rt.path, "Vec<Value>");
323        assert!(rt.needs_import.is_some());
324    }
325
326    #[test]
327    fn test_array_bytea() {
328        assert_eq!(map_type("_bytea", &empty_schema()).path, "Vec<Vec<u8>>");
329    }
330
331    // --- enums/composites/domains ---
332
333    #[test]
334    fn test_enum_status() {
335        let schema = schema_with_enum("status");
336        let rt = map_type("status", &schema);
337        assert_eq!(rt.path, "Status");
338        assert!(rt.needs_import.as_ref().unwrap().contains("super::types::Status"));
339    }
340
341    #[test]
342    fn test_enum_user_role() {
343        let schema = schema_with_enum("user_role");
344        let rt = map_type("user_role", &schema);
345        assert_eq!(rt.path, "UserRole");
346    }
347
348    #[test]
349    fn test_composite_address() {
350        let schema = schema_with_composite("address");
351        let rt = map_type("address", &schema);
352        assert_eq!(rt.path, "Address");
353        assert!(rt.needs_import.as_ref().unwrap().contains("super::types::Address"));
354    }
355
356    #[test]
357    fn test_composite_geo_point() {
358        let schema = schema_with_composite("geo_point");
359        let rt = map_type("geo_point", &schema);
360        assert_eq!(rt.path, "GeoPoint");
361    }
362
363    #[test]
364    fn test_domain_text() {
365        let schema = schema_with_domain("email", "text");
366        let rt = map_type("email", &schema);
367        assert_eq!(rt.path, "String");
368    }
369
370    #[test]
371    fn test_domain_int4() {
372        let schema = schema_with_domain("positive_int", "int4");
373        let rt = map_type("positive_int", &schema);
374        assert_eq!(rt.path, "i32");
375    }
376
377    #[test]
378    fn test_domain_uuid() {
379        let schema = schema_with_domain("my_uuid", "uuid");
380        let rt = map_type("my_uuid", &schema);
381        assert_eq!(rt.path, "Uuid");
382        assert!(rt.needs_import.is_some());
383    }
384
385    // --- arrays of custom types ---
386
387    #[test]
388    fn test_array_enum() {
389        let schema = schema_with_enum("status");
390        let rt = map_type("_status", &schema);
391        assert_eq!(rt.path, "Vec<Status>");
392        assert!(rt.needs_import.is_some());
393    }
394
395    #[test]
396    fn test_array_composite() {
397        let schema = schema_with_composite("address");
398        let rt = map_type("_address", &schema);
399        assert_eq!(rt.path, "Vec<Address>");
400    }
401
402    // --- fallback ---
403
404    #[test]
405    fn test_geometry_fallback() {
406        assert_eq!(map_type("geometry", &empty_schema()).path, "String");
407    }
408
409    #[test]
410    fn test_hstore_fallback() {
411        assert_eq!(map_type("hstore", &empty_schema()).path, "String");
412    }
413}