1use crate::error::Result;
2use crate::generator::{GeneratedFile, SchemaGenerator};
3use crate::ir::{ColumnDef, EnumDef, JsonShape, Overrides, SqlType, SqlTypeCategory, SqlcxIR};
4use crate::utils::{escape_string, pascal_case};
5use std::collections::BTreeSet;
6
7pub struct PydanticGenerator;
8
9fn json_shape_to_python(shape: &JsonShape) -> String {
12 match shape {
13 JsonShape::String => "str".to_string(),
14 JsonShape::Number => "float".to_string(),
15 JsonShape::Boolean => "bool".to_string(),
16 JsonShape::Object { .. } => "dict[str, Any]".to_string(),
17 JsonShape::Array { element } => {
18 format!("list[{}]", json_shape_to_python(element))
19 }
20 JsonShape::Nullable { inner } => {
21 format!("{} | None", json_shape_to_python(inner))
22 }
23 }
24}
25
26fn python_type(sql_type: &SqlType, overrides: &Overrides) -> String {
27 if let Some(override_type) = overrides.get(&sql_type.normalized) {
29 return match override_type.as_str() {
30 "string" | "str" => "str".to_string(),
31 "number" | "int" => "int".to_string(),
32 "float" => "float".to_string(),
33 "boolean" | "bool" => "bool".to_string(),
34 other => other.to_string(),
35 };
36 }
37
38 if let Some(enum_values) = &sql_type.enum_values {
40 let literals = enum_values
41 .iter()
42 .map(|v| format!("\"{}\"", escape_string(v)))
43 .collect::<Vec<_>>()
44 .join(", ");
45 return format!("Literal[{}]", literals);
46 }
47
48 if let Some(json_shape) = &sql_type.json_shape {
50 return json_shape_to_python(json_shape);
51 }
52
53 if let Some(element_type) = &sql_type.element_type {
55 return format!("list[{}]", python_type(element_type, overrides));
56 }
57
58 match &sql_type.category {
59 SqlTypeCategory::String | SqlTypeCategory::Uuid => "str".to_string(),
60 SqlTypeCategory::Number => {
61 let raw_upper = sql_type.raw.to_uppercase();
62 if raw_upper.contains("REAL")
63 || raw_upper.contains("FLOAT")
64 || raw_upper.contains("DOUBLE")
65 || raw_upper.contains("DECIMAL")
66 || raw_upper.contains("NUMERIC")
67 {
68 "float".to_string()
69 } else {
70 "int".to_string()
71 }
72 }
73 SqlTypeCategory::Boolean => "bool".to_string(),
74 SqlTypeCategory::Date => {
75 let raw_upper = sql_type.raw.to_uppercase();
76 if raw_upper.contains("TIMESTAMP") {
77 "datetime".to_string()
78 } else if raw_upper.contains("TIME") {
79 "time".to_string()
80 } else {
81 "date".to_string()
82 }
83 }
84 SqlTypeCategory::Json => "Any".to_string(),
85 SqlTypeCategory::Binary => "bytes".to_string(),
86 SqlTypeCategory::Enum => {
87 if let Some(enum_name) = &sql_type.enum_name {
88 pascal_case(enum_name)
89 } else {
90 "str".to_string()
91 }
92 }
93 SqlTypeCategory::Unknown => "Any".to_string(),
94 }
95}
96
97fn select_field(col: &ColumnDef, overrides: &Overrides) -> String {
98 let base = python_type(&col.sql_type, overrides);
99 let field_name = &col.name;
100 if col.nullable {
101 format!(" {}: {} | None", field_name, base)
102 } else {
103 format!(" {}: {}", field_name, base)
104 }
105}
106
107fn insert_field(col: &ColumnDef, overrides: &Overrides) -> String {
108 let base = python_type(&col.sql_type, overrides);
109 let field_name = &col.name;
110 if col.has_default || col.nullable {
111 format!(" {}: {} | None = None", field_name, base)
112 } else {
113 format!(" {}: {}", field_name, base)
114 }
115}
116
117fn collect_imports(ir: &SqlcxIR, overrides: &Overrides) -> BTreeSet<String> {
120 let mut imports = BTreeSet::new();
121 imports.insert("from pydantic import BaseModel, ConfigDict".to_string());
122
123 let mut needs_datetime = false;
124 let mut needs_time = false;
125 let mut needs_date = false;
126 let mut needs_any = false;
127 let mut needs_literal = false;
128
129 for table in &ir.tables {
130 for col in &table.columns {
131 collect_type_imports(
132 &col.sql_type,
133 overrides,
134 &mut needs_datetime,
135 &mut needs_time,
136 &mut needs_date,
137 &mut needs_any,
138 &mut needs_literal,
139 );
140 }
141 }
142
143 if !ir.enums.is_empty() {
144 imports.insert("from enum import Enum".to_string());
145 }
146
147 let mut typing_imports = Vec::new();
148 if needs_any {
149 typing_imports.push("Any");
150 }
151 if needs_literal {
152 typing_imports.push("Literal");
153 }
154 if !typing_imports.is_empty() {
155 imports.insert(format!("from typing import {}", typing_imports.join(", ")));
156 }
157
158 let mut dt_imports = Vec::new();
159 if needs_datetime {
160 dt_imports.push("datetime");
161 }
162 if needs_date {
163 dt_imports.push("date");
164 }
165 if needs_time {
166 dt_imports.push("time");
167 }
168 if !dt_imports.is_empty() {
169 imports.insert(format!("from datetime import {}", dt_imports.join(", ")));
170 }
171
172 imports
173}
174
175fn collect_type_imports(
176 sql_type: &SqlType,
177 overrides: &Overrides,
178 needs_datetime: &mut bool,
179 needs_time: &mut bool,
180 needs_date: &mut bool,
181 needs_any: &mut bool,
182 needs_literal: &mut bool,
183) {
184 if overrides.contains_key(&sql_type.normalized) {
185 return;
186 }
187
188 if sql_type.enum_values.is_some() {
189 *needs_literal = true;
190 return;
191 }
192
193 if sql_type.json_shape.is_some() {
194 *needs_any = true;
195 return;
196 }
197
198 if let Some(elem) = &sql_type.element_type {
199 collect_type_imports(
200 elem,
201 overrides,
202 needs_datetime,
203 needs_time,
204 needs_date,
205 needs_any,
206 needs_literal,
207 );
208 return;
209 }
210
211 match sql_type.category {
212 SqlTypeCategory::Date => {
213 let raw_upper = sql_type.raw.to_uppercase();
214 if raw_upper.contains("TIMESTAMP") {
215 *needs_datetime = true;
216 } else if raw_upper.contains("TIME") {
217 *needs_time = true;
218 } else {
219 *needs_date = true;
220 }
221 }
222 SqlTypeCategory::Json | SqlTypeCategory::Unknown => {
223 *needs_any = true;
224 }
225 _ => {}
226 }
227}
228
229fn sanitize_variant(v: &str) -> String {
234 let upper = v.to_uppercase().replace(['-', ' '], "_");
235 let safe = if upper.starts_with(|c: char| c.is_ascii_digit()) {
237 format!("_{}", upper)
238 } else {
239 upper
240 };
241 match safe.as_str() {
243 "FALSE" | "TRUE" | "NONE" | "AND" | "OR" | "NOT" | "IS" | "IN" | "IF" | "ELSE" | "FOR"
244 | "WHILE" | "CLASS" | "DEF" | "RETURN" | "IMPORT" | "FROM" | "AS" | "WITH" | "YIELD"
245 | "BREAK" | "CONTINUE" | "PASS" | "RAISE" | "TRY" | "EXCEPT" | "FINALLY" => {
246 format!("{}_", safe)
247 }
248 _ => safe,
249 }
250}
251
252fn generate_enum(enum_def: &EnumDef) -> String {
253 let name = pascal_case(&enum_def.name);
254 let variants: Vec<String> = enum_def
255 .values
256 .iter()
257 .map(|v| {
258 let variant_name = sanitize_variant(v);
259 format!(" {} = \"{}\"", variant_name, escape_string(v))
260 })
261 .collect();
262 format!("class {}(str, Enum):\n{}", name, variants.join("\n"))
263}
264
265fn generate_select_model(table: &crate::ir::TableDef, overrides: &Overrides) -> String {
266 let name = format!("Select{}", pascal_case(&table.name));
267 let fields: Vec<String> = table
268 .columns
269 .iter()
270 .map(|col| select_field(col, overrides))
271 .collect();
272 format!(
273 "class {}(BaseModel):\n model_config = ConfigDict(from_attributes=True)\n\n{}",
274 name,
275 fields.join("\n")
276 )
277}
278
279fn generate_insert_model(table: &crate::ir::TableDef, overrides: &Overrides) -> String {
280 let name = format!("Insert{}", pascal_case(&table.name));
281
282 let mut required: Vec<String> = Vec::new();
284 let mut optional: Vec<String> = Vec::new();
285
286 for col in &table.columns {
287 let field = insert_field(col, overrides);
288 if col.has_default || col.nullable {
289 optional.push(field);
290 } else {
291 required.push(field);
292 }
293 }
294
295 let mut fields = required;
296 fields.extend(optional);
297
298 format!(
299 "class {}(BaseModel):\n model_config = ConfigDict(from_attributes=True)\n\n{}",
300 name,
301 fields.join("\n")
302 )
303}
304
305impl PydanticGenerator {
306 pub fn generate_models_file(&self, ir: &SqlcxIR, overrides: &Overrides) -> String {
307 let mut parts: Vec<String> = Vec::new();
308
309 parts.push("# Code generated by sqlcx. DO NOT EDIT.".to_string());
310
311 let imports = collect_imports(ir, overrides);
313 parts.push(imports.into_iter().collect::<Vec<_>>().join("\n"));
314
315 for enum_def in &ir.enums {
317 parts.push(generate_enum(enum_def));
318 }
319
320 for table in &ir.tables {
322 parts.push(generate_select_model(table, overrides));
323 parts.push(generate_insert_model(table, overrides));
324 }
325
326 parts.join("\n\n") + "\n"
327 }
328}
329
330impl SchemaGenerator for PydanticGenerator {
331 fn generate(&self, ir: &SqlcxIR, overrides: &Overrides) -> Result<GeneratedFile> {
332 Ok(GeneratedFile {
333 path: "models.py".to_string(),
334 content: self.generate_models_file(ir, overrides),
335 })
336 }
337}
338
339#[cfg(test)]
342mod tests {
343 use super::*;
344 use crate::ir::*;
345 use crate::parser::DatabaseParser;
346 use crate::parser::postgres::PostgresParser;
347 use std::collections::HashMap;
348
349 fn parse_fixture_ir() -> SqlcxIR {
350 let schema_sql = include_str!("../../../../../tests/fixtures/schema.sql");
351 let parser = PostgresParser::new();
352 let (tables, enums) = parser.parse_schema(schema_sql).unwrap();
353 SqlcxIR {
354 tables,
355 queries: vec![],
356 enums,
357 }
358 }
359
360 #[test]
361 fn generates_models_file() {
362 let ir = parse_fixture_ir();
363 let gen_ = PydanticGenerator;
364 let content = gen_.generate_models_file(&ir, &HashMap::new());
365 assert!(content.contains("from pydantic import BaseModel, ConfigDict"));
366 assert!(content.contains("from enum import Enum"));
367 assert!(content.contains("class UserStatus(str, Enum):"));
368 assert!(content.contains("class SelectUsers(BaseModel):"));
369 assert!(content.contains("model_config = ConfigDict(from_attributes=True)"));
370 assert!(content.contains("class InsertUsers(BaseModel):"));
371 assert!(content.contains("class SelectPosts(BaseModel):"));
372 assert!(content.contains("class InsertPosts(BaseModel):"));
373 insta::assert_snapshot!("pydantic_models", content);
374 }
375
376 #[test]
377 fn nullable_column_uses_union_none() {
378 let col = ColumnDef {
379 name: "bio".to_string(),
380 alias: None,
381 source_table: None,
382 sql_type: SqlType {
383 raw: "text".to_string(),
384 normalized: "text".to_string(),
385 category: SqlTypeCategory::String,
386 element_type: None,
387 enum_name: None,
388 enum_values: None,
389 json_shape: None,
390 },
391 nullable: true,
392 has_default: false,
393 };
394 let result = select_field(&col, &HashMap::new());
395 assert_eq!(result, " bio: str | None");
396 }
397
398 #[test]
399 fn default_column_is_optional_in_insert() {
400 let col = ColumnDef {
401 name: "status".to_string(),
402 alias: None,
403 source_table: None,
404 sql_type: SqlType {
405 raw: "text".to_string(),
406 normalized: "text".to_string(),
407 category: SqlTypeCategory::String,
408 element_type: None,
409 enum_name: None,
410 enum_values: None,
411 json_shape: None,
412 },
413 nullable: false,
414 has_default: true,
415 };
416 let result = insert_field(&col, &HashMap::new());
417 assert_eq!(result, " status: str | None = None");
418 }
419
420 #[test]
421 fn sanitize_variant_handles_hyphens_and_keywords() {
422 assert_eq!(sanitize_variant("in-progress"), "IN_PROGRESS");
423 assert_eq!(sanitize_variant("class"), "CLASS_");
424 assert_eq!(sanitize_variant("active"), "ACTIVE");
425 assert_eq!(sanitize_variant("123bad"), "_123BAD");
426 }
427
428 #[test]
429 fn enum_type_uses_pascal_case() {
430 let sql_type = SqlType {
431 raw: "user_status".to_string(),
432 normalized: "user_status".to_string(),
433 category: SqlTypeCategory::Enum,
434 element_type: None,
435 enum_name: Some("user_status".to_string()),
436 enum_values: None,
437 json_shape: None,
438 };
439 let result = python_type(&sql_type, &HashMap::new());
440 assert_eq!(result, "UserStatus");
441 }
442
443 #[test]
444 fn array_type_maps_to_list() {
445 let sql_type = SqlType {
446 raw: "text[]".to_string(),
447 normalized: "text[]".to_string(),
448 category: SqlTypeCategory::String,
449 element_type: Some(Box::new(SqlType {
450 raw: "text".to_string(),
451 normalized: "text".to_string(),
452 category: SqlTypeCategory::String,
453 element_type: None,
454 enum_name: None,
455 enum_values: None,
456 json_shape: None,
457 })),
458 enum_name: None,
459 enum_values: None,
460 json_shape: None,
461 };
462 let result = python_type(&sql_type, &HashMap::new());
463 assert_eq!(result, "list[str]");
464 }
465
466 #[test]
467 fn timestamp_maps_to_datetime() {
468 let sql_type = SqlType {
469 raw: "TIMESTAMP".to_string(),
470 normalized: "timestamp".to_string(),
471 category: SqlTypeCategory::Date,
472 element_type: None,
473 enum_name: None,
474 enum_values: None,
475 json_shape: None,
476 };
477 let result = python_type(&sql_type, &HashMap::new());
478 assert_eq!(result, "datetime");
479 }
480}