1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{SchemaInfo, TableInfo};
10use crate::typemap;
11
12pub fn generate_struct(
13 table: &TableInfo,
14 db_kind: DatabaseKind,
15 schema_info: &SchemaInfo,
16 extra_derives: &[String],
17 type_overrides: &HashMap<String, String>,
18 is_view: bool,
19) -> (TokenStream, BTreeSet<String>) {
20 let mut imports = BTreeSet::new();
21 for imp in imports_for_derives(extra_derives) {
22 imports.insert(imp);
23 }
24 let struct_name = format_ident!("{}", table.name.to_upper_camel_case());
25
26 imports.insert("use serde::{Serialize, Deserialize};".to_string());
28 imports.insert("use sqlx_gen::SqlxGen;".to_string());
29 let mut derive_tokens = vec![
30 quote! { Debug },
31 quote! { Clone },
32 quote! { PartialEq },
33 quote! { Eq },
34 quote! { Serialize },
35 quote! { Deserialize },
36 quote! { sqlx::FromRow },
37 quote! { SqlxGen },
38 ];
39 for d in extra_derives {
40 let ident = format_ident!("{}", d);
41 derive_tokens.push(quote! { #ident });
42 }
43
44 let fields: Vec<TokenStream> = table
46 .columns
47 .iter()
48 .map(|col| {
49 let rust_type = resolve_column_type(col, db_kind, table, schema_info, type_overrides);
50 if let Some(imp) = &rust_type.needs_import {
51 imports.insert(imp.clone());
52 }
53
54 let field_name_snake = col.name.to_snake_case();
55 let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
58 let prefixed = format!(
59 "{}_{}",
60 table.name.to_snake_case(),
61 field_name_snake
62 );
63 (prefixed, true)
64 } else {
65 let changed = field_name_snake != col.name;
66 (field_name_snake, changed)
67 };
68
69 let field_ident = format_ident!("{}", effective_name);
70 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
71 let fallback = format_ident!("String");
72 quote! { #fallback }
73 });
74
75 let rename = if needs_rename {
76 let original = &col.name;
77 quote! { #[sqlx(rename = #original)] }
78 } else {
79 quote! {}
80 };
81
82 let (sql_type, is_sql_array) = detect_custom_sql_type(&col.udt_name, schema_info);
84 let has_pk = col.is_primary_key;
85 let has_sql_type = sql_type.is_some();
86
87 let sqlx_gen_attr = if has_pk || has_sql_type {
88 let pk_part = if has_pk { quote! { primary_key, } } else { quote! {} };
89 let sql_type_part = match &sql_type {
90 Some(t) => quote! { sql_type = #t, },
91 None => quote! {},
92 };
93 let array_part = if is_sql_array { quote! { is_array, } } else { quote! {} };
94 quote! { #[sqlx_gen(#pk_part #sql_type_part #array_part)] }
95 } else {
96 quote! {}
97 };
98
99 quote! {
100 #sqlx_gen_attr
101 #rename
102 pub #field_ident: #type_tokens,
103 }
104 })
105 .collect();
106
107 let table_name_str = &table.name;
108 let schema_name_str = &table.schema_name;
109 let kind_str = if is_view { "view" } else { "table" };
110
111 let tokens = quote! {
112 #[derive(#(#derive_tokens),*)]
113 #[sqlx_gen(kind = #kind_str, schema = #schema_name_str, table = #table_name_str)]
114 pub struct #struct_name {
115 #(#fields)*
116 }
117 };
118
119 (tokens, imports)
120}
121
122fn detect_custom_sql_type(udt_name: &str, schema_info: &SchemaInfo) -> (Option<String>, bool) {
127 let (base_name, is_array) = match udt_name.strip_prefix('_') {
128 Some(inner) => (inner, true),
129 None => (udt_name, false),
130 };
131
132 if let Some(e) = schema_info.enums.iter().find(|e| e.name == base_name) {
134 let qualified = if e.schema_name == "public" {
135 base_name.to_string()
136 } else {
137 format!("{}.{}", e.schema_name, base_name)
138 };
139 return (Some(qualified), is_array);
140 }
141
142 if let Some(c) = schema_info.composite_types.iter().find(|c| c.name == base_name) {
144 let qualified = if c.schema_name == "public" {
145 base_name.to_string()
146 } else {
147 format!("{}.{}", c.schema_name, base_name)
148 };
149 return (Some(qualified), is_array);
150 }
151
152 let is_domain = schema_info.domains.iter().any(|d| d.name == base_name);
156 if !is_domain && !typemap::postgres::is_builtin(base_name) {
157 return (Some(base_name.to_string()), is_array);
158 }
159
160 (None, false)
161}
162
163fn resolve_column_type(
164 col: &crate::introspect::ColumnInfo,
165 db_kind: DatabaseKind,
166 table: &TableInfo,
167 schema_info: &SchemaInfo,
168 type_overrides: &HashMap<String, String>,
169) -> typemap::RustType {
170 if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
172 let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
173 let rt = typemap::RustType::with_import(
174 &enum_type_name,
175 &format!("use super::types::{};", enum_type_name),
176 );
177 return if col.is_nullable {
178 rt.wrap_option()
179 } else {
180 rt
181 };
182 }
183
184 typemap::map_column(col, db_kind, schema_info, type_overrides)
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use crate::codegen::parse_and_format;
191 use crate::introspect::ColumnInfo;
192
193 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
194 TableInfo {
195 schema_name: "public".to_string(),
196 name: name.to_string(),
197 columns,
198 }
199 }
200
201 fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
202 ColumnInfo {
203 name: name.to_string(),
204 data_type: udt_name.to_string(),
205 udt_name: udt_name.to_string(),
206 is_nullable: nullable,
207 is_primary_key: false,
208 ordinal_position: 0,
209 schema_name: "public".to_string(),
210 column_default: None,
211 }
212 }
213
214 fn gen(table: &TableInfo) -> String {
215 let schema = SchemaInfo::default();
216 let (tokens, _) = generate_struct(table, DatabaseKind::Postgres, &schema, &[], &HashMap::new(), false);
217 parse_and_format(&tokens)
218 }
219
220 fn gen_with(
221 table: &TableInfo,
222 schema: &SchemaInfo,
223 db: DatabaseKind,
224 derives: &[String],
225 overrides: &HashMap<String, String>,
226 ) -> (String, BTreeSet<String>) {
227 let (tokens, imports) = generate_struct(table, db, schema, derives, overrides, false);
228 (parse_and_format(&tokens), imports)
229 }
230
231 #[test]
234 fn test_simple_table() {
235 let table = make_table("users", vec![
236 make_col("id", "int4", false),
237 make_col("name", "text", false),
238 ]);
239 let code = gen(&table);
240 assert!(code.contains("pub id: i32"));
241 assert!(code.contains("pub name: String"));
242 }
243
244 #[test]
245 fn test_struct_name_pascal_case() {
246 let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
247 let code = gen(&table);
248 assert!(code.contains("pub struct UserRoles"));
249 }
250
251 #[test]
252 fn test_struct_name_simple() {
253 let table = make_table("users", vec![make_col("id", "int4", false)]);
254 let code = gen(&table);
255 assert!(code.contains("pub struct Users"));
256 }
257
258 #[test]
261 fn test_nullable_column() {
262 let table = make_table("users", vec![make_col("email", "text", true)]);
263 let code = gen(&table);
264 assert!(code.contains("pub email: Option<String>"));
265 }
266
267 #[test]
268 fn test_non_nullable_column() {
269 let table = make_table("users", vec![make_col("name", "text", false)]);
270 let code = gen(&table);
271 assert!(code.contains("pub name: String"));
272 assert!(!code.contains("Option"));
273 }
274
275 #[test]
276 fn test_mix_nullable() {
277 let table = make_table("users", vec![
278 make_col("id", "int4", false),
279 make_col("bio", "text", true),
280 ]);
281 let code = gen(&table);
282 assert!(code.contains("pub id: i32"));
283 assert!(code.contains("pub bio: Option<String>"));
284 }
285
286 #[test]
289 fn test_keyword_type_renamed() {
290 let table = make_table("connector", vec![make_col("type", "text", false)]);
291 let code = gen(&table);
292 assert!(code.contains("pub connector_type: String"));
293 assert!(code.contains("sqlx(rename = \"type\")"));
294 }
295
296 #[test]
297 fn test_keyword_fn_renamed() {
298 let table = make_table("item", vec![make_col("fn", "text", false)]);
299 let code = gen(&table);
300 assert!(code.contains("pub item_fn: String"));
301 assert!(code.contains("sqlx(rename = \"fn\")"));
302 }
303
304 #[test]
305 fn test_keyword_match_renamed() {
306 let table = make_table("game", vec![make_col("match", "text", false)]);
307 let code = gen(&table);
308 assert!(code.contains("pub game_match: String"));
309 }
310
311 #[test]
312 fn test_non_keyword_no_rename() {
313 let table = make_table("users", vec![make_col("name", "text", false)]);
314 let code = gen(&table);
315 assert!(!code.contains("sqlx(rename"));
316 }
317
318 #[test]
321 fn test_camel_case_column_renamed() {
322 let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
323 let code = gen(&table);
324 assert!(code.contains("pub created_at: String"));
325 assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
326 }
327
328 #[test]
329 fn test_mixed_case_column_renamed() {
330 let table = make_table("users", vec![make_col("firstName", "text", false)]);
331 let code = gen(&table);
332 assert!(code.contains("pub first_name: String"));
333 assert!(code.contains("sqlx(rename = \"firstName\")"));
334 }
335
336 #[test]
337 fn test_already_snake_case_no_rename() {
338 let table = make_table("users", vec![make_col("created_at", "text", false)]);
339 let code = gen(&table);
340 assert!(code.contains("pub created_at: String"));
341 assert!(!code.contains("sqlx(rename"));
342 }
343
344 #[test]
347 fn test_default_derives() {
348 let table = make_table("users", vec![make_col("id", "int4", false)]);
349 let code = gen(&table);
350 assert!(code.contains("Debug"));
351 assert!(code.contains("Clone"));
352 assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
353 }
354
355 #[test]
356 fn test_extra_derive_serialize() {
357 let table = make_table("users", vec![make_col("id", "int4", false)]);
358 let schema = SchemaInfo::default();
359 let derives = vec!["Serialize".to_string()];
360 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
361 assert!(code.contains("Serialize"));
362 }
363
364 #[test]
365 fn test_extra_derives_both_serde() {
366 let table = make_table("users", vec![make_col("id", "int4", false)]);
367 let schema = SchemaInfo::default();
368 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
369 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
370 assert!(imports.iter().any(|i| i.contains("serde")));
371 }
372
373 #[test]
376 fn test_uuid_import() {
377 let table = make_table("users", vec![make_col("id", "uuid", false)]);
378 let schema = SchemaInfo::default();
379 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
380 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
381 }
382
383 #[test]
384 fn test_timestamptz_import() {
385 let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
386 let schema = SchemaInfo::default();
387 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
388 assert!(imports.iter().any(|i| i.contains("chrono")));
389 }
390
391 #[test]
392 fn test_int4_only_serde_import() {
393 let table = make_table("users", vec![make_col("id", "int4", false)]);
394 let schema = SchemaInfo::default();
395 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
396 assert_eq!(imports.len(), 2);
397 assert!(imports.iter().any(|i| i.contains("serde")));
398 assert!(imports.iter().any(|i| i.contains("sqlx_gen::SqlxGen")));
399 }
400
401 #[test]
402 fn test_multiple_imports_collected() {
403 let table = make_table("users", vec![
404 make_col("id", "uuid", false),
405 make_col("created_at", "timestamptz", false),
406 ]);
407 let schema = SchemaInfo::default();
408 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
409 assert!(imports.iter().any(|i| i.contains("uuid")));
410 assert!(imports.iter().any(|i| i.contains("chrono")));
411 }
412
413 #[test]
416 fn test_mysql_enum_column() {
417 let table = make_table("users", vec![ColumnInfo {
418 name: "status".to_string(),
419 data_type: "enum".to_string(),
420 udt_name: "enum('active','inactive')".to_string(),
421 is_nullable: false,
422 is_primary_key: false,
423 ordinal_position: 0,
424 schema_name: "test_db".to_string(),
425 column_default: None,
426 }]);
427 let schema = SchemaInfo::default();
428 let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
429 assert!(code.contains("UsersStatus"));
430 assert!(imports.iter().any(|i| i.contains("super::types::")));
431 }
432
433 #[test]
434 fn test_mysql_enum_nullable() {
435 let table = make_table("users", vec![ColumnInfo {
436 name: "status".to_string(),
437 data_type: "enum".to_string(),
438 udt_name: "enum('a','b')".to_string(),
439 is_nullable: true,
440 is_primary_key: false,
441 ordinal_position: 0,
442 schema_name: "test_db".to_string(),
443 column_default: None,
444 }]);
445 let schema = SchemaInfo::default();
446 let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
447 assert!(code.contains("Option<UsersStatus>"));
448 }
449
450 #[test]
453 fn test_type_override() {
454 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
455 let schema = SchemaInfo::default();
456 let mut overrides = HashMap::new();
457 overrides.insert("jsonb".to_string(), "MyJson".to_string());
458 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
459 assert!(code.contains("pub data: MyJson"));
460 }
461
462 #[test]
463 fn test_type_override_absent() {
464 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
465 let schema = SchemaInfo::default();
466 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
467 assert!(code.contains("Value"));
468 }
469
470 #[test]
471 fn test_type_override_nullable() {
472 let table = make_table("users", vec![make_col("data", "jsonb", true)]);
473 let schema = SchemaInfo::default();
474 let mut overrides = HashMap::new();
475 overrides.insert("jsonb".to_string(), "MyJson".to_string());
476 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
477 assert!(code.contains("Option<MyJson>"));
478 }
479}