1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{SchemaInfo, TableInfo};
10use crate::typemap;
11
12pub fn generate_struct(
13 table: &TableInfo,
14 db_kind: DatabaseKind,
15 schema_info: &SchemaInfo,
16 extra_derives: &[String],
17 type_overrides: &HashMap<String, String>,
18 is_view: bool,
19) -> (TokenStream, BTreeSet<String>) {
20 let mut imports = BTreeSet::new();
21 for imp in imports_for_derives(extra_derives) {
22 imports.insert(imp);
23 }
24 let struct_name = format_ident!("{}", table.name.to_upper_camel_case());
25
26 imports.insert("use serde::{Serialize, Deserialize};".to_string());
28 imports.insert("use sqlx_gen::SqlxGen;".to_string());
29 let mut derive_tokens = vec![
30 quote! { Debug },
31 quote! { Clone },
32 quote! { PartialEq },
33 quote! { Eq },
34 quote! { Serialize },
35 quote! { Deserialize },
36 quote! { sqlx::FromRow },
37 quote! { SqlxGen },
38 ];
39 for d in extra_derives {
40 let ident = format_ident!("{}", d);
41 derive_tokens.push(quote! { #ident });
42 }
43
44 let fields: Vec<TokenStream> = table
46 .columns
47 .iter()
48 .map(|col| {
49 let rust_type = resolve_column_type(col, db_kind, table, schema_info, type_overrides);
50 if let Some(imp) = &rust_type.needs_import {
51 imports.insert(imp.clone());
52 }
53
54 let field_name_snake = col.name.to_snake_case();
55 let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
58 let prefixed = format!(
59 "{}_{}",
60 table.name.to_snake_case(),
61 field_name_snake
62 );
63 (prefixed, true)
64 } else {
65 let changed = field_name_snake != col.name;
66 (field_name_snake, changed)
67 };
68
69 let field_ident = format_ident!("{}", effective_name);
70 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
71 let fallback = format_ident!("String");
72 quote! { #fallback }
73 });
74
75 let rename = if needs_rename {
76 let original = &col.name;
77 quote! { #[sqlx(rename = #original)] }
78 } else {
79 quote! {}
80 };
81
82 let pk_attr = if col.is_primary_key {
83 quote! { #[sqlx_gen(primary_key)] }
84 } else {
85 quote! {}
86 };
87
88 quote! {
89 #pk_attr
90 #rename
91 pub #field_ident: #type_tokens,
92 }
93 })
94 .collect();
95
96 let table_name_str = &table.name;
97 let schema_name_str = &table.schema_name;
98 let kind_str = if is_view { "view" } else { "table" };
99
100 let tokens = quote! {
101 #[derive(#(#derive_tokens),*)]
102 #[sqlx_gen(kind = #kind_str, schema = #schema_name_str, table = #table_name_str)]
103 pub struct #struct_name {
104 #(#fields)*
105 }
106 };
107
108 (tokens, imports)
109}
110
111fn resolve_column_type(
112 col: &crate::introspect::ColumnInfo,
113 db_kind: DatabaseKind,
114 table: &TableInfo,
115 schema_info: &SchemaInfo,
116 type_overrides: &HashMap<String, String>,
117) -> typemap::RustType {
118 if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
120 let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
121 let rt = typemap::RustType::with_import(
122 &enum_type_name,
123 &format!("use super::types::{};", enum_type_name),
124 );
125 return if col.is_nullable {
126 rt.wrap_option()
127 } else {
128 rt
129 };
130 }
131
132 typemap::map_column(col, db_kind, schema_info, type_overrides)
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use crate::codegen::parse_and_format;
139 use crate::introspect::ColumnInfo;
140
141 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
142 TableInfo {
143 schema_name: "public".to_string(),
144 name: name.to_string(),
145 columns,
146 }
147 }
148
149 fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
150 ColumnInfo {
151 name: name.to_string(),
152 data_type: udt_name.to_string(),
153 udt_name: udt_name.to_string(),
154 is_nullable: nullable,
155 is_primary_key: false,
156 ordinal_position: 0,
157 schema_name: "public".to_string(),
158 }
159 }
160
161 fn gen(table: &TableInfo) -> String {
162 let schema = SchemaInfo::default();
163 let (tokens, _) = generate_struct(table, DatabaseKind::Postgres, &schema, &[], &HashMap::new(), false);
164 parse_and_format(&tokens)
165 }
166
167 fn gen_with(
168 table: &TableInfo,
169 schema: &SchemaInfo,
170 db: DatabaseKind,
171 derives: &[String],
172 overrides: &HashMap<String, String>,
173 ) -> (String, BTreeSet<String>) {
174 let (tokens, imports) = generate_struct(table, db, schema, derives, overrides, false);
175 (parse_and_format(&tokens), imports)
176 }
177
178 #[test]
181 fn test_simple_table() {
182 let table = make_table("users", vec![
183 make_col("id", "int4", false),
184 make_col("name", "text", false),
185 ]);
186 let code = gen(&table);
187 assert!(code.contains("pub id: i32"));
188 assert!(code.contains("pub name: String"));
189 }
190
191 #[test]
192 fn test_struct_name_pascal_case() {
193 let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
194 let code = gen(&table);
195 assert!(code.contains("pub struct UserRoles"));
196 }
197
198 #[test]
199 fn test_struct_name_simple() {
200 let table = make_table("users", vec![make_col("id", "int4", false)]);
201 let code = gen(&table);
202 assert!(code.contains("pub struct Users"));
203 }
204
205 #[test]
208 fn test_nullable_column() {
209 let table = make_table("users", vec![make_col("email", "text", true)]);
210 let code = gen(&table);
211 assert!(code.contains("pub email: Option<String>"));
212 }
213
214 #[test]
215 fn test_non_nullable_column() {
216 let table = make_table("users", vec![make_col("name", "text", false)]);
217 let code = gen(&table);
218 assert!(code.contains("pub name: String"));
219 assert!(!code.contains("Option"));
220 }
221
222 #[test]
223 fn test_mix_nullable() {
224 let table = make_table("users", vec![
225 make_col("id", "int4", false),
226 make_col("bio", "text", true),
227 ]);
228 let code = gen(&table);
229 assert!(code.contains("pub id: i32"));
230 assert!(code.contains("pub bio: Option<String>"));
231 }
232
233 #[test]
236 fn test_keyword_type_renamed() {
237 let table = make_table("connector", vec![make_col("type", "text", false)]);
238 let code = gen(&table);
239 assert!(code.contains("pub connector_type: String"));
240 assert!(code.contains("sqlx(rename = \"type\")"));
241 }
242
243 #[test]
244 fn test_keyword_fn_renamed() {
245 let table = make_table("item", vec![make_col("fn", "text", false)]);
246 let code = gen(&table);
247 assert!(code.contains("pub item_fn: String"));
248 assert!(code.contains("sqlx(rename = \"fn\")"));
249 }
250
251 #[test]
252 fn test_keyword_match_renamed() {
253 let table = make_table("game", vec![make_col("match", "text", false)]);
254 let code = gen(&table);
255 assert!(code.contains("pub game_match: String"));
256 }
257
258 #[test]
259 fn test_non_keyword_no_rename() {
260 let table = make_table("users", vec![make_col("name", "text", false)]);
261 let code = gen(&table);
262 assert!(!code.contains("sqlx(rename"));
263 }
264
265 #[test]
268 fn test_camel_case_column_renamed() {
269 let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
270 let code = gen(&table);
271 assert!(code.contains("pub created_at: String"));
272 assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
273 }
274
275 #[test]
276 fn test_mixed_case_column_renamed() {
277 let table = make_table("users", vec![make_col("firstName", "text", false)]);
278 let code = gen(&table);
279 assert!(code.contains("pub first_name: String"));
280 assert!(code.contains("sqlx(rename = \"firstName\")"));
281 }
282
283 #[test]
284 fn test_already_snake_case_no_rename() {
285 let table = make_table("users", vec![make_col("created_at", "text", false)]);
286 let code = gen(&table);
287 assert!(code.contains("pub created_at: String"));
288 assert!(!code.contains("sqlx(rename"));
289 }
290
291 #[test]
294 fn test_default_derives() {
295 let table = make_table("users", vec![make_col("id", "int4", false)]);
296 let code = gen(&table);
297 assert!(code.contains("Debug"));
298 assert!(code.contains("Clone"));
299 assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
300 }
301
302 #[test]
303 fn test_extra_derive_serialize() {
304 let table = make_table("users", vec![make_col("id", "int4", false)]);
305 let schema = SchemaInfo::default();
306 let derives = vec!["Serialize".to_string()];
307 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
308 assert!(code.contains("Serialize"));
309 }
310
311 #[test]
312 fn test_extra_derives_both_serde() {
313 let table = make_table("users", vec![make_col("id", "int4", false)]);
314 let schema = SchemaInfo::default();
315 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
316 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
317 assert!(imports.iter().any(|i| i.contains("serde")));
318 }
319
320 #[test]
323 fn test_uuid_import() {
324 let table = make_table("users", vec![make_col("id", "uuid", false)]);
325 let schema = SchemaInfo::default();
326 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
327 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
328 }
329
330 #[test]
331 fn test_timestamptz_import() {
332 let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
333 let schema = SchemaInfo::default();
334 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
335 assert!(imports.iter().any(|i| i.contains("chrono")));
336 }
337
338 #[test]
339 fn test_int4_only_serde_import() {
340 let table = make_table("users", vec![make_col("id", "int4", false)]);
341 let schema = SchemaInfo::default();
342 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
343 assert_eq!(imports.len(), 2);
344 assert!(imports.iter().any(|i| i.contains("serde")));
345 assert!(imports.iter().any(|i| i.contains("sqlx_gen::SqlxGen")));
346 }
347
348 #[test]
349 fn test_multiple_imports_collected() {
350 let table = make_table("users", vec![
351 make_col("id", "uuid", false),
352 make_col("created_at", "timestamptz", false),
353 ]);
354 let schema = SchemaInfo::default();
355 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
356 assert!(imports.iter().any(|i| i.contains("uuid")));
357 assert!(imports.iter().any(|i| i.contains("chrono")));
358 }
359
360 #[test]
363 fn test_mysql_enum_column() {
364 let table = make_table("users", vec![ColumnInfo {
365 name: "status".to_string(),
366 data_type: "enum".to_string(),
367 udt_name: "enum('active','inactive')".to_string(),
368 is_nullable: false,
369 is_primary_key: false,
370 ordinal_position: 0,
371 schema_name: "test_db".to_string(),
372 }]);
373 let schema = SchemaInfo::default();
374 let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
375 assert!(code.contains("UsersStatus"));
376 assert!(imports.iter().any(|i| i.contains("super::types::")));
377 }
378
379 #[test]
380 fn test_mysql_enum_nullable() {
381 let table = make_table("users", vec![ColumnInfo {
382 name: "status".to_string(),
383 data_type: "enum".to_string(),
384 udt_name: "enum('a','b')".to_string(),
385 is_nullable: true,
386 is_primary_key: false,
387 ordinal_position: 0,
388 schema_name: "test_db".to_string(),
389 }]);
390 let schema = SchemaInfo::default();
391 let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
392 assert!(code.contains("Option<UsersStatus>"));
393 }
394
395 #[test]
398 fn test_type_override() {
399 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
400 let schema = SchemaInfo::default();
401 let mut overrides = HashMap::new();
402 overrides.insert("jsonb".to_string(), "MyJson".to_string());
403 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
404 assert!(code.contains("pub data: MyJson"));
405 }
406
407 #[test]
408 fn test_type_override_absent() {
409 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
410 let schema = SchemaInfo::default();
411 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
412 assert!(code.contains("Value"));
413 }
414
415 #[test]
416 fn test_type_override_nullable() {
417 let table = make_table("users", vec![make_col("data", "jsonb", true)]);
418 let schema = SchemaInfo::default();
419 let mut overrides = HashMap::new();
420 overrides.insert("jsonb".to_string(), "MyJson".to_string());
421 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
422 assert!(code.contains("Option<MyJson>"));
423 }
424}