1use std::fmt::Write as FmtWrite;
2
3use crate::{DomainInfo, PredicateInfo, SymbolTable};
4
5use super::rust::RustCodegen;
6
7pub struct PythonCodegen {
12 module_name: String,
14 generate_pyo3: bool,
16 include_docs: bool,
18 use_dataclasses: bool,
20}
21
22impl PythonCodegen {
23 pub fn new(module_name: impl Into<String>) -> Self {
25 Self {
26 module_name: module_name.into(),
27 generate_pyo3: false,
28 include_docs: true,
29 use_dataclasses: true,
30 }
31 }
32
33 pub fn with_pyo3(mut self, enable: bool) -> Self {
35 self.generate_pyo3 = enable;
36 self
37 }
38
39 pub fn with_docs(mut self, enable: bool) -> Self {
41 self.include_docs = enable;
42 self
43 }
44
45 pub fn with_dataclasses(mut self, enable: bool) -> Self {
47 self.use_dataclasses = enable;
48 self
49 }
50
51 pub fn generate(&self, table: &SymbolTable) -> String {
53 if self.generate_pyo3 {
54 self.generate_pyo3_bindings(table)
55 } else {
56 self.generate_type_stubs(table)
57 }
58 }
59
60 fn generate_type_stubs(&self, table: &SymbolTable) -> String {
62 let mut code = String::new();
63
64 writeln!(code, "\"\"\"").expect("writing to String is infallible");
66 writeln!(code, "Generated from TensorLogic schema")
67 .expect("writing to String is infallible");
68 writeln!(code, "Module: {}", self.module_name).expect("writing to String is infallible");
69 writeln!(code).expect("writing to String is infallible");
70 writeln!(code, "This code was automatically generated.")
71 .expect("writing to String is infallible");
72 writeln!(code, "DO NOT EDIT MANUALLY.").expect("writing to String is infallible");
73 writeln!(code, "\"\"\"").expect("writing to String is infallible");
74 writeln!(code).expect("writing to String is infallible");
75
76 writeln!(code, "from typing import NewType, Final")
78 .expect("writing to String is infallible");
79 if self.use_dataclasses {
80 writeln!(code, "from dataclasses import dataclass")
81 .expect("writing to String is infallible");
82 }
83 writeln!(code).expect("writing to String is infallible");
84
85 writeln!(code, "# ==========================================")
87 .expect("writing to String is infallible");
88 writeln!(code, "# Domain Types").expect("writing to String is infallible");
89 writeln!(code, "# ==========================================")
90 .expect("writing to String is infallible");
91 writeln!(code).expect("writing to String is infallible");
92
93 for domain in table.domains.values() {
94 self.generate_domain_stub(&mut code, domain);
95 writeln!(code).expect("writing to String is infallible");
96 }
97
98 writeln!(code, "# ==========================================")
100 .expect("writing to String is infallible");
101 writeln!(code, "# Predicate Types").expect("writing to String is infallible");
102 writeln!(code, "# ==========================================")
103 .expect("writing to String is infallible");
104 writeln!(code).expect("writing to String is infallible");
105
106 for predicate in table.predicates.values() {
107 self.generate_predicate_stub(&mut code, predicate, table);
108 writeln!(code).expect("writing to String is infallible");
109 }
110
111 writeln!(code, "# ==========================================")
113 .expect("writing to String is infallible");
114 writeln!(code, "# Schema Metadata").expect("writing to String is infallible");
115 writeln!(code, "# ==========================================")
116 .expect("writing to String is infallible");
117 writeln!(code).expect("writing to String is infallible");
118 self.generate_schema_metadata_stub(&mut code, table);
119
120 code
121 }
122
123 fn generate_pyo3_bindings(&self, table: &SymbolTable) -> String {
125 let mut code = String::new();
126
127 writeln!(code, "//! PyO3 bindings for TensorLogic schema")
129 .expect("writing to String is infallible");
130 writeln!(code, "//! Module: {}", self.module_name)
131 .expect("writing to String is infallible");
132 writeln!(code, "//!").expect("writing to String is infallible");
133 writeln!(code, "//! This code was automatically generated.")
134 .expect("writing to String is infallible");
135 writeln!(code, "//! DO NOT EDIT MANUALLY.").expect("writing to String is infallible");
136 writeln!(code).expect("writing to String is infallible");
137
138 writeln!(code, "use pyo3::prelude::*;").expect("writing to String is infallible");
139 writeln!(code).expect("writing to String is infallible");
140
141 writeln!(code, "// ==========================================")
143 .expect("writing to String is infallible");
144 writeln!(code, "// Domain Types").expect("writing to String is infallible");
145 writeln!(code, "// ==========================================")
146 .expect("writing to String is infallible");
147 writeln!(code).expect("writing to String is infallible");
148
149 for domain in table.domains.values() {
150 self.generate_domain_pyo3(&mut code, domain);
151 writeln!(code).expect("writing to String is infallible");
152 }
153
154 writeln!(code, "// ==========================================")
156 .expect("writing to String is infallible");
157 writeln!(code, "// Predicate Types").expect("writing to String is infallible");
158 writeln!(code, "// ==========================================")
159 .expect("writing to String is infallible");
160 writeln!(code).expect("writing to String is infallible");
161
162 for predicate in table.predicates.values() {
163 self.generate_predicate_pyo3(&mut code, predicate);
164 writeln!(code).expect("writing to String is infallible");
165 }
166
167 writeln!(code, "// ==========================================")
169 .expect("writing to String is infallible");
170 writeln!(code, "// Module Registration").expect("writing to String is infallible");
171 writeln!(code, "// ==========================================")
172 .expect("writing to String is infallible");
173 writeln!(code).expect("writing to String is infallible");
174 self.generate_module_registration(&mut code, table);
175
176 code
177 }
178
179 fn generate_domain_stub(&self, code: &mut String, domain: &DomainInfo) {
181 let type_name = Self::to_python_class_name(&domain.name);
182
183 writeln!(code, "{} = NewType('{}', int)", type_name, type_name)
185 .expect("writing to String is infallible");
186 writeln!(code).expect("writing to String is infallible");
187
188 writeln!(
190 code,
191 "{}_CARDINALITY: Final[int] = {}",
192 domain.name.to_uppercase(),
193 domain.cardinality
194 )
195 .expect("writing to String is infallible");
196 writeln!(code).expect("writing to String is infallible");
197
198 writeln!(code, "def is_valid_{}(id: int) -> bool:", domain.name)
200 .expect("writing to String is infallible");
201 if self.include_docs {
202 writeln!(code, " \"\"\"").expect("writing to String is infallible");
203 if let Some(ref desc) = domain.description {
204 writeln!(code, " {}", desc).expect("writing to String is infallible");
205 writeln!(code).expect("writing to String is infallible");
206 }
207 writeln!(code, " Validate {} ID.", type_name)
208 .expect("writing to String is infallible");
209 writeln!(code).expect("writing to String is infallible");
210 writeln!(code, " Args:").expect("writing to String is infallible");
211 writeln!(code, " id: The ID to validate")
212 .expect("writing to String is infallible");
213 writeln!(code).expect("writing to String is infallible");
214 writeln!(code, " Returns:").expect("writing to String is infallible");
215 writeln!(
216 code,
217 " True if id is in range [0, {}), False otherwise",
218 domain.cardinality
219 )
220 .expect("writing to String is infallible");
221 writeln!(code, " \"\"\"").expect("writing to String is infallible");
222 }
223 writeln!(code, " ...").expect("writing to String is infallible");
224 }
225
226 fn generate_predicate_stub(
228 &self,
229 code: &mut String,
230 predicate: &PredicateInfo,
231 _table: &SymbolTable,
232 ) {
233 let class_name = Self::to_python_class_name(&predicate.name);
234
235 if self.include_docs {
236 writeln!(code, "\"\"\"").expect("writing to String is infallible");
237 if let Some(ref desc) = predicate.description {
238 writeln!(code, "{}", desc).expect("writing to String is infallible");
239 } else {
240 writeln!(code, "Predicate: {}", predicate.name)
241 .expect("writing to String is infallible");
242 }
243 writeln!(code).expect("writing to String is infallible");
244 writeln!(code, "Arity: {}", predicate.arg_domains.len())
245 .expect("writing to String is infallible");
246 writeln!(code, "\"\"\"").expect("writing to String is infallible");
247 }
248
249 if self.use_dataclasses {
250 writeln!(code, "@dataclass(frozen=True)").expect("writing to String is infallible");
251 }
252
253 writeln!(code, "class {}:", class_name).expect("writing to String is infallible");
254
255 if self.include_docs && predicate.description.is_none() {
256 writeln!(code, " \"\"\"{}\"\"\"", predicate.name)
257 .expect("writing to String is infallible");
258 }
259
260 for (i, domain_name) in predicate.arg_domains.iter().enumerate() {
262 let field_name = format!("arg{}", i);
263 let field_type = Self::to_python_class_name(domain_name);
264 writeln!(code, " {}: {}", field_name, field_type)
265 .expect("writing to String is infallible");
266 }
267
268 if predicate.arg_domains.is_empty() {
269 writeln!(code, " pass").expect("writing to String is infallible");
270 }
271 }
272
273 fn generate_domain_pyo3(&self, code: &mut String, domain: &DomainInfo) {
275 let type_name = Self::to_python_class_name(&domain.name);
276
277 writeln!(code, "#[pyclass]").expect("writing to String is infallible");
278 writeln!(code, "#[derive(Clone, Copy, Debug)]").expect("writing to String is infallible");
279 writeln!(code, "pub struct {} {{", type_name).expect("writing to String is infallible");
280 writeln!(code, " #[pyo3(get)]").expect("writing to String is infallible");
281 writeln!(code, " pub id: usize,").expect("writing to String is infallible");
282 writeln!(code, "}}").expect("writing to String is infallible");
283 writeln!(code).expect("writing to String is infallible");
284
285 writeln!(code, "#[pymethods]").expect("writing to String is infallible");
286 writeln!(code, "impl {} {{", type_name).expect("writing to String is infallible");
287
288 writeln!(code, " #[new]").expect("writing to String is infallible");
290 writeln!(code, " pub fn new(id: usize) -> PyResult<Self> {{")
291 .expect("writing to String is infallible");
292 writeln!(code, " if id >= {} {{", domain.cardinality)
293 .expect("writing to String is infallible");
294 writeln!(
295 code,
296 " return Err(pyo3::exceptions::PyValueError::new_err("
297 )
298 .expect("writing to String is infallible");
299 writeln!(
300 code,
301 " format!(\"ID {{}} exceeds cardinality {}\", id)",
302 domain.cardinality
303 )
304 .expect("writing to String is infallible");
305 writeln!(code, " ));").expect("writing to String is infallible");
306 writeln!(code, " }}").expect("writing to String is infallible");
307 writeln!(code, " Ok(Self {{ id }})").expect("writing to String is infallible");
308 writeln!(code, " }}").expect("writing to String is infallible");
309 writeln!(code).expect("writing to String is infallible");
310
311 writeln!(code, " fn __repr__(&self) -> String {{")
313 .expect("writing to String is infallible");
314 writeln!(code, " format!(\"{}({{}})\", self.id)", type_name)
315 .expect("writing to String is infallible");
316 writeln!(code, " }}").expect("writing to String is infallible");
317
318 writeln!(code, "}}").expect("writing to String is infallible");
319 }
320
321 fn generate_predicate_pyo3(&self, code: &mut String, predicate: &PredicateInfo) {
323 let type_name = Self::to_python_class_name(&predicate.name);
324
325 writeln!(code, "#[pyclass]").expect("writing to String is infallible");
326 writeln!(code, "#[derive(Clone, Debug)]").expect("writing to String is infallible");
327 writeln!(code, "pub struct {} {{", type_name).expect("writing to String is infallible");
328
329 for (i, domain_name) in predicate.arg_domains.iter().enumerate() {
330 let field_type = Self::to_python_class_name(domain_name);
331 writeln!(code, " #[pyo3(get)]").expect("writing to String is infallible");
332 writeln!(code, " pub arg{}: {},", i, field_type)
333 .expect("writing to String is infallible");
334 }
335
336 writeln!(code, "}}").expect("writing to String is infallible");
337 writeln!(code).expect("writing to String is infallible");
338
339 writeln!(code, "#[pymethods]").expect("writing to String is infallible");
340 writeln!(code, "impl {} {{", type_name).expect("writing to String is infallible");
341
342 writeln!(code, " #[new]").expect("writing to String is infallible");
344 write!(code, " pub fn new(").expect("writing to String is infallible");
345 for (i, domain_name) in predicate.arg_domains.iter().enumerate() {
346 if i > 0 {
347 write!(code, ", ").expect("writing to String is infallible");
348 }
349 write!(
350 code,
351 "arg{}: {}",
352 i,
353 Self::to_python_class_name(domain_name)
354 )
355 .expect("writing to String is infallible");
356 }
357 writeln!(code, ") -> Self {{").expect("writing to String is infallible");
358
359 if predicate.arg_domains.is_empty() {
360 writeln!(code, " Self {{}}").expect("writing to String is infallible");
361 } else {
362 write!(code, " Self {{ ").expect("writing to String is infallible");
363 for i in 0..predicate.arg_domains.len() {
364 if i > 0 {
365 write!(code, ", ").expect("writing to String is infallible");
366 }
367 write!(code, "arg{}", i).expect("writing to String is infallible");
368 }
369 writeln!(code, " }}").expect("writing to String is infallible");
370 }
371 writeln!(code, " }}").expect("writing to String is infallible");
372
373 writeln!(code, "}}").expect("writing to String is infallible");
374 }
375
376 fn generate_module_registration(&self, code: &mut String, table: &SymbolTable) {
378 writeln!(code, "#[pymodule]").expect("writing to String is infallible");
379 writeln!(
380 code,
381 "fn {}(_py: Python, m: &PyModule) -> PyResult<()> {{",
382 self.module_name.replace('-', "_")
383 )
384 .expect("writing to String is infallible");
385
386 for domain in table.domains.values() {
388 let type_name = Self::to_python_class_name(&domain.name);
389 writeln!(code, " m.add_class::<{}>()?;", type_name)
390 .expect("writing to String is infallible");
391 }
392
393 for predicate in table.predicates.values() {
395 let type_name = Self::to_python_class_name(&predicate.name);
396 writeln!(code, " m.add_class::<{}>()?;", type_name)
397 .expect("writing to String is infallible");
398 }
399
400 writeln!(code, " Ok(())").expect("writing to String is infallible");
401 writeln!(code, "}}").expect("writing to String is infallible");
402 }
403
404 fn generate_schema_metadata_stub(&self, code: &mut String, table: &SymbolTable) {
406 writeln!(code, "class SchemaMetadata:").expect("writing to String is infallible");
407 if self.include_docs {
408 writeln!(code, " \"\"\"Schema metadata and statistics\"\"\"")
409 .expect("writing to String is infallible");
410 }
411 writeln!(
412 code,
413 " DOMAIN_COUNT: Final[int] = {}",
414 table.domains.len()
415 )
416 .expect("writing to String is infallible");
417 writeln!(
418 code,
419 " PREDICATE_COUNT: Final[int] = {}",
420 table.predicates.len()
421 )
422 .expect("writing to String is infallible");
423
424 let total_card: usize = table.domains.values().map(|d| d.cardinality).sum();
425 writeln!(code, " TOTAL_CARDINALITY: Final[int] = {}", total_card)
426 .expect("writing to String is infallible");
427 }
428
429 fn to_python_class_name(name: &str) -> String {
431 RustCodegen::to_type_name(name) }
433}