sqlx_gen/codegen/
domain_gen.rs1use std::collections::{BTreeSet, HashMap};
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::introspect::{DomainInfo, SchemaInfo};
9use crate::typemap;
10
11pub fn generate_domain(
12 domain: &DomainInfo,
13 db_kind: DatabaseKind,
14 schema_info: &SchemaInfo,
15 type_overrides: &HashMap<String, String>,
16) -> (TokenStream, BTreeSet<String>) {
17 let mut imports = BTreeSet::new();
18 let alias_name = format_ident!("{}", domain.name.to_upper_camel_case());
19
20 let doc = format!(
21 "Domain: {}.{} (base: {})",
22 domain.schema_name, domain.name, domain.base_type
23 );
24
25 let fake_col = crate::introspect::ColumnInfo {
27 name: String::new(),
28 data_type: domain.base_type.clone(),
29 udt_name: domain.base_type.clone(),
30 is_nullable: false,
31 is_primary_key: false,
32 ordinal_position: 0,
33 schema_name: domain.schema_name.clone(),
34 };
35
36 let rust_type = typemap::map_column(&fake_col, db_kind, schema_info, type_overrides);
37 if let Some(imp) = &rust_type.needs_import {
38 imports.insert(imp.clone());
39 }
40
41 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
42 let fallback = format_ident!("String");
43 quote! { #fallback }
44 });
45
46 let domain_doc = "sqlx_gen:kind=domain";
47 let tokens = quote! {
48 #[doc = #doc]
49 #[doc = #domain_doc]
50 pub type #alias_name = #type_tokens;
51 };
52
53 (tokens, imports)
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59 use crate::codegen::parse_and_format;
60
61 fn make_domain(name: &str, base: &str) -> DomainInfo {
62 DomainInfo {
63 schema_name: "public".to_string(),
64 name: name.to_string(),
65 base_type: base.to_string(),
66 }
67 }
68
69 fn gen(domain: &DomainInfo) -> (String, BTreeSet<String>) {
70 let schema = SchemaInfo::default();
71 let (tokens, imports) = generate_domain(domain, DatabaseKind::Postgres, &schema, &HashMap::new());
72 (parse_and_format(&tokens), imports)
73 }
74
75 fn gen_with_overrides(domain: &DomainInfo, overrides: &HashMap<String, String>) -> (String, BTreeSet<String>) {
76 let schema = SchemaInfo::default();
77 let (tokens, imports) = generate_domain(domain, DatabaseKind::Postgres, &schema, overrides);
78 (parse_and_format(&tokens), imports)
79 }
80
81 #[test]
82 fn test_domain_text() {
83 let d = make_domain("email", "text");
84 let (code, _) = gen(&d);
85 assert!(code.contains("pub type Email = String"));
86 }
87
88 #[test]
89 fn test_domain_int4() {
90 let d = make_domain("positive_int", "int4");
91 let (code, _) = gen(&d);
92 assert!(code.contains("pub type PositiveInt = i32"));
93 }
94
95 #[test]
96 fn test_domain_uuid() {
97 let d = make_domain("my_uuid", "uuid");
98 let (code, imports) = gen(&d);
99 assert!(code.contains("pub type MyUuid = Uuid"));
100 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
101 }
102
103 #[test]
104 fn test_doc_comment() {
105 let d = make_domain("email", "text");
106 let (code, _) = gen(&d);
107 assert!(code.contains("Domain: public.email (base: text)"));
108 }
109
110 #[test]
111 fn test_import_when_needed() {
112 let d = make_domain("my_uuid", "uuid");
113 let (_, imports) = gen(&d);
114 assert!(!imports.is_empty());
115 }
116
117 #[test]
118 fn test_no_import_simple_type() {
119 let d = make_domain("email", "text");
120 let (_, imports) = gen(&d);
121 assert!(imports.is_empty());
122 }
123
124 #[test]
125 fn test_pascal_case_name() {
126 let d = make_domain("email_address", "text");
127 let (code, _) = gen(&d);
128 assert!(code.contains("pub type EmailAddress"));
129 }
130
131 #[test]
132 fn test_type_override() {
133 let d = make_domain("json_data", "jsonb");
134 let mut overrides = HashMap::new();
135 overrides.insert("jsonb".to_string(), "MyJson".to_string());
136 let (code, _) = gen_with_overrides(&d, &overrides);
137 assert!(code.contains("pub type JsonData = MyJson"));
138 }
139
140 #[test]
141 fn test_domain_jsonb() {
142 let d = make_domain("data", "jsonb");
143 let (code, imports) = gen(&d);
144 assert!(code.contains("Value"));
145 assert!(imports.iter().any(|i| i.contains("serde_json")));
146 }
147
148 #[test]
149 fn test_domain_timestamptz() {
150 let d = make_domain("created", "timestamptz");
151 let (_, imports) = gen(&d);
152 assert!(imports.iter().any(|i| i.contains("chrono")));
153 }
154}