Skip to main content

sqlx_gen/typemap/
postgres.rs

1use heck::ToUpperCamelCase;
2
3use super::RustType;
4use crate::introspect::SchemaInfo;
5
6/// Returns true if the udt_name is a known PostgreSQL builtin type
7/// (i.e., not a fallback to String).
8pub fn is_builtin(udt_name: &str) -> bool {
9    matches!(
10        udt_name,
11        "bool"
12            | "int2" | "smallint" | "smallserial"
13            | "int4" | "int" | "integer" | "serial"
14            | "int8" | "bigint" | "bigserial"
15            | "float4" | "real"
16            | "float8" | "double precision"
17            | "numeric" | "decimal"
18            | "varchar" | "text" | "bpchar" | "char" | "name" | "citext"
19            | "bytea"
20            | "timestamp" | "timestamp without time zone"
21            | "timestamptz" | "timestamp with time zone"
22            | "date"
23            | "time" | "time without time zone"
24            | "timetz" | "time with time zone"
25            | "uuid"
26            | "json" | "jsonb"
27            | "inet" | "cidr"
28            | "oid"
29    )
30}
31
32pub fn map_type(udt_name: &str, schema_info: &SchemaInfo) -> RustType {
33    // Handle array types (prefixed with '_' in PG)
34    if let Some(inner) = udt_name.strip_prefix('_') {
35        let inner_type = map_type(inner, schema_info);
36        return inner_type.wrap_vec();
37    }
38
39    // Check if it's a known enum
40    if schema_info.enums.iter().any(|e| e.name == udt_name) {
41        let name = udt_name.to_upper_camel_case();
42        return RustType::with_import(&name, &format!("use super::types::{};", name));
43    }
44
45    // Check if it's a known composite type
46    if schema_info.composite_types.iter().any(|c| c.name == udt_name) {
47        let name = udt_name.to_upper_camel_case();
48        return RustType::with_import(&name, &format!("use super::types::{};", name));
49    }
50
51    // Check if it's a known domain
52    if let Some(domain) = schema_info.domains.iter().find(|d| d.name == udt_name) {
53        // Map to the domain's base type
54        return map_type(&domain.base_type, schema_info);
55    }
56
57    match udt_name {
58        "bool" => RustType::simple("bool"),
59        "int2" | "smallint" | "smallserial" => RustType::simple("i16"),
60        "int4" | "int" | "integer" | "serial" => RustType::simple("i32"),
61        "int8" | "bigint" | "bigserial" => RustType::simple("i64"),
62        "float4" | "real" => RustType::simple("f32"),
63        "float8" | "double precision" => RustType::simple("f64"),
64        "numeric" | "decimal" => {
65            RustType::with_import("Decimal", "use rust_decimal::Decimal;")
66        }
67        "varchar" | "text" | "bpchar" | "char" | "name" | "citext" => RustType::simple("String"),
68        "bytea" => RustType::simple("Vec<u8>"),
69        "timestamp" | "timestamp without time zone" => {
70            RustType::with_import("NaiveDateTime", "use chrono::NaiveDateTime;")
71        }
72        "timestamptz" | "timestamp with time zone" => {
73            RustType::with_import("DateTime<Utc>", "use chrono::{DateTime, Utc};")
74        }
75        "date" => RustType::with_import("NaiveDate", "use chrono::NaiveDate;"),
76        "time" | "time without time zone" => {
77            RustType::with_import("NaiveTime", "use chrono::NaiveTime;")
78        }
79        "timetz" | "time with time zone" => {
80            RustType::with_import("NaiveTime", "use chrono::NaiveTime;")
81        }
82        "uuid" => RustType::with_import("Uuid", "use uuid::Uuid;"),
83        "json" | "jsonb" => {
84            RustType::with_import("Value", "use serde_json::Value;")
85        }
86        "inet" | "cidr" => {
87            RustType::with_import("IpNetwork", "use ipnetwork::IpNetwork;")
88        }
89        "oid" => RustType::simple("u32"),
90        _ => RustType::simple("String"), // fallback
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97    use crate::introspect::{CompositeTypeInfo, DomainInfo, EnumInfo};
98
99    fn empty_schema() -> SchemaInfo {
100        SchemaInfo::default()
101    }
102
103    fn schema_with_enum(name: &str) -> SchemaInfo {
104        SchemaInfo {
105            enums: vec![EnumInfo {
106                schema_name: "public".to_string(),
107                name: name.to_string(),
108                variants: vec!["a".to_string()],
109                default_variant: None,
110            }],
111            ..Default::default()
112        }
113    }
114
115    fn schema_with_composite(name: &str) -> SchemaInfo {
116        SchemaInfo {
117            composite_types: vec![CompositeTypeInfo {
118                schema_name: "public".to_string(),
119                name: name.to_string(),
120                fields: vec![],
121            }],
122            ..Default::default()
123        }
124    }
125
126    fn schema_with_domain(name: &str, base: &str) -> SchemaInfo {
127        SchemaInfo {
128            domains: vec![DomainInfo {
129                schema_name: "public".to_string(),
130                name: name.to_string(),
131                base_type: base.to_string(),
132            }],
133            ..Default::default()
134        }
135    }
136
137    // --- builtins ---
138
139    #[test]
140    fn test_bool() {
141        assert_eq!(map_type("bool", &empty_schema()).path, "bool");
142    }
143
144    #[test]
145    fn test_int2() {
146        assert_eq!(map_type("int2", &empty_schema()).path, "i16");
147    }
148
149    #[test]
150    fn test_smallint() {
151        assert_eq!(map_type("smallint", &empty_schema()).path, "i16");
152    }
153
154    #[test]
155    fn test_smallserial() {
156        assert_eq!(map_type("smallserial", &empty_schema()).path, "i16");
157    }
158
159    #[test]
160    fn test_int4() {
161        assert_eq!(map_type("int4", &empty_schema()).path, "i32");
162    }
163
164    #[test]
165    fn test_integer() {
166        assert_eq!(map_type("integer", &empty_schema()).path, "i32");
167    }
168
169    #[test]
170    fn test_serial() {
171        assert_eq!(map_type("serial", &empty_schema()).path, "i32");
172    }
173
174    #[test]
175    fn test_int8() {
176        assert_eq!(map_type("int8", &empty_schema()).path, "i64");
177    }
178
179    #[test]
180    fn test_bigint() {
181        assert_eq!(map_type("bigint", &empty_schema()).path, "i64");
182    }
183
184    #[test]
185    fn test_bigserial() {
186        assert_eq!(map_type("bigserial", &empty_schema()).path, "i64");
187    }
188
189    #[test]
190    fn test_float4() {
191        assert_eq!(map_type("float4", &empty_schema()).path, "f32");
192    }
193
194    #[test]
195    fn test_real() {
196        assert_eq!(map_type("real", &empty_schema()).path, "f32");
197    }
198
199    #[test]
200    fn test_float8() {
201        assert_eq!(map_type("float8", &empty_schema()).path, "f64");
202    }
203
204    #[test]
205    fn test_double_precision() {
206        assert_eq!(map_type("double precision", &empty_schema()).path, "f64");
207    }
208
209    #[test]
210    fn test_numeric() {
211        let rt = map_type("numeric", &empty_schema());
212        assert_eq!(rt.path, "Decimal");
213        assert!(rt.needs_import.as_ref().unwrap().contains("rust_decimal"));
214    }
215
216    #[test]
217    fn test_decimal() {
218        let rt = map_type("decimal", &empty_schema());
219        assert_eq!(rt.path, "Decimal");
220    }
221
222    #[test]
223    fn test_varchar() {
224        assert_eq!(map_type("varchar", &empty_schema()).path, "String");
225    }
226
227    #[test]
228    fn test_text() {
229        assert_eq!(map_type("text", &empty_schema()).path, "String");
230    }
231
232    #[test]
233    fn test_bpchar() {
234        assert_eq!(map_type("bpchar", &empty_schema()).path, "String");
235    }
236
237    #[test]
238    fn test_citext() {
239        assert_eq!(map_type("citext", &empty_schema()).path, "String");
240    }
241
242    #[test]
243    fn test_name() {
244        assert_eq!(map_type("name", &empty_schema()).path, "String");
245    }
246
247    #[test]
248    fn test_bytea() {
249        assert_eq!(map_type("bytea", &empty_schema()).path, "Vec<u8>");
250    }
251
252    #[test]
253    fn test_uuid() {
254        let rt = map_type("uuid", &empty_schema());
255        assert_eq!(rt.path, "Uuid");
256        assert!(rt.needs_import.as_ref().unwrap().contains("uuid::Uuid"));
257    }
258
259    #[test]
260    fn test_json() {
261        let rt = map_type("json", &empty_schema());
262        assert_eq!(rt.path, "Value");
263        assert!(rt.needs_import.as_ref().unwrap().contains("serde_json"));
264    }
265
266    #[test]
267    fn test_jsonb() {
268        let rt = map_type("jsonb", &empty_schema());
269        assert_eq!(rt.path, "Value");
270    }
271
272    #[test]
273    fn test_timestamp() {
274        let rt = map_type("timestamp", &empty_schema());
275        assert_eq!(rt.path, "NaiveDateTime");
276        assert!(rt.needs_import.as_ref().unwrap().contains("chrono"));
277    }
278
279    #[test]
280    fn test_timestamptz() {
281        let rt = map_type("timestamptz", &empty_schema());
282        assert_eq!(rt.path, "DateTime<Utc>");
283        assert!(rt.needs_import.as_ref().unwrap().contains("chrono"));
284    }
285
286    #[test]
287    fn test_date() {
288        let rt = map_type("date", &empty_schema());
289        assert_eq!(rt.path, "NaiveDate");
290    }
291
292    #[test]
293    fn test_time() {
294        let rt = map_type("time", &empty_schema());
295        assert_eq!(rt.path, "NaiveTime");
296    }
297
298    #[test]
299    fn test_timetz() {
300        let rt = map_type("timetz", &empty_schema());
301        assert_eq!(rt.path, "NaiveTime");
302    }
303
304    #[test]
305    fn test_inet() {
306        let rt = map_type("inet", &empty_schema());
307        assert_eq!(rt.path, "IpNetwork");
308        assert!(rt.needs_import.as_ref().unwrap().contains("ipnetwork"));
309    }
310
311    #[test]
312    fn test_cidr() {
313        let rt = map_type("cidr", &empty_schema());
314        assert_eq!(rt.path, "IpNetwork");
315    }
316
317    #[test]
318    fn test_oid() {
319        assert_eq!(map_type("oid", &empty_schema()).path, "u32");
320    }
321
322    // --- arrays ---
323
324    #[test]
325    fn test_array_int4() {
326        assert_eq!(map_type("_int4", &empty_schema()).path, "Vec<i32>");
327    }
328
329    #[test]
330    fn test_array_text() {
331        assert_eq!(map_type("_text", &empty_schema()).path, "Vec<String>");
332    }
333
334    #[test]
335    fn test_array_uuid() {
336        let rt = map_type("_uuid", &empty_schema());
337        assert_eq!(rt.path, "Vec<Uuid>");
338        assert!(rt.needs_import.is_some());
339    }
340
341    #[test]
342    fn test_array_bool() {
343        assert_eq!(map_type("_bool", &empty_schema()).path, "Vec<bool>");
344    }
345
346    #[test]
347    fn test_array_jsonb() {
348        let rt = map_type("_jsonb", &empty_schema());
349        assert_eq!(rt.path, "Vec<Value>");
350        assert!(rt.needs_import.is_some());
351    }
352
353    #[test]
354    fn test_array_bytea() {
355        assert_eq!(map_type("_bytea", &empty_schema()).path, "Vec<Vec<u8>>");
356    }
357
358    // --- enums/composites/domains ---
359
360    #[test]
361    fn test_enum_status() {
362        let schema = schema_with_enum("status");
363        let rt = map_type("status", &schema);
364        assert_eq!(rt.path, "Status");
365        assert!(rt.needs_import.as_ref().unwrap().contains("super::types::Status"));
366    }
367
368    #[test]
369    fn test_enum_user_role() {
370        let schema = schema_with_enum("user_role");
371        let rt = map_type("user_role", &schema);
372        assert_eq!(rt.path, "UserRole");
373    }
374
375    #[test]
376    fn test_composite_address() {
377        let schema = schema_with_composite("address");
378        let rt = map_type("address", &schema);
379        assert_eq!(rt.path, "Address");
380        assert!(rt.needs_import.as_ref().unwrap().contains("super::types::Address"));
381    }
382
383    #[test]
384    fn test_composite_geo_point() {
385        let schema = schema_with_composite("geo_point");
386        let rt = map_type("geo_point", &schema);
387        assert_eq!(rt.path, "GeoPoint");
388    }
389
390    #[test]
391    fn test_domain_text() {
392        let schema = schema_with_domain("email", "text");
393        let rt = map_type("email", &schema);
394        assert_eq!(rt.path, "String");
395    }
396
397    #[test]
398    fn test_domain_int4() {
399        let schema = schema_with_domain("positive_int", "int4");
400        let rt = map_type("positive_int", &schema);
401        assert_eq!(rt.path, "i32");
402    }
403
404    #[test]
405    fn test_domain_uuid() {
406        let schema = schema_with_domain("my_uuid", "uuid");
407        let rt = map_type("my_uuid", &schema);
408        assert_eq!(rt.path, "Uuid");
409        assert!(rt.needs_import.is_some());
410    }
411
412    // --- arrays of custom types ---
413
414    #[test]
415    fn test_array_enum() {
416        let schema = schema_with_enum("status");
417        let rt = map_type("_status", &schema);
418        assert_eq!(rt.path, "Vec<Status>");
419        assert!(rt.needs_import.is_some());
420    }
421
422    #[test]
423    fn test_array_composite() {
424        let schema = schema_with_composite("address");
425        let rt = map_type("_address", &schema);
426        assert_eq!(rt.path, "Vec<Address>");
427    }
428
429    // --- fallback ---
430
431    #[test]
432    fn test_geometry_fallback() {
433        assert_eq!(map_type("geometry", &empty_schema()).path, "String");
434    }
435
436    #[test]
437    fn test_hstore_fallback() {
438        assert_eq!(map_type("hstore", &empty_schema()).path, "String");
439    }
440}