1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{SchemaInfo, TableInfo};
10use crate::typemap;
11
12pub fn generate_struct(
13 table: &TableInfo,
14 db_kind: DatabaseKind,
15 schema_info: &SchemaInfo,
16 extra_derives: &[String],
17 type_overrides: &HashMap<String, String>,
18) -> (TokenStream, BTreeSet<String>) {
19 let mut imports = BTreeSet::new();
20 for imp in imports_for_derives(extra_derives) {
21 imports.insert(imp);
22 }
23 let struct_name = format_ident!("{}", table.name.to_upper_camel_case());
24
25 let mut derive_tokens = vec![
27 quote! { Debug },
28 quote! { Clone },
29 quote! { sqlx::FromRow },
30 ];
31 for d in extra_derives {
32 let ident = format_ident!("{}", d);
33 derive_tokens.push(quote! { #ident });
34 }
35
36 let fields: Vec<TokenStream> = table
38 .columns
39 .iter()
40 .map(|col| {
41 let rust_type = resolve_column_type(col, db_kind, table, schema_info, type_overrides);
42 if let Some(imp) = &rust_type.needs_import {
43 imports.insert(imp.clone());
44 }
45
46 let field_name_snake = col.name.to_snake_case();
47 let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
50 let prefixed = format!(
51 "{}_{}",
52 table.name.to_snake_case(),
53 field_name_snake
54 );
55 (prefixed, true)
56 } else {
57 let changed = field_name_snake != col.name;
58 (field_name_snake, changed)
59 };
60
61 let field_ident = format_ident!("{}", effective_name);
62 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
63 let fallback = format_ident!("String");
64 quote! { #fallback }
65 });
66
67 let rename = if needs_rename {
68 let original = &col.name;
69 quote! { #[sqlx(rename = #original)] }
70 } else {
71 quote! {}
72 };
73
74 quote! {
75 #rename
76 pub #field_ident: #type_tokens,
77 }
78 })
79 .collect();
80
81 let tokens = quote! {
82 #[derive(#(#derive_tokens),*)]
83 pub struct #struct_name {
84 #(#fields)*
85 }
86 };
87
88 (tokens, imports)
89}
90
91fn resolve_column_type(
92 col: &crate::introspect::ColumnInfo,
93 db_kind: DatabaseKind,
94 table: &TableInfo,
95 schema_info: &SchemaInfo,
96 type_overrides: &HashMap<String, String>,
97) -> typemap::RustType {
98 if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
100 let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
101 let rt = typemap::RustType::with_import(
102 &enum_type_name,
103 &format!("use super::types::{};", enum_type_name),
104 );
105 return if col.is_nullable {
106 rt.wrap_option()
107 } else {
108 rt
109 };
110 }
111
112 typemap::map_column(col, db_kind, schema_info, type_overrides)
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use crate::codegen::parse_and_format;
119 use crate::introspect::ColumnInfo;
120
121 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
122 TableInfo {
123 schema_name: "public".to_string(),
124 name: name.to_string(),
125 columns,
126 }
127 }
128
129 fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
130 ColumnInfo {
131 name: name.to_string(),
132 data_type: udt_name.to_string(),
133 udt_name: udt_name.to_string(),
134 is_nullable: nullable,
135 ordinal_position: 0,
136 schema_name: "public".to_string(),
137 }
138 }
139
140 fn gen(table: &TableInfo) -> String {
141 let schema = SchemaInfo::default();
142 let (tokens, _) = generate_struct(table, DatabaseKind::Postgres, &schema, &[], &HashMap::new());
143 parse_and_format(&tokens)
144 }
145
146 fn gen_with(
147 table: &TableInfo,
148 schema: &SchemaInfo,
149 db: DatabaseKind,
150 derives: &[String],
151 overrides: &HashMap<String, String>,
152 ) -> (String, BTreeSet<String>) {
153 let (tokens, imports) = generate_struct(table, db, schema, derives, overrides);
154 (parse_and_format(&tokens), imports)
155 }
156
157 #[test]
160 fn test_simple_table() {
161 let table = make_table("users", vec![
162 make_col("id", "int4", false),
163 make_col("name", "text", false),
164 ]);
165 let code = gen(&table);
166 assert!(code.contains("pub id: i32"));
167 assert!(code.contains("pub name: String"));
168 }
169
170 #[test]
171 fn test_struct_name_pascal_case() {
172 let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
173 let code = gen(&table);
174 assert!(code.contains("pub struct UserRoles"));
175 }
176
177 #[test]
178 fn test_struct_name_simple() {
179 let table = make_table("users", vec![make_col("id", "int4", false)]);
180 let code = gen(&table);
181 assert!(code.contains("pub struct Users"));
182 }
183
184 #[test]
187 fn test_nullable_column() {
188 let table = make_table("users", vec![make_col("email", "text", true)]);
189 let code = gen(&table);
190 assert!(code.contains("pub email: Option<String>"));
191 }
192
193 #[test]
194 fn test_non_nullable_column() {
195 let table = make_table("users", vec![make_col("name", "text", false)]);
196 let code = gen(&table);
197 assert!(code.contains("pub name: String"));
198 assert!(!code.contains("Option"));
199 }
200
201 #[test]
202 fn test_mix_nullable() {
203 let table = make_table("users", vec![
204 make_col("id", "int4", false),
205 make_col("bio", "text", true),
206 ]);
207 let code = gen(&table);
208 assert!(code.contains("pub id: i32"));
209 assert!(code.contains("pub bio: Option<String>"));
210 }
211
212 #[test]
215 fn test_keyword_type_renamed() {
216 let table = make_table("connector", vec![make_col("type", "text", false)]);
217 let code = gen(&table);
218 assert!(code.contains("pub connector_type: String"));
219 assert!(code.contains("sqlx(rename = \"type\")"));
220 }
221
222 #[test]
223 fn test_keyword_fn_renamed() {
224 let table = make_table("item", vec![make_col("fn", "text", false)]);
225 let code = gen(&table);
226 assert!(code.contains("pub item_fn: String"));
227 assert!(code.contains("sqlx(rename = \"fn\")"));
228 }
229
230 #[test]
231 fn test_keyword_match_renamed() {
232 let table = make_table("game", vec![make_col("match", "text", false)]);
233 let code = gen(&table);
234 assert!(code.contains("pub game_match: String"));
235 }
236
237 #[test]
238 fn test_non_keyword_no_rename() {
239 let table = make_table("users", vec![make_col("name", "text", false)]);
240 let code = gen(&table);
241 assert!(!code.contains("sqlx(rename"));
242 }
243
244 #[test]
247 fn test_camel_case_column_renamed() {
248 let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
249 let code = gen(&table);
250 assert!(code.contains("pub created_at: String"));
251 assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
252 }
253
254 #[test]
255 fn test_mixed_case_column_renamed() {
256 let table = make_table("users", vec![make_col("firstName", "text", false)]);
257 let code = gen(&table);
258 assert!(code.contains("pub first_name: String"));
259 assert!(code.contains("sqlx(rename = \"firstName\")"));
260 }
261
262 #[test]
263 fn test_already_snake_case_no_rename() {
264 let table = make_table("users", vec![make_col("created_at", "text", false)]);
265 let code = gen(&table);
266 assert!(code.contains("pub created_at: String"));
267 assert!(!code.contains("sqlx(rename"));
268 }
269
270 #[test]
273 fn test_default_derives() {
274 let table = make_table("users", vec![make_col("id", "int4", false)]);
275 let code = gen(&table);
276 assert!(code.contains("Debug"));
277 assert!(code.contains("Clone"));
278 assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
279 }
280
281 #[test]
282 fn test_extra_derive_serialize() {
283 let table = make_table("users", vec![make_col("id", "int4", false)]);
284 let schema = SchemaInfo::default();
285 let derives = vec!["Serialize".to_string()];
286 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
287 assert!(code.contains("Serialize"));
288 }
289
290 #[test]
291 fn test_extra_derives_both_serde() {
292 let table = make_table("users", vec![make_col("id", "int4", false)]);
293 let schema = SchemaInfo::default();
294 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
295 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
296 assert!(imports.iter().any(|i| i.contains("serde")));
297 }
298
299 #[test]
302 fn test_uuid_import() {
303 let table = make_table("users", vec![make_col("id", "uuid", false)]);
304 let schema = SchemaInfo::default();
305 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
306 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
307 }
308
309 #[test]
310 fn test_timestamptz_import() {
311 let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
312 let schema = SchemaInfo::default();
313 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
314 assert!(imports.iter().any(|i| i.contains("chrono")));
315 }
316
317 #[test]
318 fn test_int4_no_import() {
319 let table = make_table("users", vec![make_col("id", "int4", false)]);
320 let schema = SchemaInfo::default();
321 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
322 assert!(imports.is_empty());
323 }
324
325 #[test]
326 fn test_multiple_imports_collected() {
327 let table = make_table("users", vec![
328 make_col("id", "uuid", false),
329 make_col("created_at", "timestamptz", false),
330 ]);
331 let schema = SchemaInfo::default();
332 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
333 assert!(imports.iter().any(|i| i.contains("uuid")));
334 assert!(imports.iter().any(|i| i.contains("chrono")));
335 }
336
337 #[test]
340 fn test_mysql_enum_column() {
341 let table = make_table("users", vec![ColumnInfo {
342 name: "status".to_string(),
343 data_type: "enum".to_string(),
344 udt_name: "enum('active','inactive')".to_string(),
345 is_nullable: false,
346 ordinal_position: 0,
347 schema_name: "test_db".to_string(),
348 }]);
349 let schema = SchemaInfo::default();
350 let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
351 assert!(code.contains("UsersStatus"));
352 assert!(imports.iter().any(|i| i.contains("super::types::")));
353 }
354
355 #[test]
356 fn test_mysql_enum_nullable() {
357 let table = make_table("users", vec![ColumnInfo {
358 name: "status".to_string(),
359 data_type: "enum".to_string(),
360 udt_name: "enum('a','b')".to_string(),
361 is_nullable: true,
362 ordinal_position: 0,
363 schema_name: "test_db".to_string(),
364 }]);
365 let schema = SchemaInfo::default();
366 let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
367 assert!(code.contains("Option<UsersStatus>"));
368 }
369
370 #[test]
373 fn test_type_override() {
374 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
375 let schema = SchemaInfo::default();
376 let mut overrides = HashMap::new();
377 overrides.insert("jsonb".to_string(), "MyJson".to_string());
378 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
379 assert!(code.contains("pub data: MyJson"));
380 }
381
382 #[test]
383 fn test_type_override_absent() {
384 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
385 let schema = SchemaInfo::default();
386 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
387 assert!(code.contains("Value"));
388 }
389
390 #[test]
391 fn test_type_override_nullable() {
392 let table = make_table("users", vec![make_col("data", "jsonb", true)]);
393 let schema = SchemaInfo::default();
394 let mut overrides = HashMap::new();
395 overrides.insert("jsonb".to_string(), "MyJson".to_string());
396 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
397 assert!(code.contains("Option<MyJson>"));
398 }
399}