Skip to main content

sqlx_gen/typemap/
postgres.rs

1use super::RustType;
2use crate::cli::TimeCrate;
3use crate::introspect::SchemaInfo;
4
5/// Returns true if the udt_name is a known PostgreSQL builtin type
6/// (i.e., not a fallback to String).
7pub fn is_builtin(udt_name: &str) -> bool {
8    matches!(
9        udt_name,
10        "bool"
11            | "int2"
12            | "smallint"
13            | "smallserial"
14            | "int4"
15            | "int"
16            | "integer"
17            | "serial"
18            | "int8"
19            | "bigint"
20            | "bigserial"
21            | "float4"
22            | "real"
23            | "float8"
24            | "double precision"
25            | "numeric"
26            | "decimal"
27            | "varchar"
28            | "text"
29            | "bpchar"
30            | "char"
31            | "name"
32            | "citext"
33            | "bytea"
34            | "timestamp"
35            | "timestamp without time zone"
36            | "timestamptz"
37            | "timestamp with time zone"
38            | "date"
39            | "time"
40            | "time without time zone"
41            | "timetz"
42            | "time with time zone"
43            | "uuid"
44            | "json"
45            | "jsonb"
46            | "inet"
47            | "cidr"
48            | "interval"
49            | "oid"
50    )
51}
52
53pub fn map_type(udt_name: &str, schema_info: &SchemaInfo, time_crate: TimeCrate) -> RustType {
54    map_type_qualified(udt_name, None, schema_info, time_crate)
55}
56
57// Shared with the rest of codegen — see codegen::rust_type_name_for.
58use crate::codegen::rust_type_name_for as rust_type_name_inner;
59
60fn rust_type_name(schema: &str, name: &str, schema_info: &SchemaInfo) -> String {
61    rust_type_name_inner(schema_info, schema, name)
62}
63
64/// Map a PG type name to a Rust type, respecting `udt_schema` when present so
65/// that two schemas declaring the same name (e.g. `auth.role` vs
66/// `billing.role`) resolve to distinct Rust idents.
67pub fn map_type_qualified(
68    udt_name: &str,
69    udt_schema: Option<&str>,
70    schema_info: &SchemaInfo,
71    time_crate: TimeCrate,
72) -> RustType {
73    // Handle array types: PG's information_schema may report them either as
74    // `_int4` (information_schema.columns.udt_name) or `integer[]`
75    // (pg_catalog.format_type). Both should produce Vec<T>.
76    if let Some(inner) = udt_name.strip_prefix('_') {
77        let inner_type = map_type_qualified(inner, udt_schema, schema_info, time_crate);
78        return inner_type.wrap_vec();
79    }
80    if let Some(inner) = udt_name.strip_suffix("[]") {
81        let inner_type = map_type_qualified(inner.trim(), udt_schema, schema_info, time_crate);
82        return inner_type.wrap_vec();
83    }
84
85    // Schema-aware enum lookup. When udt_schema is provided we restrict to
86    // exact (schema, name) matches first; otherwise we fall back to the
87    // first name match so that legacy callers (and synthetic test fixtures)
88    // keep working.
89    let enum_match = schema_info
90        .enums
91        .iter()
92        .find(|e| e.name == udt_name && udt_schema.map(|s| s == e.schema_name).unwrap_or(true));
93    if let Some(e) = enum_match {
94        let name = rust_type_name(&e.schema_name, &e.name, schema_info);
95        return RustType::with_import(&name, &format!("use super::types::{};", name));
96    }
97
98    let composite_match = schema_info
99        .composite_types
100        .iter()
101        .find(|c| c.name == udt_name && udt_schema.map(|s| s == c.schema_name).unwrap_or(true));
102    if let Some(c) = composite_match {
103        let name = rust_type_name(&c.schema_name, &c.name, schema_info);
104        return RustType::with_import(&name, &format!("use super::types::{};", name));
105    }
106
107    let domain_match = schema_info
108        .domains
109        .iter()
110        .find(|d| d.name == udt_name && udt_schema.map(|s| s == d.schema_name).unwrap_or(true));
111    if let Some(domain) = domain_match {
112        // Map to the domain's base type — base type lives in pg_catalog so
113        // schema is irrelevant for the recursive lookup.
114        return map_type_qualified(&domain.base_type, None, schema_info, time_crate);
115    }
116
117    match udt_name {
118        "bool" => RustType::simple("bool"),
119        "int2" | "smallint" | "smallserial" => RustType::simple("i16"),
120        "int4" | "int" | "integer" | "serial" => RustType::simple("i32"),
121        "int8" | "bigint" | "bigserial" => RustType::simple("i64"),
122        "float4" | "real" => RustType::simple("f32"),
123        "float8" | "double precision" => RustType::simple("f64"),
124        "numeric" | "decimal" => RustType::with_import("Decimal", "use rust_decimal::Decimal;"),
125        "varchar" | "text" | "bpchar" | "char" | "name" | "citext" => RustType::simple("String"),
126        "bytea" => RustType::simple("Vec<u8>"),
127        "timestamp" | "timestamp without time zone" => match time_crate {
128            TimeCrate::Chrono => {
129                RustType::with_import("NaiveDateTime", "use chrono::NaiveDateTime;")
130            }
131            TimeCrate::Time => {
132                RustType::with_import("PrimitiveDateTime", "use time::PrimitiveDateTime;")
133            }
134        },
135        "timestamptz" | "timestamp with time zone" => match time_crate {
136            TimeCrate::Chrono => {
137                RustType::with_import("DateTime<Utc>", "use chrono::{DateTime, Utc};")
138            }
139            TimeCrate::Time => RustType::with_import("OffsetDateTime", "use time::OffsetDateTime;"),
140        },
141        "date" => match time_crate {
142            TimeCrate::Chrono => RustType::with_import("NaiveDate", "use chrono::NaiveDate;"),
143            TimeCrate::Time => RustType::with_import("Date", "use time::Date;"),
144        },
145        "time" | "time without time zone" => match time_crate {
146            TimeCrate::Chrono => RustType::with_import("NaiveTime", "use chrono::NaiveTime;"),
147            TimeCrate::Time => RustType::with_import("Time", "use time::Time;"),
148        },
149        "timetz" | "time with time zone" => match time_crate {
150            TimeCrate::Chrono => RustType::with_import("NaiveTime", "use chrono::NaiveTime;"),
151            TimeCrate::Time => RustType::with_import("Time", "use time::Time;"),
152        },
153        "uuid" => RustType::with_import("Uuid", "use uuid::Uuid;"),
154        "json" | "jsonb" => RustType::with_import("Value", "use serde_json::Value;"),
155        "inet" | "cidr" => RustType::with_import("IpNetwork", "use ipnetwork::IpNetwork;"),
156        "interval" => RustType::with_import("PgInterval", "use sqlx::postgres::types::PgInterval;"),
157        "oid" => RustType::simple("u32"),
158        _ => RustType::simple("String"), // fallback
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::cli::TimeCrate;
166    use crate::introspect::{CompositeTypeInfo, DomainInfo, EnumInfo};
167
168    fn empty_schema() -> SchemaInfo {
169        SchemaInfo::default()
170    }
171
172    fn schema_with_enum(name: &str) -> SchemaInfo {
173        SchemaInfo {
174            enums: vec![EnumInfo {
175                schema_name: "public".to_string(),
176                name: name.to_string(),
177                variants: vec!["a".to_string()],
178                default_variant: None,
179            }],
180            ..Default::default()
181        }
182    }
183
184    fn schema_with_composite(name: &str) -> SchemaInfo {
185        SchemaInfo {
186            composite_types: vec![CompositeTypeInfo {
187                schema_name: "public".to_string(),
188                name: name.to_string(),
189                fields: vec![],
190            }],
191            ..Default::default()
192        }
193    }
194
195    fn schema_with_domain(name: &str, base: &str) -> SchemaInfo {
196        SchemaInfo {
197            domains: vec![DomainInfo {
198                schema_name: "public".to_string(),
199                name: name.to_string(),
200                base_type: base.to_string(),
201            }],
202            ..Default::default()
203        }
204    }
205
206    // --- builtins ---
207
208    #[test]
209    fn test_bool() {
210        assert_eq!(
211            map_type("bool", &empty_schema(), TimeCrate::Chrono).path,
212            "bool"
213        );
214    }
215
216    #[test]
217    fn test_int2() {
218        assert_eq!(
219            map_type("int2", &empty_schema(), TimeCrate::Chrono).path,
220            "i16"
221        );
222    }
223
224    #[test]
225    fn test_smallint() {
226        assert_eq!(
227            map_type("smallint", &empty_schema(), TimeCrate::Chrono).path,
228            "i16"
229        );
230    }
231
232    #[test]
233    fn test_smallserial() {
234        assert_eq!(
235            map_type("smallserial", &empty_schema(), TimeCrate::Chrono).path,
236            "i16"
237        );
238    }
239
240    #[test]
241    fn test_int4() {
242        assert_eq!(
243            map_type("int4", &empty_schema(), TimeCrate::Chrono).path,
244            "i32"
245        );
246    }
247
248    #[test]
249    fn test_integer() {
250        assert_eq!(
251            map_type("integer", &empty_schema(), TimeCrate::Chrono).path,
252            "i32"
253        );
254    }
255
256    #[test]
257    fn test_serial() {
258        assert_eq!(
259            map_type("serial", &empty_schema(), TimeCrate::Chrono).path,
260            "i32"
261        );
262    }
263
264    #[test]
265    fn test_int8() {
266        assert_eq!(
267            map_type("int8", &empty_schema(), TimeCrate::Chrono).path,
268            "i64"
269        );
270    }
271
272    #[test]
273    fn test_bigint() {
274        assert_eq!(
275            map_type("bigint", &empty_schema(), TimeCrate::Chrono).path,
276            "i64"
277        );
278    }
279
280    #[test]
281    fn test_bigserial() {
282        assert_eq!(
283            map_type("bigserial", &empty_schema(), TimeCrate::Chrono).path,
284            "i64"
285        );
286    }
287
288    #[test]
289    fn test_float4() {
290        assert_eq!(
291            map_type("float4", &empty_schema(), TimeCrate::Chrono).path,
292            "f32"
293        );
294    }
295
296    #[test]
297    fn test_real() {
298        assert_eq!(
299            map_type("real", &empty_schema(), TimeCrate::Chrono).path,
300            "f32"
301        );
302    }
303
304    #[test]
305    fn test_float8() {
306        assert_eq!(
307            map_type("float8", &empty_schema(), TimeCrate::Chrono).path,
308            "f64"
309        );
310    }
311
312    #[test]
313    fn test_double_precision() {
314        assert_eq!(
315            map_type("double precision", &empty_schema(), TimeCrate::Chrono).path,
316            "f64"
317        );
318    }
319
320    #[test]
321    fn test_numeric() {
322        let rt = map_type("numeric", &empty_schema(), TimeCrate::Chrono);
323        assert_eq!(rt.path, "Decimal");
324        assert!(rt.needs_import.as_ref().unwrap().contains("rust_decimal"));
325    }
326
327    #[test]
328    fn test_decimal() {
329        let rt = map_type("decimal", &empty_schema(), TimeCrate::Chrono);
330        assert_eq!(rt.path, "Decimal");
331    }
332
333    #[test]
334    fn test_varchar() {
335        assert_eq!(
336            map_type("varchar", &empty_schema(), TimeCrate::Chrono).path,
337            "String"
338        );
339    }
340
341    #[test]
342    fn test_text() {
343        assert_eq!(
344            map_type("text", &empty_schema(), TimeCrate::Chrono).path,
345            "String"
346        );
347    }
348
349    #[test]
350    fn test_bpchar() {
351        assert_eq!(
352            map_type("bpchar", &empty_schema(), TimeCrate::Chrono).path,
353            "String"
354        );
355    }
356
357    #[test]
358    fn test_citext() {
359        assert_eq!(
360            map_type("citext", &empty_schema(), TimeCrate::Chrono).path,
361            "String"
362        );
363    }
364
365    #[test]
366    fn test_name() {
367        assert_eq!(
368            map_type("name", &empty_schema(), TimeCrate::Chrono).path,
369            "String"
370        );
371    }
372
373    #[test]
374    fn test_bytea() {
375        assert_eq!(
376            map_type("bytea", &empty_schema(), TimeCrate::Chrono).path,
377            "Vec<u8>"
378        );
379    }
380
381    #[test]
382    fn test_uuid() {
383        let rt = map_type("uuid", &empty_schema(), TimeCrate::Chrono);
384        assert_eq!(rt.path, "Uuid");
385        assert!(rt.needs_import.as_ref().unwrap().contains("uuid::Uuid"));
386    }
387
388    #[test]
389    fn test_json() {
390        let rt = map_type("json", &empty_schema(), TimeCrate::Chrono);
391        assert_eq!(rt.path, "Value");
392        assert!(rt.needs_import.as_ref().unwrap().contains("serde_json"));
393    }
394
395    #[test]
396    fn test_jsonb() {
397        let rt = map_type("jsonb", &empty_schema(), TimeCrate::Chrono);
398        assert_eq!(rt.path, "Value");
399    }
400
401    #[test]
402    fn test_timestamp() {
403        let rt = map_type("timestamp", &empty_schema(), TimeCrate::Chrono);
404        assert_eq!(rt.path, "NaiveDateTime");
405        assert!(rt.needs_import.as_ref().unwrap().contains("chrono"));
406    }
407
408    #[test]
409    fn test_timestamptz() {
410        let rt = map_type("timestamptz", &empty_schema(), TimeCrate::Chrono);
411        assert_eq!(rt.path, "DateTime<Utc>");
412        assert!(rt.needs_import.as_ref().unwrap().contains("chrono"));
413    }
414
415    #[test]
416    fn test_date() {
417        let rt = map_type("date", &empty_schema(), TimeCrate::Chrono);
418        assert_eq!(rt.path, "NaiveDate");
419    }
420
421    #[test]
422    fn test_time() {
423        let rt = map_type("time", &empty_schema(), TimeCrate::Chrono);
424        assert_eq!(rt.path, "NaiveTime");
425    }
426
427    #[test]
428    fn test_timetz() {
429        let rt = map_type("timetz", &empty_schema(), TimeCrate::Chrono);
430        assert_eq!(rt.path, "NaiveTime");
431    }
432
433    #[test]
434    fn test_inet() {
435        let rt = map_type("inet", &empty_schema(), TimeCrate::Chrono);
436        assert_eq!(rt.path, "IpNetwork");
437        assert!(rt.needs_import.as_ref().unwrap().contains("ipnetwork"));
438    }
439
440    #[test]
441    fn test_cidr() {
442        let rt = map_type("cidr", &empty_schema(), TimeCrate::Chrono);
443        assert_eq!(rt.path, "IpNetwork");
444    }
445
446    #[test]
447    fn test_oid() {
448        assert_eq!(
449            map_type("oid", &empty_schema(), TimeCrate::Chrono).path,
450            "u32"
451        );
452    }
453
454    #[test]
455    fn test_interval_uses_pg_interval() {
456        let rt = map_type("interval", &empty_schema(), TimeCrate::Chrono);
457        assert_eq!(rt.path, "PgInterval");
458        assert!(rt.needs_import.as_ref().unwrap().contains("PgInterval"));
459    }
460
461    // --- arrays ---
462
463    #[test]
464    fn test_array_int4() {
465        assert_eq!(
466            map_type("_int4", &empty_schema(), TimeCrate::Chrono).path,
467            "Vec<i32>"
468        );
469    }
470
471    #[test]
472    fn test_array_bracket_notation() {
473        assert_eq!(
474            map_type("integer[]", &empty_schema(), TimeCrate::Chrono).path,
475            "Vec<i32>"
476        );
477    }
478
479    #[test]
480    fn test_array_bracket_text() {
481        assert_eq!(
482            map_type("text[]", &empty_schema(), TimeCrate::Chrono).path,
483            "Vec<String>"
484        );
485    }
486
487    #[test]
488    fn test_array_text() {
489        assert_eq!(
490            map_type("_text", &empty_schema(), TimeCrate::Chrono).path,
491            "Vec<String>"
492        );
493    }
494
495    #[test]
496    fn test_array_uuid() {
497        let rt = map_type("_uuid", &empty_schema(), TimeCrate::Chrono);
498        assert_eq!(rt.path, "Vec<Uuid>");
499        assert!(rt.needs_import.is_some());
500    }
501
502    #[test]
503    fn test_array_bool() {
504        assert_eq!(
505            map_type("_bool", &empty_schema(), TimeCrate::Chrono).path,
506            "Vec<bool>"
507        );
508    }
509
510    #[test]
511    fn test_array_jsonb() {
512        let rt = map_type("_jsonb", &empty_schema(), TimeCrate::Chrono);
513        assert_eq!(rt.path, "Vec<Value>");
514        assert!(rt.needs_import.is_some());
515    }
516
517    #[test]
518    fn test_array_bytea() {
519        assert_eq!(
520            map_type("_bytea", &empty_schema(), TimeCrate::Chrono).path,
521            "Vec<Vec<u8>>"
522        );
523    }
524
525    // --- enums/composites/domains ---
526
527    #[test]
528    fn test_enum_status() {
529        let schema = schema_with_enum("status");
530        let rt = map_type("status", &schema, TimeCrate::Chrono);
531        assert_eq!(rt.path, "Status");
532        assert!(rt
533            .needs_import
534            .as_ref()
535            .unwrap()
536            .contains("super::types::Status"));
537    }
538
539    #[test]
540    fn test_enum_user_role() {
541        let schema = schema_with_enum("user_role");
542        let rt = map_type("user_role", &schema, TimeCrate::Chrono);
543        assert_eq!(rt.path, "UserRole");
544    }
545
546    #[test]
547    fn test_composite_address() {
548        let schema = schema_with_composite("address");
549        let rt = map_type("address", &schema, TimeCrate::Chrono);
550        assert_eq!(rt.path, "Address");
551        assert!(rt
552            .needs_import
553            .as_ref()
554            .unwrap()
555            .contains("super::types::Address"));
556    }
557
558    #[test]
559    fn test_composite_geo_point() {
560        let schema = schema_with_composite("geo_point");
561        let rt = map_type("geo_point", &schema, TimeCrate::Chrono);
562        assert_eq!(rt.path, "GeoPoint");
563    }
564
565    #[test]
566    fn test_domain_text() {
567        let schema = schema_with_domain("email", "text");
568        let rt = map_type("email", &schema, TimeCrate::Chrono);
569        assert_eq!(rt.path, "String");
570    }
571
572    #[test]
573    fn test_domain_int4() {
574        let schema = schema_with_domain("positive_int", "int4");
575        let rt = map_type("positive_int", &schema, TimeCrate::Chrono);
576        assert_eq!(rt.path, "i32");
577    }
578
579    #[test]
580    fn test_domain_uuid() {
581        let schema = schema_with_domain("my_uuid", "uuid");
582        let rt = map_type("my_uuid", &schema, TimeCrate::Chrono);
583        assert_eq!(rt.path, "Uuid");
584        assert!(rt.needs_import.is_some());
585    }
586
587    // --- arrays of custom types ---
588
589    #[test]
590    fn test_array_enum() {
591        let schema = schema_with_enum("status");
592        let rt = map_type("_status", &schema, TimeCrate::Chrono);
593        assert_eq!(rt.path, "Vec<Status>");
594        assert!(rt.needs_import.is_some());
595    }
596
597    #[test]
598    fn test_array_composite() {
599        let schema = schema_with_composite("address");
600        let rt = map_type("_address", &schema, TimeCrate::Chrono);
601        assert_eq!(rt.path, "Vec<Address>");
602    }
603
604    // --- fallback ---
605
606    #[test]
607    fn test_geometry_fallback() {
608        assert_eq!(
609            map_type("geometry", &empty_schema(), TimeCrate::Chrono).path,
610            "String"
611        );
612    }
613
614    #[test]
615    fn test_hstore_fallback() {
616        assert_eq!(
617            map_type("hstore", &empty_schema(), TimeCrate::Chrono).path,
618            "String"
619        );
620    }
621
622    // --- time crate ---
623
624    #[test]
625    fn test_timestamptz_time_crate() {
626        let rt = map_type("timestamptz", &empty_schema(), TimeCrate::Time);
627        assert_eq!(rt.path, "OffsetDateTime");
628        assert!(rt
629            .needs_import
630            .as_ref()
631            .unwrap()
632            .contains("time::OffsetDateTime"));
633    }
634
635    #[test]
636    fn test_timestamp_time_crate() {
637        let rt = map_type("timestamp", &empty_schema(), TimeCrate::Time);
638        assert_eq!(rt.path, "PrimitiveDateTime");
639        assert!(rt
640            .needs_import
641            .as_ref()
642            .unwrap()
643            .contains("time::PrimitiveDateTime"));
644    }
645
646    #[test]
647    fn test_date_time_crate() {
648        let rt = map_type("date", &empty_schema(), TimeCrate::Time);
649        assert_eq!(rt.path, "Date");
650        assert!(rt.needs_import.as_ref().unwrap().contains("time::Date"));
651    }
652
653    #[test]
654    fn test_time_time_crate() {
655        let rt = map_type("time", &empty_schema(), TimeCrate::Time);
656        assert_eq!(rt.path, "Time");
657        assert!(rt.needs_import.as_ref().unwrap().contains("time::Time"));
658    }
659}