1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{SchemaInfo, TableInfo};
10use crate::typemap;
11
12pub fn generate_struct(
13 table: &TableInfo,
14 db_kind: DatabaseKind,
15 schema_info: &SchemaInfo,
16 extra_derives: &[String],
17 type_overrides: &HashMap<String, String>,
18) -> (TokenStream, BTreeSet<String>) {
19 let mut imports = BTreeSet::new();
20 for imp in imports_for_derives(extra_derives) {
21 imports.insert(imp);
22 }
23 let struct_name = format_ident!("{}", table.name.to_upper_camel_case());
24
25 imports.insert("use serde::{Serialize, Deserialize};".to_string());
27 let mut derive_tokens = vec![
28 quote! { Debug },
29 quote! { Clone },
30 quote! { PartialEq },
31 quote! { Eq },
32 quote! { Serialize },
33 quote! { Deserialize },
34 quote! { sqlx::FromRow },
35 ];
36 for d in extra_derives {
37 let ident = format_ident!("{}", d);
38 derive_tokens.push(quote! { #ident });
39 }
40
41 let fields: Vec<TokenStream> = table
43 .columns
44 .iter()
45 .map(|col| {
46 let rust_type = resolve_column_type(col, db_kind, table, schema_info, type_overrides);
47 if let Some(imp) = &rust_type.needs_import {
48 imports.insert(imp.clone());
49 }
50
51 let field_name_snake = col.name.to_snake_case();
52 let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
55 let prefixed = format!(
56 "{}_{}",
57 table.name.to_snake_case(),
58 field_name_snake
59 );
60 (prefixed, true)
61 } else {
62 let changed = field_name_snake != col.name;
63 (field_name_snake, changed)
64 };
65
66 let field_ident = format_ident!("{}", effective_name);
67 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
68 let fallback = format_ident!("String");
69 quote! { #fallback }
70 });
71
72 let rename = if needs_rename {
73 let original = &col.name;
74 quote! { #[sqlx(rename = #original)] }
75 } else {
76 quote! {}
77 };
78
79 quote! {
80 #rename
81 pub #field_ident: #type_tokens,
82 }
83 })
84 .collect();
85
86 let tokens = quote! {
87 #[derive(#(#derive_tokens),*)]
88 pub struct #struct_name {
89 #(#fields)*
90 }
91 };
92
93 (tokens, imports)
94}
95
96fn resolve_column_type(
97 col: &crate::introspect::ColumnInfo,
98 db_kind: DatabaseKind,
99 table: &TableInfo,
100 schema_info: &SchemaInfo,
101 type_overrides: &HashMap<String, String>,
102) -> typemap::RustType {
103 if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
105 let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
106 let rt = typemap::RustType::with_import(
107 &enum_type_name,
108 &format!("use super::types::{};", enum_type_name),
109 );
110 return if col.is_nullable {
111 rt.wrap_option()
112 } else {
113 rt
114 };
115 }
116
117 typemap::map_column(col, db_kind, schema_info, type_overrides)
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123 use crate::codegen::parse_and_format;
124 use crate::introspect::ColumnInfo;
125
126 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
127 TableInfo {
128 schema_name: "public".to_string(),
129 name: name.to_string(),
130 columns,
131 }
132 }
133
134 fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
135 ColumnInfo {
136 name: name.to_string(),
137 data_type: udt_name.to_string(),
138 udt_name: udt_name.to_string(),
139 is_nullable: nullable,
140 ordinal_position: 0,
141 schema_name: "public".to_string(),
142 }
143 }
144
145 fn gen(table: &TableInfo) -> String {
146 let schema = SchemaInfo::default();
147 let (tokens, _) = generate_struct(table, DatabaseKind::Postgres, &schema, &[], &HashMap::new());
148 parse_and_format(&tokens)
149 }
150
151 fn gen_with(
152 table: &TableInfo,
153 schema: &SchemaInfo,
154 db: DatabaseKind,
155 derives: &[String],
156 overrides: &HashMap<String, String>,
157 ) -> (String, BTreeSet<String>) {
158 let (tokens, imports) = generate_struct(table, db, schema, derives, overrides);
159 (parse_and_format(&tokens), imports)
160 }
161
162 #[test]
165 fn test_simple_table() {
166 let table = make_table("users", vec![
167 make_col("id", "int4", false),
168 make_col("name", "text", false),
169 ]);
170 let code = gen(&table);
171 assert!(code.contains("pub id: i32"));
172 assert!(code.contains("pub name: String"));
173 }
174
175 #[test]
176 fn test_struct_name_pascal_case() {
177 let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
178 let code = gen(&table);
179 assert!(code.contains("pub struct UserRoles"));
180 }
181
182 #[test]
183 fn test_struct_name_simple() {
184 let table = make_table("users", vec![make_col("id", "int4", false)]);
185 let code = gen(&table);
186 assert!(code.contains("pub struct Users"));
187 }
188
189 #[test]
192 fn test_nullable_column() {
193 let table = make_table("users", vec![make_col("email", "text", true)]);
194 let code = gen(&table);
195 assert!(code.contains("pub email: Option<String>"));
196 }
197
198 #[test]
199 fn test_non_nullable_column() {
200 let table = make_table("users", vec![make_col("name", "text", false)]);
201 let code = gen(&table);
202 assert!(code.contains("pub name: String"));
203 assert!(!code.contains("Option"));
204 }
205
206 #[test]
207 fn test_mix_nullable() {
208 let table = make_table("users", vec![
209 make_col("id", "int4", false),
210 make_col("bio", "text", true),
211 ]);
212 let code = gen(&table);
213 assert!(code.contains("pub id: i32"));
214 assert!(code.contains("pub bio: Option<String>"));
215 }
216
217 #[test]
220 fn test_keyword_type_renamed() {
221 let table = make_table("connector", vec![make_col("type", "text", false)]);
222 let code = gen(&table);
223 assert!(code.contains("pub connector_type: String"));
224 assert!(code.contains("sqlx(rename = \"type\")"));
225 }
226
227 #[test]
228 fn test_keyword_fn_renamed() {
229 let table = make_table("item", vec![make_col("fn", "text", false)]);
230 let code = gen(&table);
231 assert!(code.contains("pub item_fn: String"));
232 assert!(code.contains("sqlx(rename = \"fn\")"));
233 }
234
235 #[test]
236 fn test_keyword_match_renamed() {
237 let table = make_table("game", vec![make_col("match", "text", false)]);
238 let code = gen(&table);
239 assert!(code.contains("pub game_match: String"));
240 }
241
242 #[test]
243 fn test_non_keyword_no_rename() {
244 let table = make_table("users", vec![make_col("name", "text", false)]);
245 let code = gen(&table);
246 assert!(!code.contains("sqlx(rename"));
247 }
248
249 #[test]
252 fn test_camel_case_column_renamed() {
253 let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
254 let code = gen(&table);
255 assert!(code.contains("pub created_at: String"));
256 assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
257 }
258
259 #[test]
260 fn test_mixed_case_column_renamed() {
261 let table = make_table("users", vec![make_col("firstName", "text", false)]);
262 let code = gen(&table);
263 assert!(code.contains("pub first_name: String"));
264 assert!(code.contains("sqlx(rename = \"firstName\")"));
265 }
266
267 #[test]
268 fn test_already_snake_case_no_rename() {
269 let table = make_table("users", vec![make_col("created_at", "text", false)]);
270 let code = gen(&table);
271 assert!(code.contains("pub created_at: String"));
272 assert!(!code.contains("sqlx(rename"));
273 }
274
275 #[test]
278 fn test_default_derives() {
279 let table = make_table("users", vec![make_col("id", "int4", false)]);
280 let code = gen(&table);
281 assert!(code.contains("Debug"));
282 assert!(code.contains("Clone"));
283 assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
284 }
285
286 #[test]
287 fn test_extra_derive_serialize() {
288 let table = make_table("users", vec![make_col("id", "int4", false)]);
289 let schema = SchemaInfo::default();
290 let derives = vec!["Serialize".to_string()];
291 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
292 assert!(code.contains("Serialize"));
293 }
294
295 #[test]
296 fn test_extra_derives_both_serde() {
297 let table = make_table("users", vec![make_col("id", "int4", false)]);
298 let schema = SchemaInfo::default();
299 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
300 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
301 assert!(imports.iter().any(|i| i.contains("serde")));
302 }
303
304 #[test]
307 fn test_uuid_import() {
308 let table = make_table("users", vec![make_col("id", "uuid", false)]);
309 let schema = SchemaInfo::default();
310 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
311 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
312 }
313
314 #[test]
315 fn test_timestamptz_import() {
316 let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
317 let schema = SchemaInfo::default();
318 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
319 assert!(imports.iter().any(|i| i.contains("chrono")));
320 }
321
322 #[test]
323 fn test_int4_only_serde_import() {
324 let table = make_table("users", vec![make_col("id", "int4", false)]);
325 let schema = SchemaInfo::default();
326 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
327 assert_eq!(imports.len(), 1);
328 assert!(imports.iter().any(|i| i.contains("serde")));
329 }
330
331 #[test]
332 fn test_multiple_imports_collected() {
333 let table = make_table("users", vec![
334 make_col("id", "uuid", false),
335 make_col("created_at", "timestamptz", false),
336 ]);
337 let schema = SchemaInfo::default();
338 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
339 assert!(imports.iter().any(|i| i.contains("uuid")));
340 assert!(imports.iter().any(|i| i.contains("chrono")));
341 }
342
343 #[test]
346 fn test_mysql_enum_column() {
347 let table = make_table("users", vec![ColumnInfo {
348 name: "status".to_string(),
349 data_type: "enum".to_string(),
350 udt_name: "enum('active','inactive')".to_string(),
351 is_nullable: false,
352 ordinal_position: 0,
353 schema_name: "test_db".to_string(),
354 }]);
355 let schema = SchemaInfo::default();
356 let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
357 assert!(code.contains("UsersStatus"));
358 assert!(imports.iter().any(|i| i.contains("super::types::")));
359 }
360
361 #[test]
362 fn test_mysql_enum_nullable() {
363 let table = make_table("users", vec![ColumnInfo {
364 name: "status".to_string(),
365 data_type: "enum".to_string(),
366 udt_name: "enum('a','b')".to_string(),
367 is_nullable: true,
368 ordinal_position: 0,
369 schema_name: "test_db".to_string(),
370 }]);
371 let schema = SchemaInfo::default();
372 let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
373 assert!(code.contains("Option<UsersStatus>"));
374 }
375
376 #[test]
379 fn test_type_override() {
380 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
381 let schema = SchemaInfo::default();
382 let mut overrides = HashMap::new();
383 overrides.insert("jsonb".to_string(), "MyJson".to_string());
384 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
385 assert!(code.contains("pub data: MyJson"));
386 }
387
388 #[test]
389 fn test_type_override_absent() {
390 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
391 let schema = SchemaInfo::default();
392 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
393 assert!(code.contains("Value"));
394 }
395
396 #[test]
397 fn test_type_override_nullable() {
398 let table = make_table("users", vec![make_col("data", "jsonb", true)]);
399 let schema = SchemaInfo::default();
400 let mut overrides = HashMap::new();
401 overrides.insert("jsonb".to_string(), "MyJson".to_string());
402 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
403 assert!(code.contains("Option<MyJson>"));
404 }
405}