1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{CompositeTypeInfo, SchemaInfo};
10use crate::typemap;
11
12pub fn generate_composite(
13 composite: &CompositeTypeInfo,
14 db_kind: DatabaseKind,
15 schema_info: &SchemaInfo,
16 extra_derives: &[String],
17 type_overrides: &HashMap<String, String>,
18) -> (TokenStream, BTreeSet<String>) {
19 let mut imports = BTreeSet::new();
20 for imp in imports_for_derives(extra_derives) {
21 imports.insert(imp);
22 }
23 let struct_name = format_ident!("{}", composite.name.to_upper_camel_case());
24
25 let doc = format!(
26 "Composite type: {}.{}",
27 composite.schema_name, composite.name
28 );
29
30 imports.insert("use serde::{Serialize, Deserialize};".to_string());
31 imports.insert("use sqlx_gen::SqlxGen;".to_string());
32 let mut derive_tokens = vec![
33 quote! { Debug },
34 quote! { Clone },
35 quote! { PartialEq },
36 quote! { Eq },
37 quote! { Serialize },
38 quote! { Deserialize },
39 quote! { sqlx::Type },
40 quote! { SqlxGen },
41 ];
42 for d in extra_derives {
43 let ident = format_ident!("{}", d);
44 derive_tokens.push(quote! { #ident });
45 }
46
47 let pg_name = if composite.schema_name != "public" {
49 format!("{}.{}", composite.schema_name, composite.name)
50 } else {
51 composite.name.clone()
52 };
53 let type_attr = quote! { #[sqlx(type_name = #pg_name)] };
54
55 let fields: Vec<TokenStream> = composite
56 .fields
57 .iter()
58 .map(|col| {
59 let rust_type = typemap::map_column(col, db_kind, schema_info, type_overrides);
60 if let Some(imp) = &rust_type.needs_import {
61 imports.insert(imp.clone());
62 }
63
64 let field_name_snake = col.name.to_snake_case();
65 let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
66 let prefixed = format!(
67 "{}_{}",
68 composite.name.to_snake_case(),
69 field_name_snake
70 );
71 (prefixed, true)
72 } else {
73 let changed = field_name_snake != col.name;
74 (field_name_snake, changed)
75 };
76
77 let field_ident = format_ident!("{}", effective_name);
78 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
79 let fallback = format_ident!("String");
80 quote! { #fallback }
81 });
82
83 let rename = if needs_rename {
84 let original = &col.name;
85 quote! { #[sqlx(rename = #original)] }
86 } else {
87 quote! {}
88 };
89
90 quote! {
91 #rename
92 pub #field_ident: #type_tokens,
93 }
94 })
95 .collect();
96
97 let tokens = quote! {
98 #[doc = #doc]
99 #[derive(#(#derive_tokens),*)]
100 #[sqlx_gen(kind = "composite")]
101 #type_attr
102 pub struct #struct_name {
103 #(#fields)*
104 }
105 };
106
107 (tokens, imports)
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use crate::codegen::parse_and_format;
114 use crate::introspect::ColumnInfo;
115
116 fn make_composite(name: &str, fields: Vec<ColumnInfo>) -> CompositeTypeInfo {
117 CompositeTypeInfo {
118 schema_name: "public".to_string(),
119 name: name.to_string(),
120 fields,
121 }
122 }
123
124 fn make_field(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
125 ColumnInfo {
126 name: name.to_string(),
127 data_type: udt_name.to_string(),
128 udt_name: udt_name.to_string(),
129 is_nullable: nullable,
130 is_primary_key: false,
131 ordinal_position: 0,
132 schema_name: "public".to_string(),
133 column_default: None,
134 }
135 }
136
137 fn gen(composite: &CompositeTypeInfo) -> String {
138 let schema = SchemaInfo::default();
139 let (tokens, _) = generate_composite(composite, DatabaseKind::Postgres, &schema, &[], &HashMap::new());
140 parse_and_format(&tokens)
141 }
142
143 fn gen_with(
144 composite: &CompositeTypeInfo,
145 derives: &[String],
146 overrides: &HashMap<String, String>,
147 ) -> (String, BTreeSet<String>) {
148 let schema = SchemaInfo::default();
149 let (tokens, imports) = generate_composite(composite, DatabaseKind::Postgres, &schema, derives, overrides);
150 (parse_and_format(&tokens), imports)
151 }
152
153 #[test]
156 fn test_simple_composite() {
157 let c = make_composite("address", vec![
158 make_field("street", "text", false),
159 make_field("city", "text", false),
160 ]);
161 let code = gen(&c);
162 assert!(code.contains("pub street: String"));
163 assert!(code.contains("pub city: String"));
164 }
165
166 #[test]
167 fn test_name_pascal_case() {
168 let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
169 let code = gen(&c);
170 assert!(code.contains("pub struct GeoPoint"));
171 }
172
173 #[test]
174 fn test_doc_comment() {
175 let c = make_composite("address", vec![make_field("x", "text", false)]);
176 let code = gen(&c);
177 assert!(code.contains("Composite type: public.address"));
178 }
179
180 #[test]
181 fn test_sqlx_type_name() {
182 let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
183 let code = gen(&c);
184 assert!(code.contains("sqlx(type_name = \"geo_point\")"));
185 }
186
187 #[test]
188 fn test_non_public_schema_qualified_type_name() {
189 let c = CompositeTypeInfo {
190 schema_name: "geo".to_string(),
191 name: "point".to_string(),
192 fields: vec![make_field("x", "float8", false)],
193 };
194 let schema = SchemaInfo::default();
195 let (tokens, _) = generate_composite(&c, DatabaseKind::Postgres, &schema, &[], &HashMap::new());
196 let code = parse_and_format(&tokens);
197 assert!(code.contains("sqlx(type_name = \"geo.point\")"));
198 }
199
200 #[test]
201 fn test_public_schema_not_qualified() {
202 let c = make_composite("address", vec![make_field("x", "text", false)]);
203 let code = gen(&c);
204 assert!(code.contains("sqlx(type_name = \"address\")"));
205 assert!(!code.contains("type_name = \"public.address\""));
207 }
208
209 #[test]
212 fn test_nullable_field() {
213 let c = make_composite("address", vec![make_field("zip", "text", true)]);
214 let code = gen(&c);
215 assert!(code.contains("Option<String>"));
216 }
217
218 #[test]
219 fn test_non_nullable_field() {
220 let c = make_composite("address", vec![make_field("city", "text", false)]);
221 let code = gen(&c);
222 assert!(code.contains("pub city: String"));
223 assert!(!code.contains("Option"));
224 }
225
226 #[test]
227 fn test_keyword_field_prefixed() {
228 let c = make_composite("item", vec![make_field("type", "text", false)]);
229 let code = gen(&c);
230 assert!(code.contains("pub item_type: String"));
231 assert!(code.contains("sqlx(rename = \"type\")"));
232 }
233
234 #[test]
237 fn test_camel_case_field_renamed() {
238 let c = make_composite("address", vec![make_field("StreetName", "text", false)]);
239 let code = gen(&c);
240 assert!(code.contains("pub street_name: String"));
241 assert!(code.contains("sqlx(rename = \"StreetName\")"));
242 }
243
244 #[test]
245 fn test_snake_case_field_no_rename() {
246 let c = make_composite("address", vec![make_field("street_name", "text", false)]);
247 let code = gen(&c);
248 assert!(code.contains("pub street_name: String"));
249 assert!(!code.contains("sqlx(rename"));
250 }
251
252 #[test]
255 fn test_int4_field() {
256 let c = make_composite("data", vec![make_field("count", "int4", false)]);
257 let code = gen(&c);
258 assert!(code.contains("pub count: i32"));
259 }
260
261 #[test]
262 fn test_uuid_field_import() {
263 let c = make_composite("data", vec![make_field("id", "uuid", false)]);
264 let (_, imports) = gen_with(&c, &[], &HashMap::new());
265 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
266 }
267
268 #[test]
269 fn test_text_field() {
270 let c = make_composite("data", vec![make_field("label", "text", false)]);
271 let code = gen(&c);
272 assert!(code.contains("pub label: String"));
273 }
274
275 #[test]
278 fn test_default_derives() {
279 let c = make_composite("data", vec![make_field("x", "text", false)]);
280 let code = gen(&c);
281 assert!(code.contains("Debug"));
282 assert!(code.contains("Clone"));
283 assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
284 }
285
286 #[test]
287 fn test_extra_derive() {
288 let c = make_composite("data", vec![make_field("x", "text", false)]);
289 let derives = vec!["Serialize".to_string()];
290 let (code, _) = gen_with(&c, &derives, &HashMap::new());
291 assert!(code.contains("Serialize"));
292 }
293
294 #[test]
297 fn test_type_override() {
298 let c = make_composite("data", vec![make_field("payload", "jsonb", false)]);
299 let mut overrides = HashMap::new();
300 overrides.insert("jsonb".to_string(), "MyJson".to_string());
301 let (code, _) = gen_with(&c, &[], &overrides);
302 assert!(code.contains("pub payload: MyJson"));
303 }
304}