1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{CompositeTypeInfo, SchemaInfo};
10use crate::typemap;
11
12pub fn generate_composite(
13 composite: &CompositeTypeInfo,
14 db_kind: DatabaseKind,
15 schema_info: &SchemaInfo,
16 extra_derives: &[String],
17 type_overrides: &HashMap<String, String>,
18) -> (TokenStream, BTreeSet<String>) {
19 let mut imports = BTreeSet::new();
20 for imp in imports_for_derives(extra_derives) {
21 imports.insert(imp);
22 }
23 let struct_name = format_ident!("{}", composite.name.to_upper_camel_case());
24
25 let doc = format!(
26 "Composite type: {}.{}",
27 composite.schema_name, composite.name
28 );
29
30 imports.insert("use serde::{Serialize, Deserialize};".to_string());
31 imports.insert("use sqlx_gen::SqlxGen;".to_string());
32 let mut derive_tokens = vec![
33 quote! { Debug },
34 quote! { Clone },
35 quote! { PartialEq },
36 quote! { Eq },
37 quote! { Serialize },
38 quote! { Deserialize },
39 quote! { sqlx::Type },
40 quote! { SqlxGen },
41 ];
42 for d in extra_derives {
43 let ident = format_ident!("{}", d);
44 derive_tokens.push(quote! { #ident });
45 }
46
47 let pg_name = &composite.name;
48 let type_attr = quote! { #[sqlx(type_name = #pg_name)] };
49
50 let fields: Vec<TokenStream> = composite
51 .fields
52 .iter()
53 .map(|col| {
54 let rust_type = typemap::map_column(col, db_kind, schema_info, type_overrides);
55 if let Some(imp) = &rust_type.needs_import {
56 imports.insert(imp.clone());
57 }
58
59 let field_name_snake = col.name.to_snake_case();
60 let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
61 let prefixed = format!(
62 "{}_{}",
63 composite.name.to_snake_case(),
64 field_name_snake
65 );
66 (prefixed, true)
67 } else {
68 let changed = field_name_snake != col.name;
69 (field_name_snake, changed)
70 };
71
72 let field_ident = format_ident!("{}", effective_name);
73 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
74 let fallback = format_ident!("String");
75 quote! { #fallback }
76 });
77
78 let rename = if needs_rename {
79 let original = &col.name;
80 quote! { #[sqlx(rename = #original)] }
81 } else {
82 quote! {}
83 };
84
85 quote! {
86 #rename
87 pub #field_ident: #type_tokens,
88 }
89 })
90 .collect();
91
92 let tokens = quote! {
93 #[doc = #doc]
94 #[derive(#(#derive_tokens),*)]
95 #[sqlx_gen(kind = "composite")]
96 #type_attr
97 pub struct #struct_name {
98 #(#fields)*
99 }
100 };
101
102 (tokens, imports)
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use crate::codegen::parse_and_format;
109 use crate::introspect::ColumnInfo;
110
111 fn make_composite(name: &str, fields: Vec<ColumnInfo>) -> CompositeTypeInfo {
112 CompositeTypeInfo {
113 schema_name: "public".to_string(),
114 name: name.to_string(),
115 fields,
116 }
117 }
118
119 fn make_field(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
120 ColumnInfo {
121 name: name.to_string(),
122 data_type: udt_name.to_string(),
123 udt_name: udt_name.to_string(),
124 is_nullable: nullable,
125 is_primary_key: false,
126 ordinal_position: 0,
127 schema_name: "public".to_string(),
128 column_default: None,
129 }
130 }
131
132 fn gen(composite: &CompositeTypeInfo) -> String {
133 let schema = SchemaInfo::default();
134 let (tokens, _) = generate_composite(composite, DatabaseKind::Postgres, &schema, &[], &HashMap::new());
135 parse_and_format(&tokens)
136 }
137
138 fn gen_with(
139 composite: &CompositeTypeInfo,
140 derives: &[String],
141 overrides: &HashMap<String, String>,
142 ) -> (String, BTreeSet<String>) {
143 let schema = SchemaInfo::default();
144 let (tokens, imports) = generate_composite(composite, DatabaseKind::Postgres, &schema, derives, overrides);
145 (parse_and_format(&tokens), imports)
146 }
147
148 #[test]
151 fn test_simple_composite() {
152 let c = make_composite("address", vec![
153 make_field("street", "text", false),
154 make_field("city", "text", false),
155 ]);
156 let code = gen(&c);
157 assert!(code.contains("pub street: String"));
158 assert!(code.contains("pub city: String"));
159 }
160
161 #[test]
162 fn test_name_pascal_case() {
163 let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
164 let code = gen(&c);
165 assert!(code.contains("pub struct GeoPoint"));
166 }
167
168 #[test]
169 fn test_doc_comment() {
170 let c = make_composite("address", vec![make_field("x", "text", false)]);
171 let code = gen(&c);
172 assert!(code.contains("Composite type: public.address"));
173 }
174
175 #[test]
176 fn test_sqlx_type_name() {
177 let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
178 let code = gen(&c);
179 assert!(code.contains("sqlx(type_name = \"geo_point\")"));
180 }
181
182 #[test]
183 fn test_non_public_schema_qualified_type_name() {
184 let c = CompositeTypeInfo {
185 schema_name: "geo".to_string(),
186 name: "point".to_string(),
187 fields: vec![make_field("x", "float8", false)],
188 };
189 let schema = SchemaInfo::default();
190 let (tokens, _) = generate_composite(&c, DatabaseKind::Postgres, &schema, &[], &HashMap::new());
191 let code = parse_and_format(&tokens);
192 assert!(code.contains("sqlx(type_name = \"point\")"));
193 }
194
195 #[test]
196 fn test_public_schema_not_qualified() {
197 let c = make_composite("address", vec![make_field("x", "text", false)]);
198 let code = gen(&c);
199 assert!(code.contains("sqlx(type_name = \"address\")"));
200 assert!(!code.contains("type_name = \"public.address\""));
202 }
203
204 #[test]
207 fn test_nullable_field() {
208 let c = make_composite("address", vec![make_field("zip", "text", true)]);
209 let code = gen(&c);
210 assert!(code.contains("Option<String>"));
211 }
212
213 #[test]
214 fn test_non_nullable_field() {
215 let c = make_composite("address", vec![make_field("city", "text", false)]);
216 let code = gen(&c);
217 assert!(code.contains("pub city: String"));
218 assert!(!code.contains("Option"));
219 }
220
221 #[test]
222 fn test_keyword_field_prefixed() {
223 let c = make_composite("item", vec![make_field("type", "text", false)]);
224 let code = gen(&c);
225 assert!(code.contains("pub item_type: String"));
226 assert!(code.contains("sqlx(rename = \"type\")"));
227 }
228
229 #[test]
232 fn test_camel_case_field_renamed() {
233 let c = make_composite("address", vec![make_field("StreetName", "text", false)]);
234 let code = gen(&c);
235 assert!(code.contains("pub street_name: String"));
236 assert!(code.contains("sqlx(rename = \"StreetName\")"));
237 }
238
239 #[test]
240 fn test_snake_case_field_no_rename() {
241 let c = make_composite("address", vec![make_field("street_name", "text", false)]);
242 let code = gen(&c);
243 assert!(code.contains("pub street_name: String"));
244 assert!(!code.contains("sqlx(rename"));
245 }
246
247 #[test]
250 fn test_int4_field() {
251 let c = make_composite("data", vec![make_field("count", "int4", false)]);
252 let code = gen(&c);
253 assert!(code.contains("pub count: i32"));
254 }
255
256 #[test]
257 fn test_uuid_field_import() {
258 let c = make_composite("data", vec![make_field("id", "uuid", false)]);
259 let (_, imports) = gen_with(&c, &[], &HashMap::new());
260 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
261 }
262
263 #[test]
264 fn test_text_field() {
265 let c = make_composite("data", vec![make_field("label", "text", false)]);
266 let code = gen(&c);
267 assert!(code.contains("pub label: String"));
268 }
269
270 #[test]
273 fn test_default_derives() {
274 let c = make_composite("data", vec![make_field("x", "text", false)]);
275 let code = gen(&c);
276 assert!(code.contains("Debug"));
277 assert!(code.contains("Clone"));
278 assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
279 }
280
281 #[test]
282 fn test_extra_derive() {
283 let c = make_composite("data", vec![make_field("x", "text", false)]);
284 let derives = vec!["Serialize".to_string()];
285 let (code, _) = gen_with(&c, &derives, &HashMap::new());
286 assert!(code.contains("Serialize"));
287 }
288
289 #[test]
292 fn test_type_override() {
293 let c = make_composite("data", vec![make_field("payload", "jsonb", false)]);
294 let mut overrides = HashMap::new();
295 overrides.insert("jsonb".to_string(), "MyJson".to_string());
296 let (code, _) = gen_with(&c, &[], &overrides);
297 assert!(code.contains("pub payload: MyJson"));
298 }
299}