Skip to main content

typewriter_python/
mapper.rs

1//! Python type mapper implementation.
2
3use typewriter_core::ir::*;
4use typewriter_core::mapper::TypeMapper;
5use typewriter_core::naming::{to_file_style, FileStyle};
6
7use crate::emitter;
8
9/// Python language mapper.
10///
11/// Generates Pydantic v2 `BaseModel` classes from Rust structs and
12/// Python `Enum` / `Union` types from Rust enums.
13pub struct PythonMapper {
14    /// File naming style (default: `snake_case`)
15    pub file_style: FileStyle,
16}
17
18impl PythonMapper {
19    pub fn new() -> Self {
20        Self {
21            file_style: FileStyle::SnakeCase,
22        }
23    }
24
25    pub fn with_file_style(mut self, style: FileStyle) -> Self {
26        self.file_style = style;
27        self
28    }
29}
30
31impl Default for PythonMapper {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl TypeMapper for PythonMapper {
38    fn map_primitive(&self, ty: &PrimitiveType) -> String {
39        match ty {
40            PrimitiveType::String => "str".to_string(),
41            PrimitiveType::Bool => "bool".to_string(),
42            PrimitiveType::U8
43            | PrimitiveType::U16
44            | PrimitiveType::U32
45            | PrimitiveType::U64
46            | PrimitiveType::U128
47            | PrimitiveType::I8
48            | PrimitiveType::I16
49            | PrimitiveType::I32
50            | PrimitiveType::I64
51            | PrimitiveType::I128 => "int".to_string(),
52            PrimitiveType::F32 | PrimitiveType::F64 => "float".to_string(),
53            PrimitiveType::Uuid => "UUID".to_string(),
54            PrimitiveType::DateTime => "datetime".to_string(),
55            PrimitiveType::NaiveDate => "date".to_string(),
56            PrimitiveType::JsonValue => "Any".to_string(),
57        }
58    }
59
60    fn map_option(&self, inner: &TypeKind) -> String {
61        format!("Optional[{}]", self.map_type(inner))
62    }
63
64    fn map_vec(&self, inner: &TypeKind) -> String {
65        format!("list[{}]", self.map_type(inner))
66    }
67
68    fn map_hashmap(&self, key: &TypeKind, value: &TypeKind) -> String {
69        format!("dict[{}, {}]", self.map_type(key), self.map_type(value))
70    }
71
72    fn map_tuple(&self, elements: &[TypeKind]) -> String {
73        let inner: Vec<String> = elements.iter().map(|e| self.map_type(e)).collect();
74        format!("tuple[{}]", inner.join(", "))
75    }
76
77    fn map_named(&self, name: &str) -> String {
78        name.to_string()
79    }
80
81    fn emit_struct(&self, def: &StructDef) -> String {
82        emitter::render_model(self, def)
83    }
84
85    fn emit_enum(&self, def: &EnumDef) -> String {
86        emitter::render_enum(self, def)
87    }
88
89    fn file_header(&self, type_name: &str) -> String {
90        format!(
91            "# Auto-generated by typewriter v0.1.1. DO NOT EDIT.\n\
92             # Source: {}\n\
93             # Regenerate: cargo typewriter generate\n\n",
94            type_name
95        )
96    }
97
98    fn file_extension(&self) -> &str {
99        "py"
100    }
101
102    fn file_naming(&self, type_name: &str) -> String {
103        to_file_style(type_name, self.file_style)
104    }
105
106    fn map_type(&self, ty: &TypeKind) -> String {
107        match ty {
108            TypeKind::Primitive(p) => self.map_primitive(p),
109            TypeKind::Option(inner) => self.map_option(inner),
110            TypeKind::Vec(inner) => self.map_vec(inner),
111            TypeKind::HashMap(k, v) => self.map_hashmap(k, v),
112            TypeKind::Tuple(elements) => self.map_tuple(elements),
113            TypeKind::Named(name) => self.map_named(name),
114            TypeKind::Generic(name, _params) => self.map_named(name),
115            TypeKind::Unit => "None".to_string(),
116        }
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    fn mapper() -> PythonMapper {
125        PythonMapper::new()
126    }
127
128    #[test]
129    fn test_primitive_mappings() {
130        let m = mapper();
131        assert_eq!(m.map_primitive(&PrimitiveType::String), "str");
132        assert_eq!(m.map_primitive(&PrimitiveType::Bool), "bool");
133        assert_eq!(m.map_primitive(&PrimitiveType::U32), "int");
134        assert_eq!(m.map_primitive(&PrimitiveType::I64), "int");
135        assert_eq!(m.map_primitive(&PrimitiveType::F64), "float");
136        assert_eq!(m.map_primitive(&PrimitiveType::Uuid), "UUID");
137        assert_eq!(m.map_primitive(&PrimitiveType::DateTime), "datetime");
138        assert_eq!(m.map_primitive(&PrimitiveType::NaiveDate), "date");
139        assert_eq!(m.map_primitive(&PrimitiveType::JsonValue), "Any");
140    }
141
142    #[test]
143    fn test_option_mapping() {
144        let m = mapper();
145        assert_eq!(
146            m.map_option(&TypeKind::Primitive(PrimitiveType::String)),
147            "Optional[str]"
148        );
149    }
150
151    #[test]
152    fn test_vec_mapping() {
153        let m = mapper();
154        assert_eq!(
155            m.map_vec(&TypeKind::Primitive(PrimitiveType::U32)),
156            "list[int]"
157        );
158    }
159
160    #[test]
161    fn test_hashmap_mapping() {
162        let m = mapper();
163        assert_eq!(
164            m.map_hashmap(
165                &TypeKind::Primitive(PrimitiveType::String),
166                &TypeKind::Primitive(PrimitiveType::U32)
167            ),
168            "dict[str, int]"
169        );
170    }
171
172    #[test]
173    fn test_tuple_mapping() {
174        let m = mapper();
175        assert_eq!(
176            m.map_tuple(&[
177                TypeKind::Primitive(PrimitiveType::String),
178                TypeKind::Primitive(PrimitiveType::Bool)
179            ]),
180            "tuple[str, bool]"
181        );
182    }
183
184    #[test]
185    fn test_file_naming_snake() {
186        let m = mapper();
187        assert_eq!(m.file_naming("UserProfile"), "user_profile");
188        assert_eq!(m.file_naming("User"), "user");
189        assert_eq!(m.file_naming("HTTPResponse"), "http_response");
190    }
191
192    #[test]
193    fn test_file_naming_kebab() {
194        let m = PythonMapper::new().with_file_style(FileStyle::KebabCase);
195        assert_eq!(m.file_naming("UserProfile"), "user-profile");
196        assert_eq!(m.file_naming("HTTPResponse"), "http-response");
197    }
198
199    #[test]
200    fn test_file_naming_pascal() {
201        let m = PythonMapper::new().with_file_style(FileStyle::PascalCase);
202        assert_eq!(m.file_naming("UserProfile"), "UserProfile");
203    }
204
205    #[test]
206    fn test_output_filename() {
207        let m = mapper();
208        assert_eq!(m.output_filename("UserProfile"), "user_profile.py");
209    }
210
211    #[test]
212    fn test_output_filename_pascal() {
213        let m = PythonMapper::new().with_file_style(FileStyle::PascalCase);
214        assert_eq!(m.output_filename("UserProfile"), "UserProfile.py");
215    }
216
217    #[test]
218    fn test_emit_simple_struct() {
219        let m = mapper();
220        let def = StructDef {
221            name: "User".to_string(),
222            fields: vec![
223                FieldDef {
224                    name: "id".to_string(),
225                    ty: TypeKind::Primitive(PrimitiveType::Uuid),
226                    optional: false,
227                    rename: None,
228                    doc: None,
229                    skip: false,
230                    flatten: false,
231                },
232                FieldDef {
233                    name: "email".to_string(),
234                    ty: TypeKind::Primitive(PrimitiveType::String),
235                    optional: false,
236                    rename: None,
237                    doc: None,
238                    skip: false,
239                    flatten: false,
240                },
241                FieldDef {
242                    name: "age".to_string(),
243                    ty: TypeKind::Option(Box::new(TypeKind::Primitive(PrimitiveType::U32))),
244                    optional: true,
245                    rename: None,
246                    doc: None,
247                    skip: false,
248                    flatten: false,
249                },
250            ],
251            doc: None,
252            generics: vec![],
253        };
254
255        let output = m.emit_struct(&def);
256        assert!(output.contains("class User(BaseModel):"));
257        assert!(output.contains("id: UUID"));
258        assert!(output.contains("email: str"));
259        assert!(output.contains("age: Optional[int] = None"));
260    }
261
262    #[test]
263    fn test_skipped_field() {
264        let m = mapper();
265        let def = StructDef {
266            name: "User".to_string(),
267            fields: vec![
268                FieldDef {
269                    name: "email".to_string(),
270                    ty: TypeKind::Primitive(PrimitiveType::String),
271                    optional: false,
272                    rename: None,
273                    doc: None,
274                    skip: false,
275                    flatten: false,
276                },
277                FieldDef {
278                    name: "password_hash".to_string(),
279                    ty: TypeKind::Primitive(PrimitiveType::String),
280                    optional: false,
281                    rename: None,
282                    doc: None,
283                    skip: true,
284                    flatten: false,
285                },
286            ],
287            doc: None,
288            generics: vec![],
289        };
290
291        let output = m.emit_struct(&def);
292        assert!(output.contains("email: str"));
293        assert!(!output.contains("password_hash"));
294    }
295
296    #[test]
297    fn test_simple_enum() {
298        let m = mapper();
299        let def = EnumDef {
300            name: "Role".to_string(),
301            variants: vec![
302                VariantDef {
303                    name: "Admin".to_string(),
304                    rename: None,
305                    kind: VariantKind::Unit,
306                    doc: None,
307                },
308                VariantDef {
309                    name: "User".to_string(),
310                    rename: None,
311                    kind: VariantKind::Unit,
312                    doc: None,
313                },
314            ],
315            representation: EnumRepr::External,
316            doc: None,
317        };
318
319        let output = m.emit_enum(&def);
320        assert!(output.contains("class Role(str, Enum):"));
321        assert!(output.contains("ADMIN = \"Admin\""));
322        assert!(output.contains("USER = \"User\""));
323    }
324}