1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{CompositeTypeInfo, SchemaInfo};
10use crate::typemap;
11
12pub fn generate_composite(
13 composite: &CompositeTypeInfo,
14 db_kind: DatabaseKind,
15 schema_info: &SchemaInfo,
16 extra_derives: &[String],
17 type_overrides: &HashMap<String, String>,
18) -> (TokenStream, BTreeSet<String>) {
19 let mut imports = BTreeSet::new();
20 for imp in imports_for_derives(extra_derives) {
21 imports.insert(imp);
22 }
23 let struct_name = format_ident!("{}", composite.name.to_upper_camel_case());
24
25 let doc = format!(
26 "Composite type: {}.{}",
27 composite.schema_name, composite.name
28 );
29
30 let mut derive_tokens = vec![
31 quote! { Debug },
32 quote! { Clone },
33 quote! { sqlx::Type },
34 ];
35 for d in extra_derives {
36 let ident = format_ident!("{}", d);
37 derive_tokens.push(quote! { #ident });
38 }
39
40 let pg_name = &composite.name;
41 let type_attr = quote! { #[sqlx(type_name = #pg_name)] };
42
43 let fields: Vec<TokenStream> = composite
44 .fields
45 .iter()
46 .map(|col| {
47 let rust_type = typemap::map_column(col, db_kind, schema_info, type_overrides);
48 if let Some(imp) = &rust_type.needs_import {
49 imports.insert(imp.clone());
50 }
51
52 let field_name_snake = col.name.to_snake_case();
53 let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
54 let prefixed = format!(
55 "{}_{}",
56 composite.name.to_snake_case(),
57 field_name_snake
58 );
59 (prefixed, true)
60 } else {
61 let changed = field_name_snake != col.name;
62 (field_name_snake, changed)
63 };
64
65 let field_ident = format_ident!("{}", effective_name);
66 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
67 let fallback = format_ident!("String");
68 quote! { #fallback }
69 });
70
71 let rename = if needs_rename {
72 let original = &col.name;
73 quote! { #[sqlx(rename = #original)] }
74 } else {
75 quote! {}
76 };
77
78 quote! {
79 #rename
80 pub #field_ident: #type_tokens,
81 }
82 })
83 .collect();
84
85 let tokens = quote! {
86 #[doc = #doc]
87 #[derive(#(#derive_tokens),*)]
88 #type_attr
89 pub struct #struct_name {
90 #(#fields)*
91 }
92 };
93
94 (tokens, imports)
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use crate::codegen::parse_and_format;
101 use crate::introspect::ColumnInfo;
102
103 fn make_composite(name: &str, fields: Vec<ColumnInfo>) -> CompositeTypeInfo {
104 CompositeTypeInfo {
105 schema_name: "public".to_string(),
106 name: name.to_string(),
107 fields,
108 }
109 }
110
111 fn make_field(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
112 ColumnInfo {
113 name: name.to_string(),
114 data_type: udt_name.to_string(),
115 udt_name: udt_name.to_string(),
116 is_nullable: nullable,
117 ordinal_position: 0,
118 schema_name: "public".to_string(),
119 }
120 }
121
122 fn gen(composite: &CompositeTypeInfo) -> String {
123 let schema = SchemaInfo::default();
124 let (tokens, _) = generate_composite(composite, DatabaseKind::Postgres, &schema, &[], &HashMap::new());
125 parse_and_format(&tokens)
126 }
127
128 fn gen_with(
129 composite: &CompositeTypeInfo,
130 derives: &[String],
131 overrides: &HashMap<String, String>,
132 ) -> (String, BTreeSet<String>) {
133 let schema = SchemaInfo::default();
134 let (tokens, imports) = generate_composite(composite, DatabaseKind::Postgres, &schema, derives, overrides);
135 (parse_and_format(&tokens), imports)
136 }
137
138 #[test]
141 fn test_simple_composite() {
142 let c = make_composite("address", vec![
143 make_field("street", "text", false),
144 make_field("city", "text", false),
145 ]);
146 let code = gen(&c);
147 assert!(code.contains("pub street: String"));
148 assert!(code.contains("pub city: String"));
149 }
150
151 #[test]
152 fn test_name_pascal_case() {
153 let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
154 let code = gen(&c);
155 assert!(code.contains("pub struct GeoPoint"));
156 }
157
158 #[test]
159 fn test_doc_comment() {
160 let c = make_composite("address", vec![make_field("x", "text", false)]);
161 let code = gen(&c);
162 assert!(code.contains("Composite type: public.address"));
163 }
164
165 #[test]
166 fn test_sqlx_type_name() {
167 let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
168 let code = gen(&c);
169 assert!(code.contains("sqlx(type_name = \"geo_point\")"));
170 }
171
172 #[test]
175 fn test_nullable_field() {
176 let c = make_composite("address", vec![make_field("zip", "text", true)]);
177 let code = gen(&c);
178 assert!(code.contains("Option<String>"));
179 }
180
181 #[test]
182 fn test_non_nullable_field() {
183 let c = make_composite("address", vec![make_field("city", "text", false)]);
184 let code = gen(&c);
185 assert!(code.contains("pub city: String"));
186 assert!(!code.contains("Option"));
187 }
188
189 #[test]
190 fn test_keyword_field_prefixed() {
191 let c = make_composite("item", vec![make_field("type", "text", false)]);
192 let code = gen(&c);
193 assert!(code.contains("pub item_type: String"));
194 assert!(code.contains("sqlx(rename = \"type\")"));
195 }
196
197 #[test]
200 fn test_camel_case_field_renamed() {
201 let c = make_composite("address", vec![make_field("StreetName", "text", false)]);
202 let code = gen(&c);
203 assert!(code.contains("pub street_name: String"));
204 assert!(code.contains("sqlx(rename = \"StreetName\")"));
205 }
206
207 #[test]
208 fn test_snake_case_field_no_rename() {
209 let c = make_composite("address", vec![make_field("street_name", "text", false)]);
210 let code = gen(&c);
211 assert!(code.contains("pub street_name: String"));
212 assert!(!code.contains("sqlx(rename"));
213 }
214
215 #[test]
218 fn test_int4_field() {
219 let c = make_composite("data", vec![make_field("count", "int4", false)]);
220 let code = gen(&c);
221 assert!(code.contains("pub count: i32"));
222 }
223
224 #[test]
225 fn test_uuid_field_import() {
226 let c = make_composite("data", vec![make_field("id", "uuid", false)]);
227 let (_, imports) = gen_with(&c, &[], &HashMap::new());
228 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
229 }
230
231 #[test]
232 fn test_text_field() {
233 let c = make_composite("data", vec![make_field("label", "text", false)]);
234 let code = gen(&c);
235 assert!(code.contains("pub label: String"));
236 }
237
238 #[test]
241 fn test_default_derives() {
242 let c = make_composite("data", vec![make_field("x", "text", false)]);
243 let code = gen(&c);
244 assert!(code.contains("Debug"));
245 assert!(code.contains("Clone"));
246 assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
247 }
248
249 #[test]
250 fn test_extra_derive() {
251 let c = make_composite("data", vec![make_field("x", "text", false)]);
252 let derives = vec!["Serialize".to_string()];
253 let (code, _) = gen_with(&c, &derives, &HashMap::new());
254 assert!(code.contains("Serialize"));
255 }
256
257 #[test]
260 fn test_type_override() {
261 let c = make_composite("data", vec![make_field("payload", "jsonb", false)]);
262 let mut overrides = HashMap::new();
263 overrides.insert("jsonb".to_string(), "MyJson".to_string());
264 let (code, _) = gen_with(&c, &[], &overrides);
265 assert!(code.contains("pub payload: MyJson"));
266 }
267}