Skip to main content

typewriter_python/
mapper.rs

1//! Python type mapper implementation.
2
3use typewriter_core::ir::*;
4use typewriter_core::mapper::TypeMapper;
5use typewriter_core::naming::{to_file_style, FileStyle};
6
7use crate::emitter;
8
9/// Python language mapper.
10///
11/// Generates Pydantic v2 `BaseModel` classes from Rust structs and
12/// Python `Enum` / `Union` types from Rust enums.
13pub struct PythonMapper {
14    /// File naming style (default: `snake_case`)
15    pub file_style: FileStyle,
16}
17
18impl PythonMapper {
19    pub fn new() -> Self {
20        Self {
21            file_style: FileStyle::SnakeCase,
22        }
23    }
24
25    pub fn with_file_style(mut self, style: FileStyle) -> Self {
26        self.file_style = style;
27        self
28    }
29}
30
31impl Default for PythonMapper {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl TypeMapper for PythonMapper {
38    fn map_primitive(&self, ty: &PrimitiveType) -> String {
39        match ty {
40            PrimitiveType::String => "str".to_string(),
41            PrimitiveType::Bool => "bool".to_string(),
42            PrimitiveType::U8
43            | PrimitiveType::U16
44            | PrimitiveType::U32
45            | PrimitiveType::U64
46            | PrimitiveType::U128
47            | PrimitiveType::I8
48            | PrimitiveType::I16
49            | PrimitiveType::I32
50            | PrimitiveType::I64
51            | PrimitiveType::I128 => "int".to_string(),
52            PrimitiveType::F32 | PrimitiveType::F64 => "float".to_string(),
53            PrimitiveType::Uuid => "UUID".to_string(),
54            PrimitiveType::DateTime => "datetime".to_string(),
55            PrimitiveType::NaiveDate => "date".to_string(),
56            PrimitiveType::JsonValue => "Any".to_string(),
57        }
58    }
59
60    fn map_option(&self, inner: &TypeKind) -> String {
61        format!("Optional[{}]", self.map_type(inner))
62    }
63
64    fn map_vec(&self, inner: &TypeKind) -> String {
65        format!("list[{}]", self.map_type(inner))
66    }
67
68    fn map_hashmap(&self, key: &TypeKind, value: &TypeKind) -> String {
69        format!("dict[{}, {}]", self.map_type(key), self.map_type(value))
70    }
71
72    fn map_tuple(&self, elements: &[TypeKind]) -> String {
73        let inner: Vec<String> = elements.iter().map(|e| self.map_type(e)).collect();
74        format!("tuple[{}]", inner.join(", "))
75    }
76
77    fn map_named(&self, name: &str) -> String {
78        name.to_string()
79    }
80
81    fn emit_struct(&self, def: &StructDef) -> String {
82        emitter::render_model(self, def)
83    }
84
85    fn emit_enum(&self, def: &EnumDef) -> String {
86        emitter::render_enum(self, def)
87    }
88
89    fn file_header(&self, type_name: &str) -> String {
90        format!(
91            "# Auto-generated by typewriter v0.2.0. DO NOT EDIT.\n\
92            # Source: {}\n\n",
93            type_name
94        )
95    }
96
97    fn file_extension(&self) -> &str {
98        "py"
99    }
100
101    fn emit_imports(&self, def: &TypeDef) -> String {
102        let refs = def.collect_referenced_types();
103        if refs.is_empty() {
104            return String::new();
105        }
106        let mut output = String::new();
107        for name in &refs {
108            let file_name = self.file_naming(name);
109            output.push_str(&format!(
110                "from .{} import {}\n",
111                file_name, name
112            ));
113        }
114        output
115    }
116
117    fn file_naming(&self, type_name: &str) -> String {
118        to_file_style(type_name, self.file_style)
119    }
120
121    fn map_generic(&self, name: &str, params: &[TypeKind]) -> String {
122        let param_strs: Vec<String> = params.iter().map(|p| self.map_type(p)).collect();
123        format!("{}[{}]", name, param_strs.join(", "))
124    }
125
126    fn map_type(&self, ty: &TypeKind) -> String {
127        match ty {
128            TypeKind::Primitive(p) => self.map_primitive(p),
129            TypeKind::Option(inner) => self.map_option(inner),
130            TypeKind::Vec(inner) => self.map_vec(inner),
131            TypeKind::HashMap(k, v) => self.map_hashmap(k, v),
132            TypeKind::Tuple(elements) => self.map_tuple(elements),
133            TypeKind::Named(name) => self.map_named(name),
134            TypeKind::Generic(name, params) => self.map_generic(name, params),
135            TypeKind::Unit => "None".to_string(),
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    fn mapper() -> PythonMapper {
145        PythonMapper::new()
146    }
147
148    #[test]
149    fn test_primitive_mappings() {
150        let m = mapper();
151        assert_eq!(m.map_primitive(&PrimitiveType::String), "str");
152        assert_eq!(m.map_primitive(&PrimitiveType::Bool), "bool");
153        assert_eq!(m.map_primitive(&PrimitiveType::U32), "int");
154        assert_eq!(m.map_primitive(&PrimitiveType::I64), "int");
155        assert_eq!(m.map_primitive(&PrimitiveType::F64), "float");
156        assert_eq!(m.map_primitive(&PrimitiveType::Uuid), "UUID");
157        assert_eq!(m.map_primitive(&PrimitiveType::DateTime), "datetime");
158        assert_eq!(m.map_primitive(&PrimitiveType::NaiveDate), "date");
159        assert_eq!(m.map_primitive(&PrimitiveType::JsonValue), "Any");
160    }
161
162    #[test]
163    fn test_option_mapping() {
164        let m = mapper();
165        assert_eq!(
166            m.map_option(&TypeKind::Primitive(PrimitiveType::String)),
167            "Optional[str]"
168        );
169    }
170
171    #[test]
172    fn test_vec_mapping() {
173        let m = mapper();
174        assert_eq!(
175            m.map_vec(&TypeKind::Primitive(PrimitiveType::U32)),
176            "list[int]"
177        );
178    }
179
180    #[test]
181    fn test_hashmap_mapping() {
182        let m = mapper();
183        assert_eq!(
184            m.map_hashmap(
185                &TypeKind::Primitive(PrimitiveType::String),
186                &TypeKind::Primitive(PrimitiveType::U32)
187            ),
188            "dict[str, int]"
189        );
190    }
191
192    #[test]
193    fn test_tuple_mapping() {
194        let m = mapper();
195        assert_eq!(
196            m.map_tuple(&[
197                TypeKind::Primitive(PrimitiveType::String),
198                TypeKind::Primitive(PrimitiveType::Bool)
199            ]),
200            "tuple[str, bool]"
201        );
202    }
203
204    #[test]
205    fn test_file_naming_snake() {
206        let m = mapper();
207        assert_eq!(m.file_naming("UserProfile"), "user_profile");
208        assert_eq!(m.file_naming("User"), "user");
209        assert_eq!(m.file_naming("HTTPResponse"), "http_response");
210    }
211
212    #[test]
213    fn test_file_naming_kebab() {
214        let m = PythonMapper::new().with_file_style(FileStyle::KebabCase);
215        assert_eq!(m.file_naming("UserProfile"), "user-profile");
216        assert_eq!(m.file_naming("HTTPResponse"), "http-response");
217    }
218
219    #[test]
220    fn test_file_naming_pascal() {
221        let m = PythonMapper::new().with_file_style(FileStyle::PascalCase);
222        assert_eq!(m.file_naming("UserProfile"), "UserProfile");
223    }
224
225    #[test]
226    fn test_output_filename() {
227        let m = mapper();
228        assert_eq!(m.output_filename("UserProfile"), "user_profile.py");
229    }
230
231    #[test]
232    fn test_output_filename_pascal() {
233        let m = PythonMapper::new().with_file_style(FileStyle::PascalCase);
234        assert_eq!(m.output_filename("UserProfile"), "UserProfile.py");
235    }
236
237    #[test]
238    fn test_emit_simple_struct() {
239        let m = mapper();
240        let def = StructDef {
241            name: "User".to_string(),
242            fields: vec![
243                FieldDef {
244                    name: "id".to_string(),
245                    ty: TypeKind::Primitive(PrimitiveType::Uuid),
246                    optional: false,
247                    rename: None,
248                    doc: None,
249                    skip: false,
250                    flatten: false,
251                type_override: None,
252                },
253                FieldDef {
254                    name: "email".to_string(),
255                    ty: TypeKind::Primitive(PrimitiveType::String),
256                    optional: false,
257                    rename: None,
258                    doc: None,
259                    skip: false,
260                    flatten: false,
261                type_override: None,
262                },
263                FieldDef {
264                    name: "age".to_string(),
265                    ty: TypeKind::Option(Box::new(TypeKind::Primitive(PrimitiveType::U32))),
266                    optional: true,
267                    rename: None,
268                    doc: None,
269                    skip: false,
270                    flatten: false,
271                type_override: None,
272                },
273            ],
274            doc: None,
275            generics: vec![],
276        };
277
278        let output = m.emit_struct(&def);
279        assert!(output.contains("class User(BaseModel):"));
280        assert!(output.contains("id: UUID"));
281        assert!(output.contains("email: str"));
282        assert!(output.contains("age: Optional[int] = None"));
283    }
284
285    #[test]
286    fn test_skipped_field() {
287        let m = mapper();
288        let def = StructDef {
289            name: "User".to_string(),
290            fields: vec![
291                FieldDef {
292                    name: "email".to_string(),
293                    ty: TypeKind::Primitive(PrimitiveType::String),
294                    optional: false,
295                    rename: None,
296                    doc: None,
297                    skip: false,
298                    flatten: false,
299                type_override: None,
300                },
301                FieldDef {
302                    name: "password_hash".to_string(),
303                    ty: TypeKind::Primitive(PrimitiveType::String),
304                    optional: false,
305                    rename: None,
306                    doc: None,
307                    skip: true,
308                    flatten: false,
309                type_override: None,
310                },
311            ],
312            doc: None,
313            generics: vec![],
314        };
315
316        let output = m.emit_struct(&def);
317        assert!(output.contains("email: str"));
318        assert!(!output.contains("password_hash"));
319    }
320
321    #[test]
322    fn test_simple_enum() {
323        let m = mapper();
324        let def = EnumDef {
325            name: "Role".to_string(),
326            variants: vec![
327                VariantDef {
328                    name: "Admin".to_string(),
329                    rename: None,
330                    kind: VariantKind::Unit,
331                    doc: None,
332                },
333                VariantDef {
334                    name: "User".to_string(),
335                    rename: None,
336                    kind: VariantKind::Unit,
337                    doc: None,
338                },
339            ],
340            representation: EnumRepr::External,
341            doc: None,
342        };
343
344        let output = m.emit_enum(&def);
345        assert!(output.contains("class Role(str, Enum):"));
346        assert!(output.contains("ADMIN = \"Admin\""));
347        assert!(output.contains("USER = \"User\""));
348    }
349}