1use std::collections::{BTreeSet, HashMap};
2
3use heck::ToSnakeCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::codegen::naming::singularize;
9use crate::codegen::{imports_for_derives, is_rust_keyword, rust_type_name_for};
10use crate::introspect::{CompositeTypeInfo, SchemaInfo};
11use crate::typemap;
12
13pub fn generate_composite(
14 composite: &CompositeTypeInfo,
15 db_kind: DatabaseKind,
16 schema_info: &SchemaInfo,
17 extra_derives: &[String],
18 type_overrides: &HashMap<String, String>,
19 time_crate: TimeCrate,
20) -> (TokenStream, BTreeSet<String>) {
21 let mut imports = BTreeSet::new();
22 for imp in imports_for_derives(extra_derives) {
23 imports.insert(imp);
24 }
25 let rust_name = rust_type_name_for(schema_info, &composite.schema_name, &composite.name);
26 let struct_name = format_ident!("{}", rust_name);
27 let search_path_doc = if db_kind == DatabaseKind::Postgres
28 && !crate::codegen::is_default_schema(&composite.schema_name)
29 {
30 Some(format!(
31 "Lives in PostgreSQL schema `{schema}`. The sqlx connection must \
32 include `{schema}` in its search_path so PG resolves the \
33 unqualified `type_name = \"{name}\"` to this composite.",
34 schema = composite.schema_name,
35 name = composite.name,
36 ))
37 } else {
38 None
39 };
40
41 let doc = format!(
42 "Composite type: {}.{}",
43 composite.schema_name, composite.name
44 );
45
46 imports.insert("use serde::{Serialize, Deserialize};".to_string());
47 imports.insert("use sqlx_gen::SqlxGen;".to_string());
48 let mut derive_tokens = vec![
49 quote! { Debug },
50 quote! { Clone },
51 quote! { PartialEq },
52 quote! { Eq },
53 quote! { Serialize },
54 quote! { Deserialize },
55 quote! { sqlx::Type },
56 quote! { SqlxGen },
57 ];
58 for d in extra_derives {
59 let ident = format_ident!("{}", d);
60 derive_tokens.push(quote! { #ident });
61 }
62
63 let pg_name = &composite.name;
67 let type_attr = quote! { #[sqlx(type_name = #pg_name)] };
68
69 let fields: Vec<TokenStream> = composite
70 .fields
71 .iter()
72 .map(|col| {
73 let rust_type =
74 typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate);
75 if let Some(imp) = &rust_type.needs_import {
76 imports.insert(imp.clone());
77 }
78
79 let field_name_snake = col.name.to_snake_case();
80 let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
81 let prefix = singularize(&composite.name).to_snake_case();
82 let prefixed = format!("{}_{}", prefix, field_name_snake);
83 (prefixed, true)
84 } else {
85 let changed = field_name_snake != col.name;
86 (field_name_snake, changed)
87 };
88
89 let field_ident = format_ident!("{}", effective_name);
90 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
91 let fallback = format_ident!("String");
92 quote! { #fallback }
93 });
94
95 let rename = if needs_rename {
96 let original = &col.name;
97 quote! { #[sqlx(rename = #original)] }
98 } else {
99 quote! {}
100 };
101
102 quote! {
103 #rename
104 pub #field_ident: #type_tokens,
105 }
106 })
107 .collect();
108
109 let _ = db_kind;
113
114 let search_path_doc_tokens = match &search_path_doc {
115 Some(m) => quote! { #[doc = #m] },
116 None => quote! {},
117 };
118 let tokens = quote! {
119 #[doc = #doc]
120 #search_path_doc_tokens
121 #[derive(#(#derive_tokens),*)]
122 #[sqlx_gen(kind = "composite")]
123 #type_attr
124 pub struct #struct_name {
125 #(#fields)*
126 }
127 };
128
129 (tokens, imports)
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use crate::codegen::parse_and_format;
136 use crate::introspect::ColumnInfo;
137
138 fn make_composite(name: &str, fields: Vec<ColumnInfo>) -> CompositeTypeInfo {
139 CompositeTypeInfo {
140 schema_name: "public".to_string(),
141 name: name.to_string(),
142 fields,
143 }
144 }
145
146 fn make_field(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
147 ColumnInfo {
148 name: name.to_string(),
149 data_type: udt_name.to_string(),
150 udt_name: udt_name.to_string(),
151 is_nullable: nullable,
152 is_primary_key: false,
153 ordinal_position: 0,
154 schema_name: "public".to_string(),
155 udt_schema: None,
156 column_default: None,
157 }
158 }
159
160 fn gen(composite: &CompositeTypeInfo) -> String {
161 let schema = SchemaInfo::default();
162 let (tokens, _) = generate_composite(
163 composite,
164 DatabaseKind::Postgres,
165 &schema,
166 &[],
167 &HashMap::new(),
168 TimeCrate::Chrono,
169 );
170 parse_and_format(&tokens).unwrap()
171 }
172
173 fn gen_with(
174 composite: &CompositeTypeInfo,
175 derives: &[String],
176 overrides: &HashMap<String, String>,
177 ) -> (String, BTreeSet<String>) {
178 let schema = SchemaInfo::default();
179 let (tokens, imports) = generate_composite(
180 composite,
181 DatabaseKind::Postgres,
182 &schema,
183 derives,
184 overrides,
185 TimeCrate::Chrono,
186 );
187 (parse_and_format(&tokens).unwrap(), imports)
188 }
189
190 #[test]
193 fn test_simple_composite() {
194 let c = make_composite(
195 "address",
196 vec![
197 make_field("street", "text", false),
198 make_field("city", "text", false),
199 ],
200 );
201 let code = gen(&c);
202 assert!(code.contains("pub street: String"));
203 assert!(code.contains("pub city: String"));
204 }
205
206 #[test]
207 fn test_name_pascal_case() {
208 let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
209 let code = gen(&c);
210 assert!(code.contains("pub struct GeoPoint"));
211 }
212
213 #[test]
214 fn test_doc_comment() {
215 let c = make_composite("address", vec![make_field("x", "text", false)]);
216 let code = gen(&c);
217 assert!(code.contains("Composite type: public.address"));
218 }
219
220 #[test]
221 fn test_sqlx_type_name() {
222 let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
223 let code = gen(&c);
224 assert!(code.contains("sqlx(type_name = \"geo_point\")"));
225 }
226
227 #[test]
228 fn test_does_not_emit_manual_pg_has_array_type_impl() {
229 let c = make_composite("address", vec![make_field("street", "text", false)]);
232 let code = gen(&c);
233 assert!(
234 !code.contains("PgHasArrayType"),
235 "must not emit a manual PgHasArrayType impl, got:\n{}",
236 code
237 );
238 }
239
240 #[test]
241 fn test_non_public_schema_type_name_is_unqualified() {
242 let c = CompositeTypeInfo {
244 schema_name: "geo".to_string(),
245 name: "point".to_string(),
246 fields: vec![make_field("x", "float8", false)],
247 };
248 let schema = SchemaInfo::default();
249 let (tokens, _) = generate_composite(
250 &c,
251 DatabaseKind::Postgres,
252 &schema,
253 &[],
254 &HashMap::new(),
255 TimeCrate::Chrono,
256 );
257 let code = parse_and_format(&tokens).unwrap();
258 assert!(
259 code.contains("sqlx(type_name = \"point\")"),
260 "type_name must be unqualified for sqlx 0.8, got:\n{}",
261 code
262 );
263 assert!(!code.contains("\"geo.point\""));
264 }
265
266 #[test]
267 fn test_public_schema_not_qualified() {
268 let c = make_composite("address", vec![make_field("x", "text", false)]);
269 let code = gen(&c);
270 assert!(code.contains("sqlx(type_name = \"address\")"));
271 assert!(!code.contains("type_name = \"public.address\""));
273 }
274
275 #[test]
278 fn test_nullable_field() {
279 let c = make_composite("address", vec![make_field("zip", "text", true)]);
280 let code = gen(&c);
281 assert!(code.contains("Option<String>"));
282 }
283
284 #[test]
285 fn test_non_nullable_field() {
286 let c = make_composite("address", vec![make_field("city", "text", false)]);
287 let code = gen(&c);
288 assert!(code.contains("pub city: String"));
289 assert!(!code.contains("Option"));
290 }
291
292 #[test]
293 fn test_keyword_field_prefixed() {
294 let c = make_composite("item", vec![make_field("type", "text", false)]);
295 let code = gen(&c);
296 assert!(code.contains("pub item_type: String"));
297 assert!(code.contains("sqlx(rename = \"type\")"));
298 }
299
300 #[test]
303 fn test_camel_case_field_renamed() {
304 let c = make_composite("address", vec![make_field("StreetName", "text", false)]);
305 let code = gen(&c);
306 assert!(code.contains("pub street_name: String"));
307 assert!(code.contains("sqlx(rename = \"StreetName\")"));
308 }
309
310 #[test]
311 fn test_snake_case_field_no_rename() {
312 let c = make_composite("address", vec![make_field("street_name", "text", false)]);
313 let code = gen(&c);
314 assert!(code.contains("pub street_name: String"));
315 assert!(!code.contains("sqlx(rename"));
316 }
317
318 #[test]
321 fn test_int4_field() {
322 let c = make_composite("data", vec![make_field("count", "int4", false)]);
323 let code = gen(&c);
324 assert!(code.contains("pub count: i32"));
325 }
326
327 #[test]
328 fn test_uuid_field_import() {
329 let c = make_composite("data", vec![make_field("id", "uuid", false)]);
330 let (_, imports) = gen_with(&c, &[], &HashMap::new());
331 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
332 }
333
334 #[test]
335 fn test_text_field() {
336 let c = make_composite("data", vec![make_field("label", "text", false)]);
337 let code = gen(&c);
338 assert!(code.contains("pub label: String"));
339 }
340
341 #[test]
344 fn test_default_derives() {
345 let c = make_composite("data", vec![make_field("x", "text", false)]);
346 let code = gen(&c);
347 assert!(code.contains("Debug"));
348 assert!(code.contains("Clone"));
349 assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
350 }
351
352 #[test]
353 fn test_extra_derive() {
354 let c = make_composite("data", vec![make_field("x", "text", false)]);
355 let derives = vec!["Serialize".to_string()];
356 let (code, _) = gen_with(&c, &derives, &HashMap::new());
357 assert!(code.contains("Serialize"));
358 }
359
360 #[test]
363 fn test_type_override() {
364 let c = make_composite("data", vec![make_field("payload", "jsonb", false)]);
365 let mut overrides = HashMap::new();
366 overrides.insert("jsonb".to_string(), "MyJson".to_string());
367 let (code, _) = gen_with(&c, &[], &overrides);
368 assert!(code.contains("pub payload: MyJson"));
369 }
370}