1use std::collections::{BTreeSet, HashMap};
2
3use heck::ToSnakeCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::codegen::{imports_for_derives, is_rust_keyword, rust_type_name_for};
9use crate::introspect::{CompositeTypeInfo, SchemaInfo};
10use crate::typemap;
11
12pub fn generate_composite(
13 composite: &CompositeTypeInfo,
14 db_kind: DatabaseKind,
15 schema_info: &SchemaInfo,
16 extra_derives: &[String],
17 type_overrides: &HashMap<String, String>,
18 time_crate: TimeCrate,
19) -> (TokenStream, BTreeSet<String>) {
20 let mut imports = BTreeSet::new();
21 for imp in imports_for_derives(extra_derives) {
22 imports.insert(imp);
23 }
24 let rust_name = rust_type_name_for(schema_info, &composite.schema_name, &composite.name);
25 let struct_name = format_ident!("{}", rust_name);
26 let search_path_doc = if db_kind == DatabaseKind::Postgres
27 && !crate::codegen::is_default_schema(&composite.schema_name)
28 {
29 Some(format!(
30 "Lives in PostgreSQL schema `{schema}`. The sqlx connection must \
31 include `{schema}` in its search_path so PG resolves the \
32 unqualified `type_name = \"{name}\"` to this composite.",
33 schema = composite.schema_name,
34 name = composite.name,
35 ))
36 } else {
37 None
38 };
39
40 let doc = format!(
41 "Composite type: {}.{}",
42 composite.schema_name, composite.name
43 );
44
45 imports.insert("use serde::{Serialize, Deserialize};".to_string());
46 imports.insert("use sqlx_gen::SqlxGen;".to_string());
47 let mut derive_tokens = vec![
48 quote! { Debug },
49 quote! { Clone },
50 quote! { PartialEq },
51 quote! { Eq },
52 quote! { Serialize },
53 quote! { Deserialize },
54 quote! { sqlx::Type },
55 quote! { SqlxGen },
56 ];
57 for d in extra_derives {
58 let ident = format_ident!("{}", d);
59 derive_tokens.push(quote! { #ident });
60 }
61
62 let pg_name = &composite.name;
66 let type_attr = quote! { #[sqlx(type_name = #pg_name)] };
67
68 let fields: Vec<TokenStream> = composite
69 .fields
70 .iter()
71 .map(|col| {
72 let rust_type =
73 typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate);
74 if let Some(imp) = &rust_type.needs_import {
75 imports.insert(imp.clone());
76 }
77
78 let field_name_snake = col.name.to_snake_case();
79 let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
80 let prefixed = format!("{}_{}", composite.name.to_snake_case(), field_name_snake);
81 (prefixed, true)
82 } else {
83 let changed = field_name_snake != col.name;
84 (field_name_snake, changed)
85 };
86
87 let field_ident = format_ident!("{}", effective_name);
88 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
89 let fallback = format_ident!("String");
90 quote! { #fallback }
91 });
92
93 let rename = if needs_rename {
94 let original = &col.name;
95 quote! { #[sqlx(rename = #original)] }
96 } else {
97 quote! {}
98 };
99
100 quote! {
101 #rename
102 pub #field_ident: #type_tokens,
103 }
104 })
105 .collect();
106
107 let _ = db_kind;
111
112 let search_path_doc_tokens = match &search_path_doc {
113 Some(m) => quote! { #[doc = #m] },
114 None => quote! {},
115 };
116 let tokens = quote! {
117 #[doc = #doc]
118 #search_path_doc_tokens
119 #[derive(#(#derive_tokens),*)]
120 #[sqlx_gen(kind = "composite")]
121 #type_attr
122 pub struct #struct_name {
123 #(#fields)*
124 }
125 };
126
127 (tokens, imports)
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use crate::codegen::parse_and_format;
134 use crate::introspect::ColumnInfo;
135
136 fn make_composite(name: &str, fields: Vec<ColumnInfo>) -> CompositeTypeInfo {
137 CompositeTypeInfo {
138 schema_name: "public".to_string(),
139 name: name.to_string(),
140 fields,
141 }
142 }
143
144 fn make_field(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
145 ColumnInfo {
146 name: name.to_string(),
147 data_type: udt_name.to_string(),
148 udt_name: udt_name.to_string(),
149 is_nullable: nullable,
150 is_primary_key: false,
151 ordinal_position: 0,
152 schema_name: "public".to_string(),
153 udt_schema: None,
154 column_default: None,
155 }
156 }
157
158 fn gen(composite: &CompositeTypeInfo) -> String {
159 let schema = SchemaInfo::default();
160 let (tokens, _) = generate_composite(
161 composite,
162 DatabaseKind::Postgres,
163 &schema,
164 &[],
165 &HashMap::new(),
166 TimeCrate::Chrono,
167 );
168 parse_and_format(&tokens).unwrap()
169 }
170
171 fn gen_with(
172 composite: &CompositeTypeInfo,
173 derives: &[String],
174 overrides: &HashMap<String, String>,
175 ) -> (String, BTreeSet<String>) {
176 let schema = SchemaInfo::default();
177 let (tokens, imports) = generate_composite(
178 composite,
179 DatabaseKind::Postgres,
180 &schema,
181 derives,
182 overrides,
183 TimeCrate::Chrono,
184 );
185 (parse_and_format(&tokens).unwrap(), imports)
186 }
187
188 #[test]
191 fn test_simple_composite() {
192 let c = make_composite(
193 "address",
194 vec![
195 make_field("street", "text", false),
196 make_field("city", "text", false),
197 ],
198 );
199 let code = gen(&c);
200 assert!(code.contains("pub street: String"));
201 assert!(code.contains("pub city: String"));
202 }
203
204 #[test]
205 fn test_name_pascal_case() {
206 let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
207 let code = gen(&c);
208 assert!(code.contains("pub struct GeoPoint"));
209 }
210
211 #[test]
212 fn test_doc_comment() {
213 let c = make_composite("address", vec![make_field("x", "text", false)]);
214 let code = gen(&c);
215 assert!(code.contains("Composite type: public.address"));
216 }
217
218 #[test]
219 fn test_sqlx_type_name() {
220 let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
221 let code = gen(&c);
222 assert!(code.contains("sqlx(type_name = \"geo_point\")"));
223 }
224
225 #[test]
226 fn test_does_not_emit_manual_pg_has_array_type_impl() {
227 let c = make_composite("address", vec![make_field("street", "text", false)]);
230 let code = gen(&c);
231 assert!(
232 !code.contains("PgHasArrayType"),
233 "must not emit a manual PgHasArrayType impl, got:\n{}",
234 code
235 );
236 }
237
238 #[test]
239 fn test_non_public_schema_type_name_is_unqualified() {
240 let c = CompositeTypeInfo {
242 schema_name: "geo".to_string(),
243 name: "point".to_string(),
244 fields: vec![make_field("x", "float8", false)],
245 };
246 let schema = SchemaInfo::default();
247 let (tokens, _) = generate_composite(
248 &c,
249 DatabaseKind::Postgres,
250 &schema,
251 &[],
252 &HashMap::new(),
253 TimeCrate::Chrono,
254 );
255 let code = parse_and_format(&tokens).unwrap();
256 assert!(
257 code.contains("sqlx(type_name = \"point\")"),
258 "type_name must be unqualified for sqlx 0.8, got:\n{}",
259 code
260 );
261 assert!(!code.contains("\"geo.point\""));
262 }
263
264 #[test]
265 fn test_public_schema_not_qualified() {
266 let c = make_composite("address", vec![make_field("x", "text", false)]);
267 let code = gen(&c);
268 assert!(code.contains("sqlx(type_name = \"address\")"));
269 assert!(!code.contains("type_name = \"public.address\""));
271 }
272
273 #[test]
276 fn test_nullable_field() {
277 let c = make_composite("address", vec![make_field("zip", "text", true)]);
278 let code = gen(&c);
279 assert!(code.contains("Option<String>"));
280 }
281
282 #[test]
283 fn test_non_nullable_field() {
284 let c = make_composite("address", vec![make_field("city", "text", false)]);
285 let code = gen(&c);
286 assert!(code.contains("pub city: String"));
287 assert!(!code.contains("Option"));
288 }
289
290 #[test]
291 fn test_keyword_field_prefixed() {
292 let c = make_composite("item", vec![make_field("type", "text", false)]);
293 let code = gen(&c);
294 assert!(code.contains("pub item_type: String"));
295 assert!(code.contains("sqlx(rename = \"type\")"));
296 }
297
298 #[test]
301 fn test_camel_case_field_renamed() {
302 let c = make_composite("address", vec![make_field("StreetName", "text", false)]);
303 let code = gen(&c);
304 assert!(code.contains("pub street_name: String"));
305 assert!(code.contains("sqlx(rename = \"StreetName\")"));
306 }
307
308 #[test]
309 fn test_snake_case_field_no_rename() {
310 let c = make_composite("address", vec![make_field("street_name", "text", false)]);
311 let code = gen(&c);
312 assert!(code.contains("pub street_name: String"));
313 assert!(!code.contains("sqlx(rename"));
314 }
315
316 #[test]
319 fn test_int4_field() {
320 let c = make_composite("data", vec![make_field("count", "int4", false)]);
321 let code = gen(&c);
322 assert!(code.contains("pub count: i32"));
323 }
324
325 #[test]
326 fn test_uuid_field_import() {
327 let c = make_composite("data", vec![make_field("id", "uuid", false)]);
328 let (_, imports) = gen_with(&c, &[], &HashMap::new());
329 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
330 }
331
332 #[test]
333 fn test_text_field() {
334 let c = make_composite("data", vec![make_field("label", "text", false)]);
335 let code = gen(&c);
336 assert!(code.contains("pub label: String"));
337 }
338
339 #[test]
342 fn test_default_derives() {
343 let c = make_composite("data", vec![make_field("x", "text", false)]);
344 let code = gen(&c);
345 assert!(code.contains("Debug"));
346 assert!(code.contains("Clone"));
347 assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
348 }
349
350 #[test]
351 fn test_extra_derive() {
352 let c = make_composite("data", vec![make_field("x", "text", false)]);
353 let derives = vec!["Serialize".to_string()];
354 let (code, _) = gen_with(&c, &derives, &HashMap::new());
355 assert!(code.contains("Serialize"));
356 }
357
358 #[test]
361 fn test_type_override() {
362 let c = make_composite("data", vec![make_field("payload", "jsonb", false)]);
363 let mut overrides = HashMap::new();
364 overrides.insert("jsonb".to_string(), "MyJson".to_string());
365 let (code, _) = gen_with(&c, &[], &overrides);
366 assert!(code.contains("pub payload: MyJson"));
367 }
368}