Skip to main content

sqlx_gen/typemap/
postgres.rs

1use heck::ToUpperCamelCase;
2
3use super::RustType;
4use crate::cli::TimeCrate;
5use crate::introspect::SchemaInfo;
6
7/// Returns true if the udt_name is a known PostgreSQL builtin type
8/// (i.e., not a fallback to String).
9pub fn is_builtin(udt_name: &str) -> bool {
10    matches!(
11        udt_name,
12        "bool"
13            | "int2" | "smallint" | "smallserial"
14            | "int4" | "int" | "integer" | "serial"
15            | "int8" | "bigint" | "bigserial"
16            | "float4" | "real"
17            | "float8" | "double precision"
18            | "numeric" | "decimal"
19            | "varchar" | "text" | "bpchar" | "char" | "name" | "citext"
20            | "bytea"
21            | "timestamp" | "timestamp without time zone"
22            | "timestamptz" | "timestamp with time zone"
23            | "date"
24            | "time" | "time without time zone"
25            | "timetz" | "time with time zone"
26            | "uuid"
27            | "json" | "jsonb"
28            | "inet" | "cidr"
29            | "oid"
30    )
31}
32
33pub fn map_type(udt_name: &str, schema_info: &SchemaInfo, time_crate: TimeCrate) -> RustType {
34    // Handle array types (prefixed with '_' in PG)
35    if let Some(inner) = udt_name.strip_prefix('_') {
36        let inner_type = map_type(inner, schema_info, time_crate);
37        return inner_type.wrap_vec();
38    }
39
40    // Check if it's a known enum
41    if schema_info.enums.iter().any(|e| e.name == udt_name) {
42        let name = udt_name.to_upper_camel_case();
43        return RustType::with_import(&name, &format!("use super::types::{};", name));
44    }
45
46    // Check if it's a known composite type
47    if schema_info.composite_types.iter().any(|c| c.name == udt_name) {
48        let name = udt_name.to_upper_camel_case();
49        return RustType::with_import(&name, &format!("use super::types::{};", name));
50    }
51
52    // Check if it's a known domain
53    if let Some(domain) = schema_info.domains.iter().find(|d| d.name == udt_name) {
54        // Map to the domain's base type
55        return map_type(&domain.base_type, schema_info, time_crate);
56    }
57
58    match udt_name {
59        "bool" => RustType::simple("bool"),
60        "int2" | "smallint" | "smallserial" => RustType::simple("i16"),
61        "int4" | "int" | "integer" | "serial" => RustType::simple("i32"),
62        "int8" | "bigint" | "bigserial" => RustType::simple("i64"),
63        "float4" | "real" => RustType::simple("f32"),
64        "float8" | "double precision" => RustType::simple("f64"),
65        "numeric" | "decimal" => {
66            RustType::with_import("Decimal", "use rust_decimal::Decimal;")
67        }
68        "varchar" | "text" | "bpchar" | "char" | "name" | "citext" => RustType::simple("String"),
69        "bytea" => RustType::simple("Vec<u8>"),
70        "timestamp" | "timestamp without time zone" => match time_crate {
71            TimeCrate::Chrono => RustType::with_import("NaiveDateTime", "use chrono::NaiveDateTime;"),
72            TimeCrate::Time => RustType::with_import("PrimitiveDateTime", "use time::PrimitiveDateTime;"),
73        },
74        "timestamptz" | "timestamp with time zone" => match time_crate {
75            TimeCrate::Chrono => RustType::with_import("DateTime<Utc>", "use chrono::{DateTime, Utc};"),
76            TimeCrate::Time => RustType::with_import("OffsetDateTime", "use time::OffsetDateTime;"),
77        },
78        "date" => match time_crate {
79            TimeCrate::Chrono => RustType::with_import("NaiveDate", "use chrono::NaiveDate;"),
80            TimeCrate::Time => RustType::with_import("Date", "use time::Date;"),
81        },
82        "time" | "time without time zone" => match time_crate {
83            TimeCrate::Chrono => RustType::with_import("NaiveTime", "use chrono::NaiveTime;"),
84            TimeCrate::Time => RustType::with_import("Time", "use time::Time;"),
85        },
86        "timetz" | "time with time zone" => match time_crate {
87            TimeCrate::Chrono => RustType::with_import("NaiveTime", "use chrono::NaiveTime;"),
88            TimeCrate::Time => RustType::with_import("Time", "use time::Time;"),
89        },
90        "uuid" => RustType::with_import("Uuid", "use uuid::Uuid;"),
91        "json" | "jsonb" => {
92            RustType::with_import("Value", "use serde_json::Value;")
93        }
94        "inet" | "cidr" => {
95            RustType::with_import("IpNetwork", "use ipnetwork::IpNetwork;")
96        }
97        "oid" => RustType::simple("u32"),
98        _ => RustType::simple("String"), // fallback
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use crate::cli::TimeCrate;
106    use crate::introspect::{CompositeTypeInfo, DomainInfo, EnumInfo};
107
108    fn empty_schema() -> SchemaInfo {
109        SchemaInfo::default()
110    }
111
112    fn schema_with_enum(name: &str) -> SchemaInfo {
113        SchemaInfo {
114            enums: vec![EnumInfo {
115                schema_name: "public".to_string(),
116                name: name.to_string(),
117                variants: vec!["a".to_string()],
118                default_variant: None,
119            }],
120            ..Default::default()
121        }
122    }
123
124    fn schema_with_composite(name: &str) -> SchemaInfo {
125        SchemaInfo {
126            composite_types: vec![CompositeTypeInfo {
127                schema_name: "public".to_string(),
128                name: name.to_string(),
129                fields: vec![],
130            }],
131            ..Default::default()
132        }
133    }
134
135    fn schema_with_domain(name: &str, base: &str) -> SchemaInfo {
136        SchemaInfo {
137            domains: vec![DomainInfo {
138                schema_name: "public".to_string(),
139                name: name.to_string(),
140                base_type: base.to_string(),
141            }],
142            ..Default::default()
143        }
144    }
145
146    // --- builtins ---
147
148    #[test]
149    fn test_bool() {
150        assert_eq!(map_type("bool", &empty_schema(), TimeCrate::Chrono).path, "bool");
151    }
152
153    #[test]
154    fn test_int2() {
155        assert_eq!(map_type("int2", &empty_schema(), TimeCrate::Chrono).path, "i16");
156    }
157
158    #[test]
159    fn test_smallint() {
160        assert_eq!(map_type("smallint", &empty_schema(), TimeCrate::Chrono).path, "i16");
161    }
162
163    #[test]
164    fn test_smallserial() {
165        assert_eq!(map_type("smallserial", &empty_schema(), TimeCrate::Chrono).path, "i16");
166    }
167
168    #[test]
169    fn test_int4() {
170        assert_eq!(map_type("int4", &empty_schema(), TimeCrate::Chrono).path, "i32");
171    }
172
173    #[test]
174    fn test_integer() {
175        assert_eq!(map_type("integer", &empty_schema(), TimeCrate::Chrono).path, "i32");
176    }
177
178    #[test]
179    fn test_serial() {
180        assert_eq!(map_type("serial", &empty_schema(), TimeCrate::Chrono).path, "i32");
181    }
182
183    #[test]
184    fn test_int8() {
185        assert_eq!(map_type("int8", &empty_schema(), TimeCrate::Chrono).path, "i64");
186    }
187
188    #[test]
189    fn test_bigint() {
190        assert_eq!(map_type("bigint", &empty_schema(), TimeCrate::Chrono).path, "i64");
191    }
192
193    #[test]
194    fn test_bigserial() {
195        assert_eq!(map_type("bigserial", &empty_schema(), TimeCrate::Chrono).path, "i64");
196    }
197
198    #[test]
199    fn test_float4() {
200        assert_eq!(map_type("float4", &empty_schema(), TimeCrate::Chrono).path, "f32");
201    }
202
203    #[test]
204    fn test_real() {
205        assert_eq!(map_type("real", &empty_schema(), TimeCrate::Chrono).path, "f32");
206    }
207
208    #[test]
209    fn test_float8() {
210        assert_eq!(map_type("float8", &empty_schema(), TimeCrate::Chrono).path, "f64");
211    }
212
213    #[test]
214    fn test_double_precision() {
215        assert_eq!(map_type("double precision", &empty_schema(), TimeCrate::Chrono).path, "f64");
216    }
217
218    #[test]
219    fn test_numeric() {
220        let rt = map_type("numeric", &empty_schema(), TimeCrate::Chrono);
221        assert_eq!(rt.path, "Decimal");
222        assert!(rt.needs_import.as_ref().unwrap().contains("rust_decimal"));
223    }
224
225    #[test]
226    fn test_decimal() {
227        let rt = map_type("decimal", &empty_schema(), TimeCrate::Chrono);
228        assert_eq!(rt.path, "Decimal");
229    }
230
231    #[test]
232    fn test_varchar() {
233        assert_eq!(map_type("varchar", &empty_schema(), TimeCrate::Chrono).path, "String");
234    }
235
236    #[test]
237    fn test_text() {
238        assert_eq!(map_type("text", &empty_schema(), TimeCrate::Chrono).path, "String");
239    }
240
241    #[test]
242    fn test_bpchar() {
243        assert_eq!(map_type("bpchar", &empty_schema(), TimeCrate::Chrono).path, "String");
244    }
245
246    #[test]
247    fn test_citext() {
248        assert_eq!(map_type("citext", &empty_schema(), TimeCrate::Chrono).path, "String");
249    }
250
251    #[test]
252    fn test_name() {
253        assert_eq!(map_type("name", &empty_schema(), TimeCrate::Chrono).path, "String");
254    }
255
256    #[test]
257    fn test_bytea() {
258        assert_eq!(map_type("bytea", &empty_schema(), TimeCrate::Chrono).path, "Vec<u8>");
259    }
260
261    #[test]
262    fn test_uuid() {
263        let rt = map_type("uuid", &empty_schema(), TimeCrate::Chrono);
264        assert_eq!(rt.path, "Uuid");
265        assert!(rt.needs_import.as_ref().unwrap().contains("uuid::Uuid"));
266    }
267
268    #[test]
269    fn test_json() {
270        let rt = map_type("json", &empty_schema(), TimeCrate::Chrono);
271        assert_eq!(rt.path, "Value");
272        assert!(rt.needs_import.as_ref().unwrap().contains("serde_json"));
273    }
274
275    #[test]
276    fn test_jsonb() {
277        let rt = map_type("jsonb", &empty_schema(), TimeCrate::Chrono);
278        assert_eq!(rt.path, "Value");
279    }
280
281    #[test]
282    fn test_timestamp() {
283        let rt = map_type("timestamp", &empty_schema(), TimeCrate::Chrono);
284        assert_eq!(rt.path, "NaiveDateTime");
285        assert!(rt.needs_import.as_ref().unwrap().contains("chrono"));
286    }
287
288    #[test]
289    fn test_timestamptz() {
290        let rt = map_type("timestamptz", &empty_schema(), TimeCrate::Chrono);
291        assert_eq!(rt.path, "DateTime<Utc>");
292        assert!(rt.needs_import.as_ref().unwrap().contains("chrono"));
293    }
294
295    #[test]
296    fn test_date() {
297        let rt = map_type("date", &empty_schema(), TimeCrate::Chrono);
298        assert_eq!(rt.path, "NaiveDate");
299    }
300
301    #[test]
302    fn test_time() {
303        let rt = map_type("time", &empty_schema(), TimeCrate::Chrono);
304        assert_eq!(rt.path, "NaiveTime");
305    }
306
307    #[test]
308    fn test_timetz() {
309        let rt = map_type("timetz", &empty_schema(), TimeCrate::Chrono);
310        assert_eq!(rt.path, "NaiveTime");
311    }
312
313    #[test]
314    fn test_inet() {
315        let rt = map_type("inet", &empty_schema(), TimeCrate::Chrono);
316        assert_eq!(rt.path, "IpNetwork");
317        assert!(rt.needs_import.as_ref().unwrap().contains("ipnetwork"));
318    }
319
320    #[test]
321    fn test_cidr() {
322        let rt = map_type("cidr", &empty_schema(), TimeCrate::Chrono);
323        assert_eq!(rt.path, "IpNetwork");
324    }
325
326    #[test]
327    fn test_oid() {
328        assert_eq!(map_type("oid", &empty_schema(), TimeCrate::Chrono).path, "u32");
329    }
330
331    // --- arrays ---
332
333    #[test]
334    fn test_array_int4() {
335        assert_eq!(map_type("_int4", &empty_schema(), TimeCrate::Chrono).path, "Vec<i32>");
336    }
337
338    #[test]
339    fn test_array_text() {
340        assert_eq!(map_type("_text", &empty_schema(), TimeCrate::Chrono).path, "Vec<String>");
341    }
342
343    #[test]
344    fn test_array_uuid() {
345        let rt = map_type("_uuid", &empty_schema(), TimeCrate::Chrono);
346        assert_eq!(rt.path, "Vec<Uuid>");
347        assert!(rt.needs_import.is_some());
348    }
349
350    #[test]
351    fn test_array_bool() {
352        assert_eq!(map_type("_bool", &empty_schema(), TimeCrate::Chrono).path, "Vec<bool>");
353    }
354
355    #[test]
356    fn test_array_jsonb() {
357        let rt = map_type("_jsonb", &empty_schema(), TimeCrate::Chrono);
358        assert_eq!(rt.path, "Vec<Value>");
359        assert!(rt.needs_import.is_some());
360    }
361
362    #[test]
363    fn test_array_bytea() {
364        assert_eq!(map_type("_bytea", &empty_schema(), TimeCrate::Chrono).path, "Vec<Vec<u8>>");
365    }
366
367    // --- enums/composites/domains ---
368
369    #[test]
370    fn test_enum_status() {
371        let schema = schema_with_enum("status");
372        let rt = map_type("status", &schema, TimeCrate::Chrono);
373        assert_eq!(rt.path, "Status");
374        assert!(rt.needs_import.as_ref().unwrap().contains("super::types::Status"));
375    }
376
377    #[test]
378    fn test_enum_user_role() {
379        let schema = schema_with_enum("user_role");
380        let rt = map_type("user_role", &schema, TimeCrate::Chrono);
381        assert_eq!(rt.path, "UserRole");
382    }
383
384    #[test]
385    fn test_composite_address() {
386        let schema = schema_with_composite("address");
387        let rt = map_type("address", &schema, TimeCrate::Chrono);
388        assert_eq!(rt.path, "Address");
389        assert!(rt.needs_import.as_ref().unwrap().contains("super::types::Address"));
390    }
391
392    #[test]
393    fn test_composite_geo_point() {
394        let schema = schema_with_composite("geo_point");
395        let rt = map_type("geo_point", &schema, TimeCrate::Chrono);
396        assert_eq!(rt.path, "GeoPoint");
397    }
398
399    #[test]
400    fn test_domain_text() {
401        let schema = schema_with_domain("email", "text");
402        let rt = map_type("email", &schema, TimeCrate::Chrono);
403        assert_eq!(rt.path, "String");
404    }
405
406    #[test]
407    fn test_domain_int4() {
408        let schema = schema_with_domain("positive_int", "int4");
409        let rt = map_type("positive_int", &schema, TimeCrate::Chrono);
410        assert_eq!(rt.path, "i32");
411    }
412
413    #[test]
414    fn test_domain_uuid() {
415        let schema = schema_with_domain("my_uuid", "uuid");
416        let rt = map_type("my_uuid", &schema, TimeCrate::Chrono);
417        assert_eq!(rt.path, "Uuid");
418        assert!(rt.needs_import.is_some());
419    }
420
421    // --- arrays of custom types ---
422
423    #[test]
424    fn test_array_enum() {
425        let schema = schema_with_enum("status");
426        let rt = map_type("_status", &schema, TimeCrate::Chrono);
427        assert_eq!(rt.path, "Vec<Status>");
428        assert!(rt.needs_import.is_some());
429    }
430
431    #[test]
432    fn test_array_composite() {
433        let schema = schema_with_composite("address");
434        let rt = map_type("_address", &schema, TimeCrate::Chrono);
435        assert_eq!(rt.path, "Vec<Address>");
436    }
437
438    // --- fallback ---
439
440    #[test]
441    fn test_geometry_fallback() {
442        assert_eq!(map_type("geometry", &empty_schema(), TimeCrate::Chrono).path, "String");
443    }
444
445    #[test]
446    fn test_hstore_fallback() {
447        assert_eq!(map_type("hstore", &empty_schema(), TimeCrate::Chrono).path, "String");
448    }
449
450    // --- time crate ---
451
452    #[test]
453    fn test_timestamptz_time_crate() {
454        let rt = map_type("timestamptz", &empty_schema(), TimeCrate::Time);
455        assert_eq!(rt.path, "OffsetDateTime");
456        assert!(rt.needs_import.as_ref().unwrap().contains("time::OffsetDateTime"));
457    }
458
459    #[test]
460    fn test_timestamp_time_crate() {
461        let rt = map_type("timestamp", &empty_schema(), TimeCrate::Time);
462        assert_eq!(rt.path, "PrimitiveDateTime");
463        assert!(rt.needs_import.as_ref().unwrap().contains("time::PrimitiveDateTime"));
464    }
465
466    #[test]
467    fn test_date_time_crate() {
468        let rt = map_type("date", &empty_schema(), TimeCrate::Time);
469        assert_eq!(rt.path, "Date");
470        assert!(rt.needs_import.as_ref().unwrap().contains("time::Date"));
471    }
472
473    #[test]
474    fn test_time_time_crate() {
475        let rt = map_type("time", &empty_schema(), TimeCrate::Time);
476        assert_eq!(rt.path, "Time");
477        assert!(rt.needs_import.as_ref().unwrap().contains("time::Time"));
478    }
479}