Skip to main content

typewriter_python/
mapper.rs

1//! Python type mapper implementation.
2
3use typewriter_core::ir::*;
4use typewriter_core::mapper::TypeMapper;
5use typewriter_core::naming::{to_file_style, FileStyle};
6
7use crate::emitter;
8
9/// Python language mapper.
10///
11/// Generates Pydantic v2 `BaseModel` classes from Rust structs and
12/// Python `Enum` / `Union` types from Rust enums.
13pub struct PythonMapper {
14    /// File naming style (default: `snake_case`)
15    pub file_style: FileStyle,
16}
17
18impl PythonMapper {
19    pub fn new() -> Self {
20        Self {
21            file_style: FileStyle::SnakeCase,
22        }
23    }
24
25    pub fn with_file_style(mut self, style: FileStyle) -> Self {
26        self.file_style = style;
27        self
28    }
29}
30
31impl Default for PythonMapper {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl TypeMapper for PythonMapper {
38    fn map_primitive(&self, ty: &PrimitiveType) -> String {
39        match ty {
40            PrimitiveType::String => "str".to_string(),
41            PrimitiveType::Bool => "bool".to_string(),
42            PrimitiveType::U8
43            | PrimitiveType::U16
44            | PrimitiveType::U32
45            | PrimitiveType::U64
46            | PrimitiveType::U128
47            | PrimitiveType::I8
48            | PrimitiveType::I16
49            | PrimitiveType::I32
50            | PrimitiveType::I64
51            | PrimitiveType::I128 => "int".to_string(),
52            PrimitiveType::F32 | PrimitiveType::F64 => "float".to_string(),
53            PrimitiveType::Uuid => "UUID".to_string(),
54            PrimitiveType::DateTime => "datetime".to_string(),
55            PrimitiveType::NaiveDate => "date".to_string(),
56            PrimitiveType::JsonValue => "Any".to_string(),
57        }
58    }
59
60    fn map_option(&self, inner: &TypeKind) -> String {
61        format!("Optional[{}]", self.map_type(inner))
62    }
63
64    fn map_vec(&self, inner: &TypeKind) -> String {
65        format!("list[{}]", self.map_type(inner))
66    }
67
68    fn map_hashmap(&self, key: &TypeKind, value: &TypeKind) -> String {
69        format!("dict[{}, {}]", self.map_type(key), self.map_type(value))
70    }
71
72    fn map_tuple(&self, elements: &[TypeKind]) -> String {
73        let inner: Vec<String> = elements.iter().map(|e| self.map_type(e)).collect();
74        format!("tuple[{}]", inner.join(", "))
75    }
76
77    fn map_named(&self, name: &str) -> String {
78        name.to_string()
79    }
80
81    fn emit_struct(&self, def: &StructDef) -> String {
82        emitter::render_model(self, def)
83    }
84
85    fn emit_enum(&self, def: &EnumDef) -> String {
86        emitter::render_enum(self, def)
87    }
88
89    fn file_header(&self, type_name: &str) -> String {
90        format!(
91            "# Auto-generated by typewriter v0.3.0. DO NOT EDIT.\n\
92            # Source: {}\n\n",
93            type_name
94        )
95    }
96
97    fn file_extension(&self) -> &str {
98        "py"
99    }
100
101    fn emit_imports(&self, def: &TypeDef) -> String {
102        let refs = def.collect_referenced_types();
103        if refs.is_empty() {
104            return String::new();
105        }
106        let mut output = String::new();
107        for name in &refs {
108            let file_name = self.file_naming(name);
109            output.push_str(&format!("from .{} import {}\n", file_name, name));
110        }
111        output
112    }
113
114    fn file_naming(&self, type_name: &str) -> String {
115        to_file_style(type_name, self.file_style)
116    }
117
118    fn map_generic(&self, name: &str, params: &[TypeKind]) -> String {
119        let param_strs: Vec<String> = params.iter().map(|p| self.map_type(p)).collect();
120        format!("{}[{}]", name, param_strs.join(", "))
121    }
122
123    fn map_type(&self, ty: &TypeKind) -> String {
124        match ty {
125            TypeKind::Primitive(p) => self.map_primitive(p),
126            TypeKind::Option(inner) => self.map_option(inner),
127            TypeKind::Vec(inner) => self.map_vec(inner),
128            TypeKind::HashMap(k, v) => self.map_hashmap(k, v),
129            TypeKind::Tuple(elements) => self.map_tuple(elements),
130            TypeKind::Named(name) => self.map_named(name),
131            TypeKind::Generic(name, params) => self.map_generic(name, params),
132            TypeKind::Unit => "None".to_string(),
133        }
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    fn mapper() -> PythonMapper {
142        PythonMapper::new()
143    }
144
145    #[test]
146    fn test_primitive_mappings() {
147        let m = mapper();
148        assert_eq!(m.map_primitive(&PrimitiveType::String), "str");
149        assert_eq!(m.map_primitive(&PrimitiveType::Bool), "bool");
150        assert_eq!(m.map_primitive(&PrimitiveType::U32), "int");
151        assert_eq!(m.map_primitive(&PrimitiveType::I64), "int");
152        assert_eq!(m.map_primitive(&PrimitiveType::F64), "float");
153        assert_eq!(m.map_primitive(&PrimitiveType::Uuid), "UUID");
154        assert_eq!(m.map_primitive(&PrimitiveType::DateTime), "datetime");
155        assert_eq!(m.map_primitive(&PrimitiveType::NaiveDate), "date");
156        assert_eq!(m.map_primitive(&PrimitiveType::JsonValue), "Any");
157    }
158
159    #[test]
160    fn test_option_mapping() {
161        let m = mapper();
162        assert_eq!(
163            m.map_option(&TypeKind::Primitive(PrimitiveType::String)),
164            "Optional[str]"
165        );
166    }
167
168    #[test]
169    fn test_vec_mapping() {
170        let m = mapper();
171        assert_eq!(
172            m.map_vec(&TypeKind::Primitive(PrimitiveType::U32)),
173            "list[int]"
174        );
175    }
176
177    #[test]
178    fn test_hashmap_mapping() {
179        let m = mapper();
180        assert_eq!(
181            m.map_hashmap(
182                &TypeKind::Primitive(PrimitiveType::String),
183                &TypeKind::Primitive(PrimitiveType::U32)
184            ),
185            "dict[str, int]"
186        );
187    }
188
189    #[test]
190    fn test_tuple_mapping() {
191        let m = mapper();
192        assert_eq!(
193            m.map_tuple(&[
194                TypeKind::Primitive(PrimitiveType::String),
195                TypeKind::Primitive(PrimitiveType::Bool)
196            ]),
197            "tuple[str, bool]"
198        );
199    }
200
201    #[test]
202    fn test_file_naming_snake() {
203        let m = mapper();
204        assert_eq!(m.file_naming("UserProfile"), "user_profile");
205        assert_eq!(m.file_naming("User"), "user");
206        assert_eq!(m.file_naming("HTTPResponse"), "http_response");
207    }
208
209    #[test]
210    fn test_file_naming_kebab() {
211        let m = PythonMapper::new().with_file_style(FileStyle::KebabCase);
212        assert_eq!(m.file_naming("UserProfile"), "user-profile");
213        assert_eq!(m.file_naming("HTTPResponse"), "http-response");
214    }
215
216    #[test]
217    fn test_file_naming_pascal() {
218        let m = PythonMapper::new().with_file_style(FileStyle::PascalCase);
219        assert_eq!(m.file_naming("UserProfile"), "UserProfile");
220    }
221
222    #[test]
223    fn test_output_filename() {
224        let m = mapper();
225        assert_eq!(m.output_filename("UserProfile"), "user_profile.py");
226    }
227
228    #[test]
229    fn test_output_filename_pascal() {
230        let m = PythonMapper::new().with_file_style(FileStyle::PascalCase);
231        assert_eq!(m.output_filename("UserProfile"), "UserProfile.py");
232    }
233
234    #[test]
235    fn test_emit_simple_struct() {
236        let m = mapper();
237        let def = StructDef {
238            name: "User".to_string(),
239            fields: vec![
240                FieldDef {
241                    name: "id".to_string(),
242                    ty: TypeKind::Primitive(PrimitiveType::Uuid),
243                    optional: false,
244                    rename: None,
245                    doc: None,
246                    skip: false,
247                    flatten: false,
248                    type_override: None,
249                },
250                FieldDef {
251                    name: "email".to_string(),
252                    ty: TypeKind::Primitive(PrimitiveType::String),
253                    optional: false,
254                    rename: None,
255                    doc: None,
256                    skip: false,
257                    flatten: false,
258                    type_override: None,
259                },
260                FieldDef {
261                    name: "age".to_string(),
262                    ty: TypeKind::Option(Box::new(TypeKind::Primitive(PrimitiveType::U32))),
263                    optional: true,
264                    rename: None,
265                    doc: None,
266                    skip: false,
267                    flatten: false,
268                    type_override: None,
269                },
270            ],
271            doc: None,
272            generics: vec![],
273        };
274
275        let output = m.emit_struct(&def);
276        assert!(output.contains("class User(BaseModel):"));
277        assert!(output.contains("id: UUID"));
278        assert!(output.contains("email: str"));
279        assert!(output.contains("age: Optional[int] = None"));
280    }
281
282    #[test]
283    fn test_skipped_field() {
284        let m = mapper();
285        let def = StructDef {
286            name: "User".to_string(),
287            fields: vec![
288                FieldDef {
289                    name: "email".to_string(),
290                    ty: TypeKind::Primitive(PrimitiveType::String),
291                    optional: false,
292                    rename: None,
293                    doc: None,
294                    skip: false,
295                    flatten: false,
296                    type_override: None,
297                },
298                FieldDef {
299                    name: "password_hash".to_string(),
300                    ty: TypeKind::Primitive(PrimitiveType::String),
301                    optional: false,
302                    rename: None,
303                    doc: None,
304                    skip: true,
305                    flatten: false,
306                    type_override: None,
307                },
308            ],
309            doc: None,
310            generics: vec![],
311        };
312
313        let output = m.emit_struct(&def);
314        assert!(output.contains("email: str"));
315        assert!(!output.contains("password_hash"));
316    }
317
318    #[test]
319    fn test_simple_enum() {
320        let m = mapper();
321        let def = EnumDef {
322            name: "Role".to_string(),
323            variants: vec![
324                VariantDef {
325                    name: "Admin".to_string(),
326                    rename: None,
327                    kind: VariantKind::Unit,
328                    doc: None,
329                },
330                VariantDef {
331                    name: "User".to_string(),
332                    rename: None,
333                    kind: VariantKind::Unit,
334                    doc: None,
335                },
336            ],
337            representation: EnumRepr::External,
338            doc: None,
339        };
340
341        let output = m.emit_enum(&def);
342        assert!(output.contains("class Role(str, Enum):"));
343        assert!(output.contains("ADMIN = \"Admin\""));
344        assert!(output.contains("USER = \"User\""));
345    }
346}