Skip to main content

typewriter_python/
mapper.rs

1//! Python type mapper implementation.
2
3use typewriter_core::ir::*;
4use typewriter_core::mapper::TypeMapper;
5use typewriter_core::naming::{to_file_style, FileStyle};
6
7use crate::emitter;
8
9/// Python language mapper.
10///
11/// Generates Pydantic v2 `BaseModel` classes from Rust structs and
12/// Python `Enum` / `Union` types from Rust enums.
13pub struct PythonMapper {
14    /// File naming style (default: `snake_case`)
15    pub file_style: FileStyle,
16}
17
18impl PythonMapper {
19    pub fn new() -> Self {
20        Self {
21            file_style: FileStyle::SnakeCase,
22        }
23    }
24
25    pub fn with_file_style(mut self, style: FileStyle) -> Self {
26        self.file_style = style;
27        self
28    }
29}
30
31impl Default for PythonMapper {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl TypeMapper for PythonMapper {
38    fn map_primitive(&self, ty: &PrimitiveType) -> String {
39        match ty {
40            PrimitiveType::String => "str".to_string(),
41            PrimitiveType::Bool => "bool".to_string(),
42            PrimitiveType::U8
43            | PrimitiveType::U16
44            | PrimitiveType::U32
45            | PrimitiveType::U64
46            | PrimitiveType::U128
47            | PrimitiveType::I8
48            | PrimitiveType::I16
49            | PrimitiveType::I32
50            | PrimitiveType::I64
51            | PrimitiveType::I128 => "int".to_string(),
52            PrimitiveType::F32 | PrimitiveType::F64 => "float".to_string(),
53            PrimitiveType::Uuid => "UUID".to_string(),
54            PrimitiveType::DateTime => "datetime".to_string(),
55            PrimitiveType::NaiveDate => "date".to_string(),
56            PrimitiveType::JsonValue => "Any".to_string(),
57        }
58    }
59
60    fn map_option(&self, inner: &TypeKind) -> String {
61        format!("Optional[{}]", self.map_type(inner))
62    }
63
64    fn map_vec(&self, inner: &TypeKind) -> String {
65        format!("list[{}]", self.map_type(inner))
66    }
67
68    fn map_hashmap(&self, key: &TypeKind, value: &TypeKind) -> String {
69        format!("dict[{}, {}]", self.map_type(key), self.map_type(value))
70    }
71
72    fn map_tuple(&self, elements: &[TypeKind]) -> String {
73        let inner: Vec<String> = elements.iter().map(|e| self.map_type(e)).collect();
74        format!("tuple[{}]", inner.join(", "))
75    }
76
77    fn map_named(&self, name: &str) -> String {
78        name.to_string()
79    }
80
81    fn emit_struct(&self, def: &StructDef) -> String {
82        emitter::render_model(self, def)
83    }
84
85    fn emit_enum(&self, def: &EnumDef) -> String {
86        emitter::render_enum(self, def)
87    }
88
89    fn file_header(&self, type_name: &str) -> String {
90        format!(
91            "# Auto-generated by typewriter v0.1.2. DO NOT EDIT.\n\
92             # Source: {}\n\
93             # Regenerate: cargo typewriter generate\n\n",
94            type_name
95        )
96    }
97
98    fn file_extension(&self) -> &str {
99        "py"
100    }
101
102    fn emit_imports(&self, def: &TypeDef) -> String {
103        let refs = def.collect_referenced_types();
104        if refs.is_empty() {
105            return String::new();
106        }
107        let mut output = String::new();
108        for name in &refs {
109            let file_name = self.file_naming(name);
110            output.push_str(&format!(
111                "from .{} import {}\n",
112                file_name, name
113            ));
114        }
115        output
116    }
117
118    fn file_naming(&self, type_name: &str) -> String {
119        to_file_style(type_name, self.file_style)
120    }
121
122    fn map_generic(&self, name: &str, params: &[TypeKind]) -> String {
123        let param_strs: Vec<String> = params.iter().map(|p| self.map_type(p)).collect();
124        format!("{}[{}]", name, param_strs.join(", "))
125    }
126
127    fn map_type(&self, ty: &TypeKind) -> String {
128        match ty {
129            TypeKind::Primitive(p) => self.map_primitive(p),
130            TypeKind::Option(inner) => self.map_option(inner),
131            TypeKind::Vec(inner) => self.map_vec(inner),
132            TypeKind::HashMap(k, v) => self.map_hashmap(k, v),
133            TypeKind::Tuple(elements) => self.map_tuple(elements),
134            TypeKind::Named(name) => self.map_named(name),
135            TypeKind::Generic(name, params) => self.map_generic(name, params),
136            TypeKind::Unit => "None".to_string(),
137        }
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    fn mapper() -> PythonMapper {
146        PythonMapper::new()
147    }
148
149    #[test]
150    fn test_primitive_mappings() {
151        let m = mapper();
152        assert_eq!(m.map_primitive(&PrimitiveType::String), "str");
153        assert_eq!(m.map_primitive(&PrimitiveType::Bool), "bool");
154        assert_eq!(m.map_primitive(&PrimitiveType::U32), "int");
155        assert_eq!(m.map_primitive(&PrimitiveType::I64), "int");
156        assert_eq!(m.map_primitive(&PrimitiveType::F64), "float");
157        assert_eq!(m.map_primitive(&PrimitiveType::Uuid), "UUID");
158        assert_eq!(m.map_primitive(&PrimitiveType::DateTime), "datetime");
159        assert_eq!(m.map_primitive(&PrimitiveType::NaiveDate), "date");
160        assert_eq!(m.map_primitive(&PrimitiveType::JsonValue), "Any");
161    }
162
163    #[test]
164    fn test_option_mapping() {
165        let m = mapper();
166        assert_eq!(
167            m.map_option(&TypeKind::Primitive(PrimitiveType::String)),
168            "Optional[str]"
169        );
170    }
171
172    #[test]
173    fn test_vec_mapping() {
174        let m = mapper();
175        assert_eq!(
176            m.map_vec(&TypeKind::Primitive(PrimitiveType::U32)),
177            "list[int]"
178        );
179    }
180
181    #[test]
182    fn test_hashmap_mapping() {
183        let m = mapper();
184        assert_eq!(
185            m.map_hashmap(
186                &TypeKind::Primitive(PrimitiveType::String),
187                &TypeKind::Primitive(PrimitiveType::U32)
188            ),
189            "dict[str, int]"
190        );
191    }
192
193    #[test]
194    fn test_tuple_mapping() {
195        let m = mapper();
196        assert_eq!(
197            m.map_tuple(&[
198                TypeKind::Primitive(PrimitiveType::String),
199                TypeKind::Primitive(PrimitiveType::Bool)
200            ]),
201            "tuple[str, bool]"
202        );
203    }
204
205    #[test]
206    fn test_file_naming_snake() {
207        let m = mapper();
208        assert_eq!(m.file_naming("UserProfile"), "user_profile");
209        assert_eq!(m.file_naming("User"), "user");
210        assert_eq!(m.file_naming("HTTPResponse"), "http_response");
211    }
212
213    #[test]
214    fn test_file_naming_kebab() {
215        let m = PythonMapper::new().with_file_style(FileStyle::KebabCase);
216        assert_eq!(m.file_naming("UserProfile"), "user-profile");
217        assert_eq!(m.file_naming("HTTPResponse"), "http-response");
218    }
219
220    #[test]
221    fn test_file_naming_pascal() {
222        let m = PythonMapper::new().with_file_style(FileStyle::PascalCase);
223        assert_eq!(m.file_naming("UserProfile"), "UserProfile");
224    }
225
226    #[test]
227    fn test_output_filename() {
228        let m = mapper();
229        assert_eq!(m.output_filename("UserProfile"), "user_profile.py");
230    }
231
232    #[test]
233    fn test_output_filename_pascal() {
234        let m = PythonMapper::new().with_file_style(FileStyle::PascalCase);
235        assert_eq!(m.output_filename("UserProfile"), "UserProfile.py");
236    }
237
238    #[test]
239    fn test_emit_simple_struct() {
240        let m = mapper();
241        let def = StructDef {
242            name: "User".to_string(),
243            fields: vec![
244                FieldDef {
245                    name: "id".to_string(),
246                    ty: TypeKind::Primitive(PrimitiveType::Uuid),
247                    optional: false,
248                    rename: None,
249                    doc: None,
250                    skip: false,
251                    flatten: false,
252                },
253                FieldDef {
254                    name: "email".to_string(),
255                    ty: TypeKind::Primitive(PrimitiveType::String),
256                    optional: false,
257                    rename: None,
258                    doc: None,
259                    skip: false,
260                    flatten: false,
261                },
262                FieldDef {
263                    name: "age".to_string(),
264                    ty: TypeKind::Option(Box::new(TypeKind::Primitive(PrimitiveType::U32))),
265                    optional: true,
266                    rename: None,
267                    doc: None,
268                    skip: false,
269                    flatten: false,
270                },
271            ],
272            doc: None,
273            generics: vec![],
274        };
275
276        let output = m.emit_struct(&def);
277        assert!(output.contains("class User(BaseModel):"));
278        assert!(output.contains("id: UUID"));
279        assert!(output.contains("email: str"));
280        assert!(output.contains("age: Optional[int] = None"));
281    }
282
283    #[test]
284    fn test_skipped_field() {
285        let m = mapper();
286        let def = StructDef {
287            name: "User".to_string(),
288            fields: vec![
289                FieldDef {
290                    name: "email".to_string(),
291                    ty: TypeKind::Primitive(PrimitiveType::String),
292                    optional: false,
293                    rename: None,
294                    doc: None,
295                    skip: false,
296                    flatten: false,
297                },
298                FieldDef {
299                    name: "password_hash".to_string(),
300                    ty: TypeKind::Primitive(PrimitiveType::String),
301                    optional: false,
302                    rename: None,
303                    doc: None,
304                    skip: true,
305                    flatten: false,
306                },
307            ],
308            doc: None,
309            generics: vec![],
310        };
311
312        let output = m.emit_struct(&def);
313        assert!(output.contains("email: str"));
314        assert!(!output.contains("password_hash"));
315    }
316
317    #[test]
318    fn test_simple_enum() {
319        let m = mapper();
320        let def = EnumDef {
321            name: "Role".to_string(),
322            variants: vec![
323                VariantDef {
324                    name: "Admin".to_string(),
325                    rename: None,
326                    kind: VariantKind::Unit,
327                    doc: None,
328                },
329                VariantDef {
330                    name: "User".to_string(),
331                    rename: None,
332                    kind: VariantKind::Unit,
333                    doc: None,
334                },
335            ],
336            representation: EnumRepr::External,
337            doc: None,
338        };
339
340        let output = m.emit_enum(&def);
341        assert!(output.contains("class Role(str, Enum):"));
342        assert!(output.contains("ADMIN = \"Admin\""));
343        assert!(output.contains("USER = \"User\""));
344    }
345}