Skip to main content

tensorlogic_adapters/codegen/
rust.rs

1use std::fmt::Write as FmtWrite;
2
3use crate::{DomainInfo, PredicateInfo, SymbolTable};
4
5/// Code generator for Rust types from schemas.
6pub struct RustCodegen {
7    /// Module name for generated code
8    module_name: String,
9    /// Whether to derive common traits
10    derive_common: bool,
11    /// Whether to include documentation comments
12    include_docs: bool,
13}
14
15impl RustCodegen {
16    /// Create a new Rust code generator.
17    pub fn new(module_name: impl Into<String>) -> Self {
18        Self {
19            module_name: module_name.into(),
20            derive_common: true,
21            include_docs: true,
22        }
23    }
24
25    /// Set whether to derive common traits (Clone, Debug, etc.).
26    pub fn with_common_derives(mut self, enable: bool) -> Self {
27        self.derive_common = enable;
28        self
29    }
30
31    /// Set whether to include documentation comments.
32    pub fn with_docs(mut self, enable: bool) -> Self {
33        self.include_docs = enable;
34        self
35    }
36
37    /// Generate complete Rust module from a symbol table.
38    pub fn generate(&self, table: &SymbolTable) -> String {
39        let mut code = String::new();
40
41        // Module header
42        writeln!(code, "//! Generated from TensorLogic schema.")
43            .expect("writing to String is infallible");
44        writeln!(code, "//! Module: {}", self.module_name)
45            .expect("writing to String is infallible");
46        writeln!(code, "//!").expect("writing to String is infallible");
47        writeln!(code, "//! This code was automatically generated.")
48            .expect("writing to String is infallible");
49        writeln!(code, "//! DO NOT EDIT MANUALLY.").expect("writing to String is infallible");
50        writeln!(code).expect("writing to String is infallible");
51
52        // Use statements
53        writeln!(code, "#![allow(dead_code)]").expect("writing to String is infallible");
54        writeln!(code).expect("writing to String is infallible");
55
56        // Generate domain types
57        writeln!(code, "// ============================================")
58            .expect("writing to String is infallible");
59        writeln!(code, "// Domain Types").expect("writing to String is infallible");
60        writeln!(code, "// ============================================")
61            .expect("writing to String is infallible");
62        writeln!(code).expect("writing to String is infallible");
63
64        for domain in table.domains.values() {
65            self.generate_domain(&mut code, domain);
66            writeln!(code).expect("writing to String is infallible");
67        }
68
69        // Generate predicate types
70        writeln!(code, "// ============================================")
71            .expect("writing to String is infallible");
72        writeln!(code, "// Predicate Types").expect("writing to String is infallible");
73        writeln!(code, "// ============================================")
74            .expect("writing to String is infallible");
75        writeln!(code).expect("writing to String is infallible");
76
77        for predicate in table.predicates.values() {
78            self.generate_predicate(&mut code, predicate, table);
79            writeln!(code).expect("writing to String is infallible");
80        }
81
82        // Generate schema metadata type
83        writeln!(code, "// ============================================")
84            .expect("writing to String is infallible");
85        writeln!(code, "// Schema Metadata").expect("writing to String is infallible");
86        writeln!(code, "// ============================================")
87            .expect("writing to String is infallible");
88        writeln!(code).expect("writing to String is infallible");
89        self.generate_schema_metadata(&mut code, table);
90
91        code
92    }
93
94    /// Generate domain type.
95    fn generate_domain(&self, code: &mut String, domain: &DomainInfo) {
96        if self.include_docs {
97            if let Some(ref desc) = domain.description {
98                writeln!(code, "/// {}", desc).expect("writing to String is infallible");
99            } else {
100                writeln!(code, "/// Domain: {}", domain.name)
101                    .expect("writing to String is infallible");
102            }
103            writeln!(code, "///").expect("writing to String is infallible");
104            writeln!(code, "/// Cardinality: {}", domain.cardinality)
105                .expect("writing to String is infallible");
106        }
107
108        if self.derive_common {
109            writeln!(code, "#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]")
110                .expect("writing to String is infallible");
111        }
112
113        let type_name = Self::to_type_name(&domain.name);
114        writeln!(code, "pub struct {}(pub usize);", type_name)
115            .expect("writing to String is infallible");
116        writeln!(code).expect("writing to String is infallible");
117
118        // Generate constructor and accessors
119        writeln!(code, "impl {} {{", type_name).expect("writing to String is infallible");
120        writeln!(
121            code,
122            "    /// Maximum valid ID for this domain (exclusive)."
123        )
124        .expect("writing to String is infallible");
125        writeln!(
126            code,
127            "    pub const CARDINALITY: usize = {};",
128            domain.cardinality
129        )
130        .expect("writing to String is infallible");
131        writeln!(code).expect("writing to String is infallible");
132
133        writeln!(code, "    /// Create a new {} instance.", type_name)
134            .expect("writing to String is infallible");
135        writeln!(code, "    ///").expect("writing to String is infallible");
136        writeln!(code, "    /// # Panics").expect("writing to String is infallible");
137        writeln!(code, "    ///").expect("writing to String is infallible");
138        writeln!(code, "    /// Panics if `id >= {}`.", domain.cardinality)
139            .expect("writing to String is infallible");
140        writeln!(code, "    pub fn new(id: usize) -> Self {{")
141            .expect("writing to String is infallible");
142        writeln!(code, "        assert!(id < Self::CARDINALITY, \"ID {{}} exceeds cardinality {{}}\", id, Self::CARDINALITY);", ).expect("writing to String is infallible");
143        writeln!(code, "        Self(id)").expect("writing to String is infallible");
144        writeln!(code, "    }}").expect("writing to String is infallible");
145        writeln!(code).expect("writing to String is infallible");
146
147        writeln!(
148            code,
149            "    /// Create a new {} instance without bounds checking.",
150            type_name
151        )
152        .expect("writing to String is infallible");
153        writeln!(code, "    ///").expect("writing to String is infallible");
154        writeln!(code, "    /// # Safety").expect("writing to String is infallible");
155        writeln!(code, "    ///").expect("writing to String is infallible");
156        writeln!(
157            code,
158            "    /// Caller must ensure `id < {}`.",
159            domain.cardinality
160        )
161        .expect("writing to String is infallible");
162        writeln!(
163            code,
164            "    pub unsafe fn new_unchecked(id: usize) -> Self {{"
165        )
166        .expect("writing to String is infallible");
167        writeln!(code, "        Self(id)").expect("writing to String is infallible");
168        writeln!(code, "    }}").expect("writing to String is infallible");
169        writeln!(code).expect("writing to String is infallible");
170
171        writeln!(code, "    /// Get the underlying ID.").expect("writing to String is infallible");
172        writeln!(code, "    pub fn id(&self) -> usize {{")
173            .expect("writing to String is infallible");
174        writeln!(code, "        self.0").expect("writing to String is infallible");
175        writeln!(code, "    }}").expect("writing to String is infallible");
176
177        writeln!(code, "}}").expect("writing to String is infallible");
178    }
179
180    /// Generate predicate type.
181    fn generate_predicate(
182        &self,
183        code: &mut String,
184        predicate: &PredicateInfo,
185        _table: &SymbolTable,
186    ) {
187        if self.include_docs {
188            if let Some(ref desc) = predicate.description {
189                writeln!(code, "/// {}", desc).expect("writing to String is infallible");
190            } else {
191                writeln!(code, "/// Predicate: {}", predicate.name)
192                    .expect("writing to String is infallible");
193            }
194            writeln!(code, "///").expect("writing to String is infallible");
195            writeln!(code, "/// Arity: {}", predicate.arg_domains.len())
196                .expect("writing to String is infallible");
197
198            if let Some(ref constraints) = predicate.constraints {
199                if !constraints.properties.is_empty() {
200                    writeln!(code, "///").expect("writing to String is infallible");
201                    writeln!(code, "/// Properties:").expect("writing to String is infallible");
202                    for prop in &constraints.properties {
203                        writeln!(code, "/// - {:?}", prop)
204                            .expect("writing to String is infallible");
205                    }
206                }
207            }
208        }
209
210        if self.derive_common {
211            writeln!(code, "#[derive(Clone, Debug, PartialEq, Eq, Hash)]")
212                .expect("writing to String is infallible");
213        }
214
215        let type_name = Self::to_type_name(&predicate.name);
216
217        // Generate struct with typed fields
218        if predicate.arg_domains.is_empty() {
219            // Nullary predicate
220            writeln!(code, "pub struct {};", type_name).expect("writing to String is infallible");
221        } else if predicate.arg_domains.len() == 1 {
222            // Unary predicate
223            let domain_type = Self::to_type_name(&predicate.arg_domains[0]);
224            writeln!(code, "pub struct {}(pub {});", type_name, domain_type)
225                .expect("writing to String is infallible");
226        } else {
227            // N-ary predicate - use tuple struct
228            write!(code, "pub struct {}(", type_name).expect("writing to String is infallible");
229            for (i, domain_name) in predicate.arg_domains.iter().enumerate() {
230                if i > 0 {
231                    write!(code, ", ").expect("writing to String is infallible");
232                }
233                write!(code, "pub {}", Self::to_type_name(domain_name))
234                    .expect("writing to String is infallible");
235            }
236            writeln!(code, ");").expect("writing to String is infallible");
237        }
238
239        writeln!(code).expect("writing to String is infallible");
240
241        // Generate constructor and accessors
242        writeln!(code, "impl {} {{", type_name).expect("writing to String is infallible");
243
244        if !predicate.arg_domains.is_empty() {
245            // Constructor
246            writeln!(code, "    /// Create a new {} instance.", type_name)
247                .expect("writing to String is infallible");
248            write!(code, "    pub fn new(").expect("writing to String is infallible");
249            for (i, domain_name) in predicate.arg_domains.iter().enumerate() {
250                if i > 0 {
251                    write!(code, ", ").expect("writing to String is infallible");
252                }
253                write!(code, "arg{}: {}", i, Self::to_type_name(domain_name))
254                    .expect("writing to String is infallible");
255            }
256            writeln!(code, ") -> Self {{").expect("writing to String is infallible");
257
258            if predicate.arg_domains.len() == 1 {
259                writeln!(code, "        Self(arg0)").expect("writing to String is infallible");
260            } else {
261                write!(code, "        Self(").expect("writing to String is infallible");
262                for i in 0..predicate.arg_domains.len() {
263                    if i > 0 {
264                        write!(code, ", ").expect("writing to String is infallible");
265                    }
266                    write!(code, "arg{}", i).expect("writing to String is infallible");
267                }
268                writeln!(code, ")").expect("writing to String is infallible");
269            }
270            writeln!(code, "    }}").expect("writing to String is infallible");
271            writeln!(code).expect("writing to String is infallible");
272
273            // Accessor methods
274            for (i, domain_name) in predicate.arg_domains.iter().enumerate() {
275                writeln!(code, "    /// Get argument {}.", i)
276                    .expect("writing to String is infallible");
277                writeln!(
278                    code,
279                    "    pub fn arg{}(&self) -> {} {{",
280                    i,
281                    Self::to_type_name(domain_name)
282                )
283                .expect("writing to String is infallible");
284                if predicate.arg_domains.len() == 1 {
285                    writeln!(code, "        self.0").expect("writing to String is infallible");
286                } else {
287                    writeln!(code, "        self.{}", i).expect("writing to String is infallible");
288                }
289                writeln!(code, "    }}").expect("writing to String is infallible");
290                writeln!(code).expect("writing to String is infallible");
291            }
292        }
293
294        writeln!(code, "}}").expect("writing to String is infallible");
295    }
296
297    /// Generate schema metadata type.
298    fn generate_schema_metadata(&self, code: &mut String, table: &SymbolTable) {
299        writeln!(code, "/// Schema metadata and statistics.")
300            .expect("writing to String is infallible");
301        writeln!(code, "pub struct SchemaMetadata;").expect("writing to String is infallible");
302        writeln!(code).expect("writing to String is infallible");
303
304        writeln!(code, "impl SchemaMetadata {{").expect("writing to String is infallible");
305        writeln!(code, "    /// Number of domains in the schema.")
306            .expect("writing to String is infallible");
307        writeln!(
308            code,
309            "    pub const DOMAIN_COUNT: usize = {};",
310            table.domains.len()
311        )
312        .expect("writing to String is infallible");
313        writeln!(code).expect("writing to String is infallible");
314
315        writeln!(code, "    /// Number of predicates in the schema.")
316            .expect("writing to String is infallible");
317        writeln!(
318            code,
319            "    pub const PREDICATE_COUNT: usize = {};",
320            table.predicates.len()
321        )
322        .expect("writing to String is infallible");
323        writeln!(code).expect("writing to String is infallible");
324
325        writeln!(code, "    /// Total cardinality across all domains.")
326            .expect("writing to String is infallible");
327        let total_card: usize = table.domains.values().map(|d| d.cardinality).sum();
328        writeln!(
329            code,
330            "    pub const TOTAL_CARDINALITY: usize = {};",
331            total_card
332        )
333        .expect("writing to String is infallible");
334
335        writeln!(code, "}}").expect("writing to String is infallible");
336    }
337
338    /// Convert a domain/predicate name to a Rust type name (PascalCase).
339    pub(super) fn to_type_name(name: &str) -> String {
340        // Simple conversion: capitalize first letter of each word
341        name.split('_')
342            .map(|word| {
343                let mut chars = word.chars();
344                match chars.next() {
345                    None => String::new(),
346                    Some(first) => first.to_uppercase().chain(chars).collect(),
347                }
348            })
349            .collect()
350    }
351}