Skip to main content

typewriter_python/
mapper.rs

1//! Python type mapper implementation.
2
3use typewriter_core::ir::*;
4use typewriter_core::mapper::TypeMapper;
5
6use crate::emitter;
7
8/// Python language mapper.
9///
10/// Generates Pydantic v2 `BaseModel` classes from Rust structs and
11/// Python `Enum` / `Union` types from Rust enums.
12pub struct PythonMapper;
13
14impl PythonMapper {
15    pub fn new() -> Self {
16        Self
17    }
18}
19
20impl Default for PythonMapper {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl TypeMapper for PythonMapper {
27    fn map_primitive(&self, ty: &PrimitiveType) -> String {
28        match ty {
29            PrimitiveType::String => "str".to_string(),
30            PrimitiveType::Bool => "bool".to_string(),
31            PrimitiveType::U8
32            | PrimitiveType::U16
33            | PrimitiveType::U32
34            | PrimitiveType::U64
35            | PrimitiveType::U128
36            | PrimitiveType::I8
37            | PrimitiveType::I16
38            | PrimitiveType::I32
39            | PrimitiveType::I64
40            | PrimitiveType::I128 => "int".to_string(),
41            PrimitiveType::F32 | PrimitiveType::F64 => "float".to_string(),
42            PrimitiveType::Uuid => "UUID".to_string(),
43            PrimitiveType::DateTime => "datetime".to_string(),
44            PrimitiveType::NaiveDate => "date".to_string(),
45            PrimitiveType::JsonValue => "Any".to_string(),
46        }
47    }
48
49    fn map_option(&self, inner: &TypeKind) -> String {
50        format!("Optional[{}]", self.map_type(inner))
51    }
52
53    fn map_vec(&self, inner: &TypeKind) -> String {
54        format!("list[{}]", self.map_type(inner))
55    }
56
57    fn map_hashmap(&self, key: &TypeKind, value: &TypeKind) -> String {
58        format!("dict[{}, {}]", self.map_type(key), self.map_type(value))
59    }
60
61    fn map_tuple(&self, elements: &[TypeKind]) -> String {
62        let inner: Vec<String> = elements.iter().map(|e| self.map_type(e)).collect();
63        format!("tuple[{}]", inner.join(", "))
64    }
65
66    fn map_named(&self, name: &str) -> String {
67        name.to_string()
68    }
69
70    fn emit_struct(&self, def: &StructDef) -> String {
71        emitter::render_model(self, def)
72    }
73
74    fn emit_enum(&self, def: &EnumDef) -> String {
75        emitter::render_enum(self, def)
76    }
77
78    fn file_header(&self, type_name: &str) -> String {
79        format!(
80            "# Auto-generated by typewriter v0.1.0. DO NOT EDIT.\n\
81             # Source: {}\n\
82             # Regenerate: cargo typewriter generate\n\n",
83            type_name
84        )
85    }
86
87    fn file_extension(&self) -> &str {
88        "py"
89    }
90
91    fn file_naming(&self, type_name: &str) -> String {
92        to_snake_case(type_name)
93    }
94
95    fn map_type(&self, ty: &TypeKind) -> String {
96        match ty {
97            TypeKind::Primitive(p) => self.map_primitive(p),
98            TypeKind::Option(inner) => self.map_option(inner),
99            TypeKind::Vec(inner) => self.map_vec(inner),
100            TypeKind::HashMap(k, v) => self.map_hashmap(k, v),
101            TypeKind::Tuple(elements) => self.map_tuple(elements),
102            TypeKind::Named(name) => self.map_named(name),
103            TypeKind::Generic(name, _params) => self.map_named(name),
104            TypeKind::Unit => "None".to_string(),
105        }
106    }
107}
108
109/// Convert PascalCase to snake_case.
110///
111/// # Examples
112/// - `"UserProfile"` → `"user_profile"`
113/// - `"HTTPResponse"` → `"http_response"`
114fn to_snake_case(name: &str) -> String {
115    let mut result = String::new();
116    for (i, c) in name.chars().enumerate() {
117        if c.is_uppercase() {
118            if i > 0 {
119                let prev_lower = name.chars().nth(i - 1).map_or(false, |p| p.is_lowercase());
120                let next_lower = name.chars().nth(i + 1).map_or(false, |n| n.is_lowercase());
121                if prev_lower || next_lower {
122                    result.push('_');
123                }
124            }
125            result.push(c.to_lowercase().next().unwrap());
126        } else {
127            result.push(c);
128        }
129    }
130    result
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    fn mapper() -> PythonMapper {
138        PythonMapper::new()
139    }
140
141    #[test]
142    fn test_primitive_mappings() {
143        let m = mapper();
144        assert_eq!(m.map_primitive(&PrimitiveType::String), "str");
145        assert_eq!(m.map_primitive(&PrimitiveType::Bool), "bool");
146        assert_eq!(m.map_primitive(&PrimitiveType::U32), "int");
147        assert_eq!(m.map_primitive(&PrimitiveType::I64), "int");
148        assert_eq!(m.map_primitive(&PrimitiveType::F64), "float");
149        assert_eq!(m.map_primitive(&PrimitiveType::Uuid), "UUID");
150        assert_eq!(m.map_primitive(&PrimitiveType::DateTime), "datetime");
151        assert_eq!(m.map_primitive(&PrimitiveType::NaiveDate), "date");
152        assert_eq!(m.map_primitive(&PrimitiveType::JsonValue), "Any");
153    }
154
155    #[test]
156    fn test_option_mapping() {
157        let m = mapper();
158        assert_eq!(
159            m.map_option(&TypeKind::Primitive(PrimitiveType::String)),
160            "Optional[str]"
161        );
162    }
163
164    #[test]
165    fn test_vec_mapping() {
166        let m = mapper();
167        assert_eq!(
168            m.map_vec(&TypeKind::Primitive(PrimitiveType::U32)),
169            "list[int]"
170        );
171    }
172
173    #[test]
174    fn test_hashmap_mapping() {
175        let m = mapper();
176        assert_eq!(
177            m.map_hashmap(
178                &TypeKind::Primitive(PrimitiveType::String),
179                &TypeKind::Primitive(PrimitiveType::U32)
180            ),
181            "dict[str, int]"
182        );
183    }
184
185    #[test]
186    fn test_tuple_mapping() {
187        let m = mapper();
188        assert_eq!(
189            m.map_tuple(&[
190                TypeKind::Primitive(PrimitiveType::String),
191                TypeKind::Primitive(PrimitiveType::Bool)
192            ]),
193            "tuple[str, bool]"
194        );
195    }
196
197    #[test]
198    fn test_file_naming() {
199        let m = mapper();
200        assert_eq!(m.file_naming("UserProfile"), "user_profile");
201        assert_eq!(m.file_naming("User"), "user");
202        assert_eq!(m.file_naming("HTTPResponse"), "http_response");
203    }
204
205    #[test]
206    fn test_output_filename() {
207        let m = mapper();
208        assert_eq!(m.output_filename("UserProfile"), "user_profile.py");
209    }
210
211    #[test]
212    fn test_emit_simple_struct() {
213        let m = mapper();
214        let def = StructDef {
215            name: "User".to_string(),
216            fields: vec![
217                FieldDef {
218                    name: "id".to_string(),
219                    ty: TypeKind::Primitive(PrimitiveType::Uuid),
220                    optional: false,
221                    rename: None,
222                    doc: None,
223                    skip: false,
224                    flatten: false,
225                },
226                FieldDef {
227                    name: "email".to_string(),
228                    ty: TypeKind::Primitive(PrimitiveType::String),
229                    optional: false,
230                    rename: None,
231                    doc: None,
232                    skip: false,
233                    flatten: false,
234                },
235                FieldDef {
236                    name: "age".to_string(),
237                    ty: TypeKind::Option(Box::new(TypeKind::Primitive(PrimitiveType::U32))),
238                    optional: true,
239                    rename: None,
240                    doc: None,
241                    skip: false,
242                    flatten: false,
243                },
244            ],
245            doc: None,
246            generics: vec![],
247        };
248
249        let output = m.emit_struct(&def);
250        assert!(output.contains("class User(BaseModel):"));
251        assert!(output.contains("id: UUID"));
252        assert!(output.contains("email: str"));
253        assert!(output.contains("age: Optional[int] = None"));
254    }
255
256    #[test]
257    fn test_skipped_field() {
258        let m = mapper();
259        let def = StructDef {
260            name: "User".to_string(),
261            fields: vec![
262                FieldDef {
263                    name: "email".to_string(),
264                    ty: TypeKind::Primitive(PrimitiveType::String),
265                    optional: false,
266                    rename: None,
267                    doc: None,
268                    skip: false,
269                    flatten: false,
270                },
271                FieldDef {
272                    name: "password_hash".to_string(),
273                    ty: TypeKind::Primitive(PrimitiveType::String),
274                    optional: false,
275                    rename: None,
276                    doc: None,
277                    skip: true,
278                    flatten: false,
279                },
280            ],
281            doc: None,
282            generics: vec![],
283        };
284
285        let output = m.emit_struct(&def);
286        assert!(output.contains("email: str"));
287        assert!(!output.contains("password_hash"));
288    }
289
290    #[test]
291    fn test_simple_enum() {
292        let m = mapper();
293        let def = EnumDef {
294            name: "Role".to_string(),
295            variants: vec![
296                VariantDef {
297                    name: "Admin".to_string(),
298                    rename: None,
299                    kind: VariantKind::Unit,
300                    doc: None,
301                },
302                VariantDef {
303                    name: "User".to_string(),
304                    rename: None,
305                    kind: VariantKind::Unit,
306                    doc: None,
307                },
308            ],
309            representation: EnumRepr::External,
310            doc: None,
311        };
312
313        let output = m.emit_enum(&def);
314        assert!(output.contains("class Role(str, Enum):"));
315        assert!(output.contains("ADMIN = \"Admin\""));
316        assert!(output.contains("USER = \"User\""));
317    }
318}