sqlx_gen/codegen/
domain_gen.rs1use std::collections::{BTreeSet, HashMap};
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::introspect::{DomainInfo, SchemaInfo};
9use crate::typemap;
10
11pub fn generate_domain(
12 domain: &DomainInfo,
13 db_kind: DatabaseKind,
14 schema_info: &SchemaInfo,
15 type_overrides: &HashMap<String, String>,
16 time_crate: TimeCrate,
17) -> (TokenStream, BTreeSet<String>) {
18 let mut imports = BTreeSet::new();
19 let alias_name = format_ident!("{}", domain.name.to_upper_camel_case());
20
21 let doc = format!(
22 "Domain: {}.{} (base: {})",
23 domain.schema_name, domain.name, domain.base_type
24 );
25
26 let fake_col = crate::introspect::ColumnInfo {
28 name: String::new(),
29 data_type: domain.base_type.clone(),
30 udt_name: domain.base_type.clone(),
31 is_nullable: false,
32 is_primary_key: false,
33 ordinal_position: 0,
34 schema_name: domain.schema_name.clone(),
35 column_default: None,
36 };
37
38 let rust_type = typemap::map_column(&fake_col, db_kind, schema_info, type_overrides, time_crate);
39 if let Some(imp) = &rust_type.needs_import {
40 imports.insert(imp.clone());
41 }
42
43 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
44 let fallback = format_ident!("String");
45 quote! { #fallback }
46 });
47
48 let domain_doc = "sqlx_gen:kind=domain";
49 let tokens = quote! {
50 #[doc = #doc]
51 #[doc = #domain_doc]
52 pub type #alias_name = #type_tokens;
53 };
54
55 (tokens, imports)
56}
57
58#[cfg(test)]
59mod tests {
60 use super::*;
61 use crate::codegen::parse_and_format;
62
63 fn make_domain(name: &str, base: &str) -> DomainInfo {
64 DomainInfo {
65 schema_name: "public".to_string(),
66 name: name.to_string(),
67 base_type: base.to_string(),
68 }
69 }
70
71 fn gen(domain: &DomainInfo) -> (String, BTreeSet<String>) {
72 let schema = SchemaInfo::default();
73 let (tokens, imports) = generate_domain(domain, DatabaseKind::Postgres, &schema, &HashMap::new(), TimeCrate::Chrono);
74 (parse_and_format(&tokens), imports)
75 }
76
77 fn gen_with_overrides(domain: &DomainInfo, overrides: &HashMap<String, String>) -> (String, BTreeSet<String>) {
78 let schema = SchemaInfo::default();
79 let (tokens, imports) = generate_domain(domain, DatabaseKind::Postgres, &schema, overrides, TimeCrate::Chrono);
80 (parse_and_format(&tokens), imports)
81 }
82
83 #[test]
84 fn test_domain_text() {
85 let d = make_domain("email", "text");
86 let (code, _) = gen(&d);
87 assert!(code.contains("pub type Email = String"));
88 }
89
90 #[test]
91 fn test_domain_int4() {
92 let d = make_domain("positive_int", "int4");
93 let (code, _) = gen(&d);
94 assert!(code.contains("pub type PositiveInt = i32"));
95 }
96
97 #[test]
98 fn test_domain_uuid() {
99 let d = make_domain("my_uuid", "uuid");
100 let (code, imports) = gen(&d);
101 assert!(code.contains("pub type MyUuid = Uuid"));
102 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
103 }
104
105 #[test]
106 fn test_doc_comment() {
107 let d = make_domain("email", "text");
108 let (code, _) = gen(&d);
109 assert!(code.contains("Domain: public.email (base: text)"));
110 }
111
112 #[test]
113 fn test_import_when_needed() {
114 let d = make_domain("my_uuid", "uuid");
115 let (_, imports) = gen(&d);
116 assert!(!imports.is_empty());
117 }
118
119 #[test]
120 fn test_no_import_simple_type() {
121 let d = make_domain("email", "text");
122 let (_, imports) = gen(&d);
123 assert!(imports.is_empty());
124 }
125
126 #[test]
127 fn test_pascal_case_name() {
128 let d = make_domain("email_address", "text");
129 let (code, _) = gen(&d);
130 assert!(code.contains("pub type EmailAddress"));
131 }
132
133 #[test]
134 fn test_type_override() {
135 let d = make_domain("json_data", "jsonb");
136 let mut overrides = HashMap::new();
137 overrides.insert("jsonb".to_string(), "MyJson".to_string());
138 let (code, _) = gen_with_overrides(&d, &overrides);
139 assert!(code.contains("pub type JsonData = MyJson"));
140 }
141
142 #[test]
143 fn test_domain_jsonb() {
144 let d = make_domain("data", "jsonb");
145 let (code, imports) = gen(&d);
146 assert!(code.contains("Value"));
147 assert!(imports.iter().any(|i| i.contains("serde_json")));
148 }
149
150 #[test]
151 fn test_domain_timestamptz() {
152 let d = make_domain("created", "timestamptz");
153 let (_, imports) = gen(&d);
154 assert!(imports.iter().any(|i| i.contains("chrono")));
155 }
156}