Skip to main content

sqlcx_core/generator/python/
pydantic.rs

1use crate::error::Result;
2use crate::generator::{GeneratedFile, SchemaGenerator};
3use crate::ir::{ColumnDef, EnumDef, JsonShape, Overrides, SqlType, SqlTypeCategory, SqlcxIR};
4use crate::utils::{escape_string, pascal_case};
5use std::collections::BTreeSet;
6
7pub struct PydanticGenerator;
8
9// ── Type mapping ──────────────────────────────────────────────────────────────
10
11fn json_shape_to_python(shape: &JsonShape) -> String {
12    match shape {
13        JsonShape::String => "str".to_string(),
14        JsonShape::Number => "float".to_string(),
15        JsonShape::Boolean => "bool".to_string(),
16        JsonShape::Object { .. } => "dict[str, Any]".to_string(),
17        JsonShape::Array { element } => {
18            format!("list[{}]", json_shape_to_python(element))
19        }
20        JsonShape::Nullable { inner } => {
21            format!("{} | None", json_shape_to_python(inner))
22        }
23    }
24}
25
26fn python_type(sql_type: &SqlType, overrides: &Overrides) -> String {
27    // Check overrides first
28    if let Some(override_type) = overrides.get(&sql_type.normalized) {
29        return match override_type.as_str() {
30            "string" | "str" => "str".to_string(),
31            "number" | "int" => "int".to_string(),
32            "float" => "float".to_string(),
33            "boolean" | "bool" => "bool".to_string(),
34            other => other.to_string(),
35        };
36    }
37
38    // Inline @enum annotation
39    if let Some(enum_values) = &sql_type.enum_values {
40        let literals = enum_values
41            .iter()
42            .map(|v| format!("\"{}\"", escape_string(v)))
43            .collect::<Vec<_>>()
44            .join(", ");
45        return format!("Literal[{}]", literals);
46    }
47
48    // Inline @json annotation
49    if let Some(json_shape) = &sql_type.json_shape {
50        return json_shape_to_python(json_shape);
51    }
52
53    // Array type
54    if let Some(element_type) = &sql_type.element_type {
55        return format!("list[{}]", python_type(element_type, overrides));
56    }
57
58    match &sql_type.category {
59        SqlTypeCategory::String | SqlTypeCategory::Uuid => "str".to_string(),
60        SqlTypeCategory::Number => {
61            let raw_upper = sql_type.raw.to_uppercase();
62            if raw_upper.contains("REAL")
63                || raw_upper.contains("FLOAT")
64                || raw_upper.contains("DOUBLE")
65                || raw_upper.contains("DECIMAL")
66                || raw_upper.contains("NUMERIC")
67            {
68                "float".to_string()
69            } else {
70                "int".to_string()
71            }
72        }
73        SqlTypeCategory::Boolean => "bool".to_string(),
74        SqlTypeCategory::Date => {
75            let raw_upper = sql_type.raw.to_uppercase();
76            if raw_upper.contains("TIMESTAMP") {
77                "datetime".to_string()
78            } else if raw_upper.contains("TIME") {
79                "time".to_string()
80            } else {
81                "date".to_string()
82            }
83        }
84        SqlTypeCategory::Json => "Any".to_string(),
85        SqlTypeCategory::Binary => "bytes".to_string(),
86        SqlTypeCategory::Enum => {
87            if let Some(enum_name) = &sql_type.enum_name {
88                pascal_case(enum_name)
89            } else {
90                "str".to_string()
91            }
92        }
93        SqlTypeCategory::Unknown => "Any".to_string(),
94    }
95}
96
97fn select_field(col: &ColumnDef, overrides: &Overrides) -> String {
98    let base = python_type(&col.sql_type, overrides);
99    let field_name = &col.name;
100    if col.nullable {
101        format!("    {}: {} | None", field_name, base)
102    } else {
103        format!("    {}: {}", field_name, base)
104    }
105}
106
107fn insert_field(col: &ColumnDef, overrides: &Overrides) -> String {
108    let base = python_type(&col.sql_type, overrides);
109    let field_name = &col.name;
110    if col.has_default || col.nullable {
111        format!("    {}: {} | None = None", field_name, base)
112    } else {
113        format!("    {}: {}", field_name, base)
114    }
115}
116
117// ── Import collection ─────────────────────────────────────────────────────────
118
119fn collect_imports(ir: &SqlcxIR, overrides: &Overrides) -> BTreeSet<String> {
120    let mut imports = BTreeSet::new();
121    imports.insert("from pydantic import BaseModel, ConfigDict".to_string());
122
123    let mut needs_datetime = false;
124    let mut needs_time = false;
125    let mut needs_date = false;
126    let mut needs_any = false;
127    let mut needs_literal = false;
128
129    for table in &ir.tables {
130        for col in &table.columns {
131            collect_type_imports(
132                &col.sql_type,
133                overrides,
134                &mut needs_datetime,
135                &mut needs_time,
136                &mut needs_date,
137                &mut needs_any,
138                &mut needs_literal,
139            );
140        }
141    }
142
143    if !ir.enums.is_empty() {
144        imports.insert("from enum import Enum".to_string());
145    }
146
147    let mut typing_imports = Vec::new();
148    if needs_any {
149        typing_imports.push("Any");
150    }
151    if needs_literal {
152        typing_imports.push("Literal");
153    }
154    if !typing_imports.is_empty() {
155        imports.insert(format!("from typing import {}", typing_imports.join(", ")));
156    }
157
158    let mut dt_imports = Vec::new();
159    if needs_datetime {
160        dt_imports.push("datetime");
161    }
162    if needs_date {
163        dt_imports.push("date");
164    }
165    if needs_time {
166        dt_imports.push("time");
167    }
168    if !dt_imports.is_empty() {
169        imports.insert(format!("from datetime import {}", dt_imports.join(", ")));
170    }
171
172    imports
173}
174
175fn collect_type_imports(
176    sql_type: &SqlType,
177    overrides: &Overrides,
178    needs_datetime: &mut bool,
179    needs_time: &mut bool,
180    needs_date: &mut bool,
181    needs_any: &mut bool,
182    needs_literal: &mut bool,
183) {
184    if overrides.contains_key(&sql_type.normalized) {
185        return;
186    }
187
188    if sql_type.enum_values.is_some() {
189        *needs_literal = true;
190        return;
191    }
192
193    if sql_type.json_shape.is_some() {
194        *needs_any = true;
195        return;
196    }
197
198    if let Some(elem) = &sql_type.element_type {
199        collect_type_imports(
200            elem,
201            overrides,
202            needs_datetime,
203            needs_time,
204            needs_date,
205            needs_any,
206            needs_literal,
207        );
208        return;
209    }
210
211    match sql_type.category {
212        SqlTypeCategory::Date => {
213            let raw_upper = sql_type.raw.to_uppercase();
214            if raw_upper.contains("TIMESTAMP") {
215                *needs_datetime = true;
216            } else if raw_upper.contains("TIME") {
217                *needs_time = true;
218            } else {
219                *needs_date = true;
220            }
221        }
222        SqlTypeCategory::Json | SqlTypeCategory::Unknown => {
223            *needs_any = true;
224        }
225        _ => {}
226    }
227}
228
229// ── Generator ─────────────────────────────────────────────────────────────────
230
231/// Sanitize an enum value into a valid Python identifier.
232/// "in-progress" → "IN_PROGRESS", "class" → "CLASS_"
233fn sanitize_variant(v: &str) -> String {
234    let upper = v.to_uppercase().replace(['-', ' '], "_");
235    // Prefix with underscore if starts with digit
236    let safe = if upper.starts_with(|c: char| c.is_ascii_digit()) {
237        format!("_{}", upper)
238    } else {
239        upper
240    };
241    // Append underscore if it's a Python keyword
242    match safe.as_str() {
243        "FALSE" | "TRUE" | "NONE" | "AND" | "OR" | "NOT" | "IS" | "IN" | "IF" | "ELSE" | "FOR"
244        | "WHILE" | "CLASS" | "DEF" | "RETURN" | "IMPORT" | "FROM" | "AS" | "WITH" | "YIELD"
245        | "BREAK" | "CONTINUE" | "PASS" | "RAISE" | "TRY" | "EXCEPT" | "FINALLY" => {
246            format!("{}_", safe)
247        }
248        _ => safe,
249    }
250}
251
252fn generate_enum(enum_def: &EnumDef) -> String {
253    let name = pascal_case(&enum_def.name);
254    let variants: Vec<String> = enum_def
255        .values
256        .iter()
257        .map(|v| {
258            let variant_name = sanitize_variant(v);
259            format!("    {} = \"{}\"", variant_name, escape_string(v))
260        })
261        .collect();
262    format!("class {}(str, Enum):\n{}", name, variants.join("\n"))
263}
264
265fn generate_select_model(table: &crate::ir::TableDef, overrides: &Overrides) -> String {
266    let name = format!("Select{}", pascal_case(&table.name));
267    let fields: Vec<String> = table
268        .columns
269        .iter()
270        .map(|col| select_field(col, overrides))
271        .collect();
272    format!(
273        "class {}(BaseModel):\n    model_config = ConfigDict(from_attributes=True)\n\n{}",
274        name,
275        fields.join("\n")
276    )
277}
278
279fn generate_insert_model(table: &crate::ir::TableDef, overrides: &Overrides) -> String {
280    let name = format!("Insert{}", pascal_case(&table.name));
281
282    // Required fields first, optional fields last (Python syntax)
283    let mut required: Vec<String> = Vec::new();
284    let mut optional: Vec<String> = Vec::new();
285
286    for col in &table.columns {
287        let field = insert_field(col, overrides);
288        if col.has_default || col.nullable {
289            optional.push(field);
290        } else {
291            required.push(field);
292        }
293    }
294
295    let mut fields = required;
296    fields.extend(optional);
297
298    format!(
299        "class {}(BaseModel):\n    model_config = ConfigDict(from_attributes=True)\n\n{}",
300        name,
301        fields.join("\n")
302    )
303}
304
305impl PydanticGenerator {
306    pub fn generate_models_file(&self, ir: &SqlcxIR, overrides: &Overrides) -> String {
307        let mut parts: Vec<String> = Vec::new();
308
309        parts.push("# Code generated by sqlcx. DO NOT EDIT.".to_string());
310
311        // Imports
312        let imports = collect_imports(ir, overrides);
313        parts.push(imports.into_iter().collect::<Vec<_>>().join("\n"));
314
315        // Enums
316        for enum_def in &ir.enums {
317            parts.push(generate_enum(enum_def));
318        }
319
320        // Models
321        for table in &ir.tables {
322            parts.push(generate_select_model(table, overrides));
323            parts.push(generate_insert_model(table, overrides));
324        }
325
326        parts.join("\n\n") + "\n"
327    }
328}
329
330impl SchemaGenerator for PydanticGenerator {
331    fn generate(&self, ir: &SqlcxIR, overrides: &Overrides) -> Result<GeneratedFile> {
332        Ok(GeneratedFile {
333            path: "models.py".to_string(),
334            content: self.generate_models_file(ir, overrides),
335        })
336    }
337}
338
339// ── Tests ─────────────────────────────────────────────────────────────────────
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use crate::ir::*;
345    use crate::parser::DatabaseParser;
346    use crate::parser::postgres::PostgresParser;
347    use std::collections::HashMap;
348
349    fn parse_fixture_ir() -> SqlcxIR {
350        let schema_sql = include_str!("../../../../../tests/fixtures/schema.sql");
351        let parser = PostgresParser::new();
352        let (tables, enums) = parser.parse_schema(schema_sql).unwrap();
353        SqlcxIR {
354            tables,
355            queries: vec![],
356            enums,
357        }
358    }
359
360    #[test]
361    fn generates_models_file() {
362        let ir = parse_fixture_ir();
363        let gen_ = PydanticGenerator;
364        let content = gen_.generate_models_file(&ir, &HashMap::new());
365        assert!(content.contains("from pydantic import BaseModel, ConfigDict"));
366        assert!(content.contains("from enum import Enum"));
367        assert!(content.contains("class UserStatus(str, Enum):"));
368        assert!(content.contains("class SelectUsers(BaseModel):"));
369        assert!(content.contains("model_config = ConfigDict(from_attributes=True)"));
370        assert!(content.contains("class InsertUsers(BaseModel):"));
371        assert!(content.contains("class SelectPosts(BaseModel):"));
372        assert!(content.contains("class InsertPosts(BaseModel):"));
373        insta::assert_snapshot!("pydantic_models", content);
374    }
375
376    #[test]
377    fn nullable_column_uses_union_none() {
378        let col = ColumnDef {
379            name: "bio".to_string(),
380            alias: None,
381            source_table: None,
382            sql_type: SqlType {
383                raw: "text".to_string(),
384                normalized: "text".to_string(),
385                category: SqlTypeCategory::String,
386                element_type: None,
387                enum_name: None,
388                enum_values: None,
389                json_shape: None,
390            },
391            nullable: true,
392            has_default: false,
393        };
394        let result = select_field(&col, &HashMap::new());
395        assert_eq!(result, "    bio: str | None");
396    }
397
398    #[test]
399    fn default_column_is_optional_in_insert() {
400        let col = ColumnDef {
401            name: "status".to_string(),
402            alias: None,
403            source_table: None,
404            sql_type: SqlType {
405                raw: "text".to_string(),
406                normalized: "text".to_string(),
407                category: SqlTypeCategory::String,
408                element_type: None,
409                enum_name: None,
410                enum_values: None,
411                json_shape: None,
412            },
413            nullable: false,
414            has_default: true,
415        };
416        let result = insert_field(&col, &HashMap::new());
417        assert_eq!(result, "    status: str | None = None");
418    }
419
420    #[test]
421    fn sanitize_variant_handles_hyphens_and_keywords() {
422        assert_eq!(sanitize_variant("in-progress"), "IN_PROGRESS");
423        assert_eq!(sanitize_variant("class"), "CLASS_");
424        assert_eq!(sanitize_variant("active"), "ACTIVE");
425        assert_eq!(sanitize_variant("123bad"), "_123BAD");
426    }
427
428    #[test]
429    fn enum_type_uses_pascal_case() {
430        let sql_type = SqlType {
431            raw: "user_status".to_string(),
432            normalized: "user_status".to_string(),
433            category: SqlTypeCategory::Enum,
434            element_type: None,
435            enum_name: Some("user_status".to_string()),
436            enum_values: None,
437            json_shape: None,
438        };
439        let result = python_type(&sql_type, &HashMap::new());
440        assert_eq!(result, "UserStatus");
441    }
442
443    #[test]
444    fn array_type_maps_to_list() {
445        let sql_type = SqlType {
446            raw: "text[]".to_string(),
447            normalized: "text[]".to_string(),
448            category: SqlTypeCategory::String,
449            element_type: Some(Box::new(SqlType {
450                raw: "text".to_string(),
451                normalized: "text".to_string(),
452                category: SqlTypeCategory::String,
453                element_type: None,
454                enum_name: None,
455                enum_values: None,
456                json_shape: None,
457            })),
458            enum_name: None,
459            enum_values: None,
460            json_shape: None,
461        };
462        let result = python_type(&sql_type, &HashMap::new());
463        assert_eq!(result, "list[str]");
464    }
465
466    #[test]
467    fn timestamp_maps_to_datetime() {
468        let sql_type = SqlType {
469            raw: "TIMESTAMP".to_string(),
470            normalized: "timestamp".to_string(),
471            category: SqlTypeCategory::Date,
472            element_type: None,
473            enum_name: None,
474            enum_values: None,
475            json_shape: None,
476        };
477        let result = python_type(&sql_type, &HashMap::new());
478        assert_eq!(result, "datetime");
479    }
480}