1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{CompositeTypeInfo, SchemaInfo};
10use crate::typemap;
11
12pub fn generate_composite(
13 composite: &CompositeTypeInfo,
14 db_kind: DatabaseKind,
15 schema_info: &SchemaInfo,
16 extra_derives: &[String],
17 type_overrides: &HashMap<String, String>,
18 time_crate: TimeCrate,
19) -> (TokenStream, BTreeSet<String>) {
20 let mut imports = BTreeSet::new();
21 for imp in imports_for_derives(extra_derives) {
22 imports.insert(imp);
23 }
24 let struct_name = format_ident!("{}", composite.name.to_upper_camel_case());
25
26 let doc = format!(
27 "Composite type: {}.{}",
28 composite.schema_name, composite.name
29 );
30
31 imports.insert("use serde::{Serialize, Deserialize};".to_string());
32 imports.insert("use sqlx_gen::SqlxGen;".to_string());
33 let mut derive_tokens = vec![
34 quote! { Debug },
35 quote! { Clone },
36 quote! { PartialEq },
37 quote! { Eq },
38 quote! { Serialize },
39 quote! { Deserialize },
40 quote! { sqlx::Type },
41 quote! { SqlxGen },
42 ];
43 for d in extra_derives {
44 let ident = format_ident!("{}", d);
45 derive_tokens.push(quote! { #ident });
46 }
47
48 let pg_name = &composite.name;
49 let type_attr = quote! { #[sqlx(type_name = #pg_name)] };
50
51 let fields: Vec<TokenStream> = composite
52 .fields
53 .iter()
54 .map(|col| {
55 let rust_type = typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate);
56 if let Some(imp) = &rust_type.needs_import {
57 imports.insert(imp.clone());
58 }
59
60 let field_name_snake = col.name.to_snake_case();
61 let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
62 let prefixed = format!(
63 "{}_{}",
64 composite.name.to_snake_case(),
65 field_name_snake
66 );
67 (prefixed, true)
68 } else {
69 let changed = field_name_snake != col.name;
70 (field_name_snake, changed)
71 };
72
73 let field_ident = format_ident!("{}", effective_name);
74 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
75 let fallback = format_ident!("String");
76 quote! { #fallback }
77 });
78
79 let rename = if needs_rename {
80 let original = &col.name;
81 quote! { #[sqlx(rename = #original)] }
82 } else {
83 quote! {}
84 };
85
86 quote! {
87 #rename
88 pub #field_ident: #type_tokens,
89 }
90 })
91 .collect();
92
93 let tokens = quote! {
94 #[doc = #doc]
95 #[derive(#(#derive_tokens),*)]
96 #[sqlx_gen(kind = "composite")]
97 #type_attr
98 pub struct #struct_name {
99 #(#fields)*
100 }
101 };
102
103 (tokens, imports)
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use crate::codegen::parse_and_format;
110 use crate::introspect::ColumnInfo;
111
112 fn make_composite(name: &str, fields: Vec<ColumnInfo>) -> CompositeTypeInfo {
113 CompositeTypeInfo {
114 schema_name: "public".to_string(),
115 name: name.to_string(),
116 fields,
117 }
118 }
119
120 fn make_field(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
121 ColumnInfo {
122 name: name.to_string(),
123 data_type: udt_name.to_string(),
124 udt_name: udt_name.to_string(),
125 is_nullable: nullable,
126 is_primary_key: false,
127 ordinal_position: 0,
128 schema_name: "public".to_string(),
129 column_default: None,
130 }
131 }
132
133 fn gen(composite: &CompositeTypeInfo) -> String {
134 let schema = SchemaInfo::default();
135 let (tokens, _) = generate_composite(composite, DatabaseKind::Postgres, &schema, &[], &HashMap::new(), TimeCrate::Chrono);
136 parse_and_format(&tokens)
137 }
138
139 fn gen_with(
140 composite: &CompositeTypeInfo,
141 derives: &[String],
142 overrides: &HashMap<String, String>,
143 ) -> (String, BTreeSet<String>) {
144 let schema = SchemaInfo::default();
145 let (tokens, imports) = generate_composite(composite, DatabaseKind::Postgres, &schema, derives, overrides, TimeCrate::Chrono);
146 (parse_and_format(&tokens), imports)
147 }
148
149 #[test]
152 fn test_simple_composite() {
153 let c = make_composite("address", vec![
154 make_field("street", "text", false),
155 make_field("city", "text", false),
156 ]);
157 let code = gen(&c);
158 assert!(code.contains("pub street: String"));
159 assert!(code.contains("pub city: String"));
160 }
161
162 #[test]
163 fn test_name_pascal_case() {
164 let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
165 let code = gen(&c);
166 assert!(code.contains("pub struct GeoPoint"));
167 }
168
169 #[test]
170 fn test_doc_comment() {
171 let c = make_composite("address", vec![make_field("x", "text", false)]);
172 let code = gen(&c);
173 assert!(code.contains("Composite type: public.address"));
174 }
175
176 #[test]
177 fn test_sqlx_type_name() {
178 let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
179 let code = gen(&c);
180 assert!(code.contains("sqlx(type_name = \"geo_point\")"));
181 }
182
183 #[test]
184 fn test_non_public_schema_qualified_type_name() {
185 let c = CompositeTypeInfo {
186 schema_name: "geo".to_string(),
187 name: "point".to_string(),
188 fields: vec![make_field("x", "float8", false)],
189 };
190 let schema = SchemaInfo::default();
191 let (tokens, _) = generate_composite(&c, DatabaseKind::Postgres, &schema, &[], &HashMap::new(), TimeCrate::Chrono);
192 let code = parse_and_format(&tokens);
193 assert!(code.contains("sqlx(type_name = \"point\")"));
194 }
195
196 #[test]
197 fn test_public_schema_not_qualified() {
198 let c = make_composite("address", vec![make_field("x", "text", false)]);
199 let code = gen(&c);
200 assert!(code.contains("sqlx(type_name = \"address\")"));
201 assert!(!code.contains("type_name = \"public.address\""));
203 }
204
205 #[test]
208 fn test_nullable_field() {
209 let c = make_composite("address", vec![make_field("zip", "text", true)]);
210 let code = gen(&c);
211 assert!(code.contains("Option<String>"));
212 }
213
214 #[test]
215 fn test_non_nullable_field() {
216 let c = make_composite("address", vec![make_field("city", "text", false)]);
217 let code = gen(&c);
218 assert!(code.contains("pub city: String"));
219 assert!(!code.contains("Option"));
220 }
221
222 #[test]
223 fn test_keyword_field_prefixed() {
224 let c = make_composite("item", vec![make_field("type", "text", false)]);
225 let code = gen(&c);
226 assert!(code.contains("pub item_type: String"));
227 assert!(code.contains("sqlx(rename = \"type\")"));
228 }
229
230 #[test]
233 fn test_camel_case_field_renamed() {
234 let c = make_composite("address", vec![make_field("StreetName", "text", false)]);
235 let code = gen(&c);
236 assert!(code.contains("pub street_name: String"));
237 assert!(code.contains("sqlx(rename = \"StreetName\")"));
238 }
239
240 #[test]
241 fn test_snake_case_field_no_rename() {
242 let c = make_composite("address", vec![make_field("street_name", "text", false)]);
243 let code = gen(&c);
244 assert!(code.contains("pub street_name: String"));
245 assert!(!code.contains("sqlx(rename"));
246 }
247
248 #[test]
251 fn test_int4_field() {
252 let c = make_composite("data", vec![make_field("count", "int4", false)]);
253 let code = gen(&c);
254 assert!(code.contains("pub count: i32"));
255 }
256
257 #[test]
258 fn test_uuid_field_import() {
259 let c = make_composite("data", vec![make_field("id", "uuid", false)]);
260 let (_, imports) = gen_with(&c, &[], &HashMap::new());
261 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
262 }
263
264 #[test]
265 fn test_text_field() {
266 let c = make_composite("data", vec![make_field("label", "text", false)]);
267 let code = gen(&c);
268 assert!(code.contains("pub label: String"));
269 }
270
271 #[test]
274 fn test_default_derives() {
275 let c = make_composite("data", vec![make_field("x", "text", false)]);
276 let code = gen(&c);
277 assert!(code.contains("Debug"));
278 assert!(code.contains("Clone"));
279 assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
280 }
281
282 #[test]
283 fn test_extra_derive() {
284 let c = make_composite("data", vec![make_field("x", "text", false)]);
285 let derives = vec!["Serialize".to_string()];
286 let (code, _) = gen_with(&c, &derives, &HashMap::new());
287 assert!(code.contains("Serialize"));
288 }
289
290 #[test]
293 fn test_type_override() {
294 let c = make_composite("data", vec![make_field("payload", "jsonb", false)]);
295 let mut overrides = HashMap::new();
296 overrides.insert("jsonb".to_string(), "MyJson".to_string());
297 let (code, _) = gen_with(&c, &[], &overrides);
298 assert!(code.contains("pub payload: MyJson"));
299 }
300}