Skip to main content

tensorlogic_adapters/codegen/
python.rs

1use std::fmt::Write as FmtWrite;
2
3use crate::{DomainInfo, PredicateInfo, SymbolTable};
4
5use super::rust::RustCodegen;
6
7/// Code generator for Python type stubs and PyO3 bindings.
8///
9/// This generator creates Python type stubs (.pyi) and optionally PyO3
10/// binding code from TensorLogic schemas.
11pub struct PythonCodegen {
12    /// Module name
13    module_name: String,
14    /// Whether to generate PyO3 bindings (vs. just stubs)
15    generate_pyo3: bool,
16    /// Whether to include docstrings
17    include_docs: bool,
18    /// Whether to generate dataclass decorators
19    use_dataclasses: bool,
20}
21
22impl PythonCodegen {
23    /// Create a new Python code generator.
24    pub fn new(module_name: impl Into<String>) -> Self {
25        Self {
26            module_name: module_name.into(),
27            generate_pyo3: false,
28            include_docs: true,
29            use_dataclasses: true,
30        }
31    }
32
33    /// Set whether to generate PyO3 bindings.
34    pub fn with_pyo3(mut self, enable: bool) -> Self {
35        self.generate_pyo3 = enable;
36        self
37    }
38
39    /// Set whether to include docstrings.
40    pub fn with_docs(mut self, enable: bool) -> Self {
41        self.include_docs = enable;
42        self
43    }
44
45    /// Set whether to use dataclasses.
46    pub fn with_dataclasses(mut self, enable: bool) -> Self {
47        self.use_dataclasses = enable;
48        self
49    }
50
51    /// Generate complete Python module from a symbol table.
52    pub fn generate(&self, table: &SymbolTable) -> String {
53        if self.generate_pyo3 {
54            self.generate_pyo3_bindings(table)
55        } else {
56            self.generate_type_stubs(table)
57        }
58    }
59
60    /// Generate Python type stubs (.pyi file).
61    fn generate_type_stubs(&self, table: &SymbolTable) -> String {
62        let mut code = String::new();
63
64        // Module header
65        writeln!(code, "\"\"\"").expect("writing to String is infallible");
66        writeln!(code, "Generated from TensorLogic schema")
67            .expect("writing to String is infallible");
68        writeln!(code, "Module: {}", self.module_name).expect("writing to String is infallible");
69        writeln!(code).expect("writing to String is infallible");
70        writeln!(code, "This code was automatically generated.")
71            .expect("writing to String is infallible");
72        writeln!(code, "DO NOT EDIT MANUALLY.").expect("writing to String is infallible");
73        writeln!(code, "\"\"\"").expect("writing to String is infallible");
74        writeln!(code).expect("writing to String is infallible");
75
76        // Imports
77        writeln!(code, "from typing import NewType, Final")
78            .expect("writing to String is infallible");
79        if self.use_dataclasses {
80            writeln!(code, "from dataclasses import dataclass")
81                .expect("writing to String is infallible");
82        }
83        writeln!(code).expect("writing to String is infallible");
84
85        // Generate domain types
86        writeln!(code, "# ==========================================")
87            .expect("writing to String is infallible");
88        writeln!(code, "# Domain Types").expect("writing to String is infallible");
89        writeln!(code, "# ==========================================")
90            .expect("writing to String is infallible");
91        writeln!(code).expect("writing to String is infallible");
92
93        for domain in table.domains.values() {
94            self.generate_domain_stub(&mut code, domain);
95            writeln!(code).expect("writing to String is infallible");
96        }
97
98        // Generate predicate types
99        writeln!(code, "# ==========================================")
100            .expect("writing to String is infallible");
101        writeln!(code, "# Predicate Types").expect("writing to String is infallible");
102        writeln!(code, "# ==========================================")
103            .expect("writing to String is infallible");
104        writeln!(code).expect("writing to String is infallible");
105
106        for predicate in table.predicates.values() {
107            self.generate_predicate_stub(&mut code, predicate, table);
108            writeln!(code).expect("writing to String is infallible");
109        }
110
111        // Generate schema metadata
112        writeln!(code, "# ==========================================")
113            .expect("writing to String is infallible");
114        writeln!(code, "# Schema Metadata").expect("writing to String is infallible");
115        writeln!(code, "# ==========================================")
116            .expect("writing to String is infallible");
117        writeln!(code).expect("writing to String is infallible");
118        self.generate_schema_metadata_stub(&mut code, table);
119
120        code
121    }
122
123    /// Generate PyO3 Rust bindings.
124    fn generate_pyo3_bindings(&self, table: &SymbolTable) -> String {
125        let mut code = String::new();
126
127        // Module header
128        writeln!(code, "//! PyO3 bindings for TensorLogic schema")
129            .expect("writing to String is infallible");
130        writeln!(code, "//! Module: {}", self.module_name)
131            .expect("writing to String is infallible");
132        writeln!(code, "//!").expect("writing to String is infallible");
133        writeln!(code, "//! This code was automatically generated.")
134            .expect("writing to String is infallible");
135        writeln!(code, "//! DO NOT EDIT MANUALLY.").expect("writing to String is infallible");
136        writeln!(code).expect("writing to String is infallible");
137
138        writeln!(code, "use pyo3::prelude::*;").expect("writing to String is infallible");
139        writeln!(code).expect("writing to String is infallible");
140
141        // Generate domain classes
142        writeln!(code, "// ==========================================")
143            .expect("writing to String is infallible");
144        writeln!(code, "// Domain Types").expect("writing to String is infallible");
145        writeln!(code, "// ==========================================")
146            .expect("writing to String is infallible");
147        writeln!(code).expect("writing to String is infallible");
148
149        for domain in table.domains.values() {
150            self.generate_domain_pyo3(&mut code, domain);
151            writeln!(code).expect("writing to String is infallible");
152        }
153
154        // Generate predicate classes
155        writeln!(code, "// ==========================================")
156            .expect("writing to String is infallible");
157        writeln!(code, "// Predicate Types").expect("writing to String is infallible");
158        writeln!(code, "// ==========================================")
159            .expect("writing to String is infallible");
160        writeln!(code).expect("writing to String is infallible");
161
162        for predicate in table.predicates.values() {
163            self.generate_predicate_pyo3(&mut code, predicate);
164            writeln!(code).expect("writing to String is infallible");
165        }
166
167        // Generate module registration
168        writeln!(code, "// ==========================================")
169            .expect("writing to String is infallible");
170        writeln!(code, "// Module Registration").expect("writing to String is infallible");
171        writeln!(code, "// ==========================================")
172            .expect("writing to String is infallible");
173        writeln!(code).expect("writing to String is infallible");
174        self.generate_module_registration(&mut code, table);
175
176        code
177    }
178
179    /// Generate Python type stub for a domain.
180    fn generate_domain_stub(&self, code: &mut String, domain: &DomainInfo) {
181        let type_name = Self::to_python_class_name(&domain.name);
182
183        // NewType for branded ID
184        writeln!(code, "{} = NewType('{}', int)", type_name, type_name)
185            .expect("writing to String is infallible");
186        writeln!(code).expect("writing to String is infallible");
187
188        // Cardinality constant
189        writeln!(
190            code,
191            "{}_CARDINALITY: Final[int] = {}",
192            domain.name.to_uppercase(),
193            domain.cardinality
194        )
195        .expect("writing to String is infallible");
196        writeln!(code).expect("writing to String is infallible");
197
198        // Validator function
199        writeln!(code, "def is_valid_{}(id: int) -> bool:", domain.name)
200            .expect("writing to String is infallible");
201        if self.include_docs {
202            writeln!(code, "    \"\"\"").expect("writing to String is infallible");
203            if let Some(ref desc) = domain.description {
204                writeln!(code, "    {}", desc).expect("writing to String is infallible");
205                writeln!(code).expect("writing to String is infallible");
206            }
207            writeln!(code, "    Validate {} ID.", type_name)
208                .expect("writing to String is infallible");
209            writeln!(code).expect("writing to String is infallible");
210            writeln!(code, "    Args:").expect("writing to String is infallible");
211            writeln!(code, "        id: The ID to validate")
212                .expect("writing to String is infallible");
213            writeln!(code).expect("writing to String is infallible");
214            writeln!(code, "    Returns:").expect("writing to String is infallible");
215            writeln!(
216                code,
217                "        True if id is in range [0, {}), False otherwise",
218                domain.cardinality
219            )
220            .expect("writing to String is infallible");
221            writeln!(code, "    \"\"\"").expect("writing to String is infallible");
222        }
223        writeln!(code, "    ...").expect("writing to String is infallible");
224    }
225
226    /// Generate Python type stub for a predicate.
227    fn generate_predicate_stub(
228        &self,
229        code: &mut String,
230        predicate: &PredicateInfo,
231        _table: &SymbolTable,
232    ) {
233        let class_name = Self::to_python_class_name(&predicate.name);
234
235        if self.include_docs {
236            writeln!(code, "\"\"\"").expect("writing to String is infallible");
237            if let Some(ref desc) = predicate.description {
238                writeln!(code, "{}", desc).expect("writing to String is infallible");
239            } else {
240                writeln!(code, "Predicate: {}", predicate.name)
241                    .expect("writing to String is infallible");
242            }
243            writeln!(code).expect("writing to String is infallible");
244            writeln!(code, "Arity: {}", predicate.arg_domains.len())
245                .expect("writing to String is infallible");
246            writeln!(code, "\"\"\"").expect("writing to String is infallible");
247        }
248
249        if self.use_dataclasses {
250            writeln!(code, "@dataclass(frozen=True)").expect("writing to String is infallible");
251        }
252
253        writeln!(code, "class {}:", class_name).expect("writing to String is infallible");
254
255        if self.include_docs && predicate.description.is_none() {
256            writeln!(code, "    \"\"\"{}\"\"\"", predicate.name)
257                .expect("writing to String is infallible");
258        }
259
260        // Add fields
261        for (i, domain_name) in predicate.arg_domains.iter().enumerate() {
262            let field_name = format!("arg{}", i);
263            let field_type = Self::to_python_class_name(domain_name);
264            writeln!(code, "    {}: {}", field_name, field_type)
265                .expect("writing to String is infallible");
266        }
267
268        if predicate.arg_domains.is_empty() {
269            writeln!(code, "    pass").expect("writing to String is infallible");
270        }
271    }
272
273    /// Generate PyO3 class for a domain.
274    fn generate_domain_pyo3(&self, code: &mut String, domain: &DomainInfo) {
275        let type_name = Self::to_python_class_name(&domain.name);
276
277        writeln!(code, "#[pyclass]").expect("writing to String is infallible");
278        writeln!(code, "#[derive(Clone, Copy, Debug)]").expect("writing to String is infallible");
279        writeln!(code, "pub struct {} {{", type_name).expect("writing to String is infallible");
280        writeln!(code, "    #[pyo3(get)]").expect("writing to String is infallible");
281        writeln!(code, "    pub id: usize,").expect("writing to String is infallible");
282        writeln!(code, "}}").expect("writing to String is infallible");
283        writeln!(code).expect("writing to String is infallible");
284
285        writeln!(code, "#[pymethods]").expect("writing to String is infallible");
286        writeln!(code, "impl {} {{", type_name).expect("writing to String is infallible");
287
288        // Constructor
289        writeln!(code, "    #[new]").expect("writing to String is infallible");
290        writeln!(code, "    pub fn new(id: usize) -> PyResult<Self> {{")
291            .expect("writing to String is infallible");
292        writeln!(code, "        if id >= {} {{", domain.cardinality)
293            .expect("writing to String is infallible");
294        writeln!(
295            code,
296            "            return Err(pyo3::exceptions::PyValueError::new_err("
297        )
298        .expect("writing to String is infallible");
299        writeln!(
300            code,
301            "                format!(\"ID {{}} exceeds cardinality {}\", id)",
302            domain.cardinality
303        )
304        .expect("writing to String is infallible");
305        writeln!(code, "            ));").expect("writing to String is infallible");
306        writeln!(code, "        }}").expect("writing to String is infallible");
307        writeln!(code, "        Ok(Self {{ id }})").expect("writing to String is infallible");
308        writeln!(code, "    }}").expect("writing to String is infallible");
309        writeln!(code).expect("writing to String is infallible");
310
311        // String representation
312        writeln!(code, "    fn __repr__(&self) -> String {{")
313            .expect("writing to String is infallible");
314        writeln!(code, "        format!(\"{}({{}})\", self.id)", type_name)
315            .expect("writing to String is infallible");
316        writeln!(code, "    }}").expect("writing to String is infallible");
317
318        writeln!(code, "}}").expect("writing to String is infallible");
319    }
320
321    /// Generate PyO3 class for a predicate.
322    fn generate_predicate_pyo3(&self, code: &mut String, predicate: &PredicateInfo) {
323        let type_name = Self::to_python_class_name(&predicate.name);
324
325        writeln!(code, "#[pyclass]").expect("writing to String is infallible");
326        writeln!(code, "#[derive(Clone, Debug)]").expect("writing to String is infallible");
327        writeln!(code, "pub struct {} {{", type_name).expect("writing to String is infallible");
328
329        for (i, domain_name) in predicate.arg_domains.iter().enumerate() {
330            let field_type = Self::to_python_class_name(domain_name);
331            writeln!(code, "    #[pyo3(get)]").expect("writing to String is infallible");
332            writeln!(code, "    pub arg{}: {},", i, field_type)
333                .expect("writing to String is infallible");
334        }
335
336        writeln!(code, "}}").expect("writing to String is infallible");
337        writeln!(code).expect("writing to String is infallible");
338
339        writeln!(code, "#[pymethods]").expect("writing to String is infallible");
340        writeln!(code, "impl {} {{", type_name).expect("writing to String is infallible");
341
342        // Constructor
343        writeln!(code, "    #[new]").expect("writing to String is infallible");
344        write!(code, "    pub fn new(").expect("writing to String is infallible");
345        for (i, domain_name) in predicate.arg_domains.iter().enumerate() {
346            if i > 0 {
347                write!(code, ", ").expect("writing to String is infallible");
348            }
349            write!(
350                code,
351                "arg{}: {}",
352                i,
353                Self::to_python_class_name(domain_name)
354            )
355            .expect("writing to String is infallible");
356        }
357        writeln!(code, ") -> Self {{").expect("writing to String is infallible");
358
359        if predicate.arg_domains.is_empty() {
360            writeln!(code, "        Self {{}}").expect("writing to String is infallible");
361        } else {
362            write!(code, "        Self {{ ").expect("writing to String is infallible");
363            for i in 0..predicate.arg_domains.len() {
364                if i > 0 {
365                    write!(code, ", ").expect("writing to String is infallible");
366                }
367                write!(code, "arg{}", i).expect("writing to String is infallible");
368            }
369            writeln!(code, " }}").expect("writing to String is infallible");
370        }
371        writeln!(code, "    }}").expect("writing to String is infallible");
372
373        writeln!(code, "}}").expect("writing to String is infallible");
374    }
375
376    /// Generate module registration for PyO3.
377    fn generate_module_registration(&self, code: &mut String, table: &SymbolTable) {
378        writeln!(code, "#[pymodule]").expect("writing to String is infallible");
379        writeln!(
380            code,
381            "fn {}(_py: Python, m: &PyModule) -> PyResult<()> {{",
382            self.module_name.replace('-', "_")
383        )
384        .expect("writing to String is infallible");
385
386        // Register domain classes
387        for domain in table.domains.values() {
388            let type_name = Self::to_python_class_name(&domain.name);
389            writeln!(code, "    m.add_class::<{}>()?;", type_name)
390                .expect("writing to String is infallible");
391        }
392
393        // Register predicate classes
394        for predicate in table.predicates.values() {
395            let type_name = Self::to_python_class_name(&predicate.name);
396            writeln!(code, "    m.add_class::<{}>()?;", type_name)
397                .expect("writing to String is infallible");
398        }
399
400        writeln!(code, "    Ok(())").expect("writing to String is infallible");
401        writeln!(code, "}}").expect("writing to String is infallible");
402    }
403
404    /// Generate schema metadata stub.
405    fn generate_schema_metadata_stub(&self, code: &mut String, table: &SymbolTable) {
406        writeln!(code, "class SchemaMetadata:").expect("writing to String is infallible");
407        if self.include_docs {
408            writeln!(code, "    \"\"\"Schema metadata and statistics\"\"\"")
409                .expect("writing to String is infallible");
410        }
411        writeln!(
412            code,
413            "    DOMAIN_COUNT: Final[int] = {}",
414            table.domains.len()
415        )
416        .expect("writing to String is infallible");
417        writeln!(
418            code,
419            "    PREDICATE_COUNT: Final[int] = {}",
420            table.predicates.len()
421        )
422        .expect("writing to String is infallible");
423
424        let total_card: usize = table.domains.values().map(|d| d.cardinality).sum();
425        writeln!(code, "    TOTAL_CARDINALITY: Final[int] = {}", total_card)
426            .expect("writing to String is infallible");
427    }
428
429    /// Convert a name to Python class name (PascalCase).
430    fn to_python_class_name(name: &str) -> String {
431        RustCodegen::to_type_name(name) // Reuse Rust converter
432    }
433}