sqlx_gen/codegen/
domain_gen.rs1use std::collections::{BTreeSet, HashMap};
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::introspect::{DomainInfo, SchemaInfo};
9use crate::typemap;
10
11pub fn generate_domain(
12 domain: &DomainInfo,
13 db_kind: DatabaseKind,
14 schema_info: &SchemaInfo,
15 type_overrides: &HashMap<String, String>,
16) -> (TokenStream, BTreeSet<String>) {
17 let mut imports = BTreeSet::new();
18 let alias_name = format_ident!("{}", domain.name.to_upper_camel_case());
19
20 let doc = format!(
21 "Domain: {}.{} (base: {})",
22 domain.schema_name, domain.name, domain.base_type
23 );
24
25 let fake_col = crate::introspect::ColumnInfo {
27 name: String::new(),
28 data_type: domain.base_type.clone(),
29 udt_name: domain.base_type.clone(),
30 is_nullable: false,
31 ordinal_position: 0,
32 schema_name: domain.schema_name.clone(),
33 };
34
35 let rust_type = typemap::map_column(&fake_col, db_kind, schema_info, type_overrides);
36 if let Some(imp) = &rust_type.needs_import {
37 imports.insert(imp.clone());
38 }
39
40 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
41 let fallback = format_ident!("String");
42 quote! { #fallback }
43 });
44
45 let tokens = quote! {
46 #[doc = #doc]
47 pub type #alias_name = #type_tokens;
48 };
49
50 (tokens, imports)
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56 use crate::codegen::parse_and_format;
57
58 fn make_domain(name: &str, base: &str) -> DomainInfo {
59 DomainInfo {
60 schema_name: "public".to_string(),
61 name: name.to_string(),
62 base_type: base.to_string(),
63 }
64 }
65
66 fn gen(domain: &DomainInfo) -> (String, BTreeSet<String>) {
67 let schema = SchemaInfo::default();
68 let (tokens, imports) = generate_domain(domain, DatabaseKind::Postgres, &schema, &HashMap::new());
69 (parse_and_format(&tokens), imports)
70 }
71
72 fn gen_with_overrides(domain: &DomainInfo, overrides: &HashMap<String, String>) -> (String, BTreeSet<String>) {
73 let schema = SchemaInfo::default();
74 let (tokens, imports) = generate_domain(domain, DatabaseKind::Postgres, &schema, overrides);
75 (parse_and_format(&tokens), imports)
76 }
77
78 #[test]
79 fn test_domain_text() {
80 let d = make_domain("email", "text");
81 let (code, _) = gen(&d);
82 assert!(code.contains("pub type Email = String"));
83 }
84
85 #[test]
86 fn test_domain_int4() {
87 let d = make_domain("positive_int", "int4");
88 let (code, _) = gen(&d);
89 assert!(code.contains("pub type PositiveInt = i32"));
90 }
91
92 #[test]
93 fn test_domain_uuid() {
94 let d = make_domain("my_uuid", "uuid");
95 let (code, imports) = gen(&d);
96 assert!(code.contains("pub type MyUuid = Uuid"));
97 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
98 }
99
100 #[test]
101 fn test_doc_comment() {
102 let d = make_domain("email", "text");
103 let (code, _) = gen(&d);
104 assert!(code.contains("Domain: public.email (base: text)"));
105 }
106
107 #[test]
108 fn test_import_when_needed() {
109 let d = make_domain("my_uuid", "uuid");
110 let (_, imports) = gen(&d);
111 assert!(!imports.is_empty());
112 }
113
114 #[test]
115 fn test_no_import_simple_type() {
116 let d = make_domain("email", "text");
117 let (_, imports) = gen(&d);
118 assert!(imports.is_empty());
119 }
120
121 #[test]
122 fn test_pascal_case_name() {
123 let d = make_domain("email_address", "text");
124 let (code, _) = gen(&d);
125 assert!(code.contains("pub type EmailAddress"));
126 }
127
128 #[test]
129 fn test_type_override() {
130 let d = make_domain("json_data", "jsonb");
131 let mut overrides = HashMap::new();
132 overrides.insert("jsonb".to_string(), "MyJson".to_string());
133 let (code, _) = gen_with_overrides(&d, &overrides);
134 assert!(code.contains("pub type JsonData = MyJson"));
135 }
136
137 #[test]
138 fn test_domain_jsonb() {
139 let d = make_domain("data", "jsonb");
140 let (code, imports) = gen(&d);
141 assert!(code.contains("Value"));
142 assert!(imports.iter().any(|i| i.contains("serde_json")));
143 }
144
145 #[test]
146 fn test_domain_timestamptz() {
147 let d = make_domain("created", "timestamptz");
148 let (_, imports) = gen(&d);
149 assert!(imports.iter().any(|i| i.contains("chrono")));
150 }
151}