Skip to main content

tensorlogic_compiler/
context.rs

1//! Compiler context for tracking domains, variables, and axes.
2//!
3//! The [`CompilerContext`] is the central state manager for the compilation process.
4//! It tracks domain information, variable-to-domain bindings, axis assignments,
5//! and manages temporary tensor names.
6
7use anyhow::{bail, Result};
8use std::collections::HashMap;
9
10use crate::config::CompilationConfig;
11
12// Re-export DomainInfo from adapters for backward compatibility
13pub use tensorlogic_adapters::DomainInfo;
14
15/// Compiler context for managing compilation state.
16///
17/// The `CompilerContext` tracks all stateful information needed during compilation:
18/// - Domain definitions and their cardinalities
19/// - Variable-to-domain bindings
20/// - Variable-to-axis assignments (for einsum notation)
21/// - Temporary tensor name generation
22/// - Compilation configuration (logic-to-tensor mapping strategies)
23/// - Optional SymbolTable integration for schema-driven compilation
24///
25/// # Lifecycle
26///
27/// 1. Create a new context with [`CompilerContext::new()`], [`CompilerContext::with_config()`],
28///    or [`CompilerContext::from_symbol_table()`]
29/// 2. Register domains with [`add_domain`](CompilerContext::add_domain)
30/// 3. Optionally bind variables to domains with [`bind_var`](CompilerContext::bind_var)
31/// 4. Pass the context to [`compile_to_einsum_with_context`](crate::compile_to_einsum_with_context)
32/// 5. Axes are automatically assigned during compilation
33///
34/// # Examples
35///
36/// ## Basic Usage
37///
38/// ```
39/// use tensorlogic_compiler::{CompilerContext, CompilationConfig};
40///
41/// // Use default soft_differentiable strategy
42/// let mut ctx = CompilerContext::new();
43///
44/// // Or use a specific strategy
45/// let mut ctx_fuzzy = CompilerContext::with_config(
46///     CompilationConfig::fuzzy_lukasiewicz()
47/// );
48///
49/// // Register domains
50/// ctx.add_domain("Person", 100);
51/// ctx.add_domain("City", 50);
52///
53/// // Optionally bind variables (or let the compiler infer)
54/// ctx.bind_var("x", "Person").unwrap();
55/// ```
56///
57/// ## Schema-Driven Compilation
58///
59/// ```
60/// use tensorlogic_compiler::CompilerContext;
61/// use tensorlogic_adapters::{SymbolTable, DomainInfo};
62///
63/// // Create a symbol table with schema
64/// let mut table = SymbolTable::new();
65/// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
66///
67/// // Create context from symbol table
68/// let ctx = CompilerContext::from_symbol_table(&table);
69///
70/// // Domains are automatically imported
71/// assert!(ctx.domains.contains_key("Person"));
72/// ```
73#[derive(Debug, Clone)]
74pub struct CompilerContext {
75    /// Registered domains with their metadata
76    pub domains: HashMap<String, DomainInfo>,
77    /// Variable-to-domain bindings
78    pub var_to_domain: HashMap<String, String>,
79    /// Variable-to-axis assignments (e.g., 'x' → 'a', 'y' → 'b')
80    pub var_to_axis: HashMap<String, char>,
81    /// Next available axis character
82    next_axis: char,
83    /// Counter for generating unique temporary tensor names
84    temp_counter: usize,
85    /// Compilation configuration (strategies for logic operations)
86    pub config: CompilationConfig,
87    /// Optional reference to symbol table for schema-driven compilation
88    symbol_table_ref: Option<String>, // Just a marker for now
89    /// Let bindings: variable name to tensor index
90    pub let_bindings: HashMap<String, usize>,
91}
92
93impl CompilerContext {
94    /// Creates a new, empty compiler context with default configuration.
95    ///
96    /// The context starts with no domains, no variable bindings, axis
97    /// assignment beginning at 'a', and uses the default `soft_differentiable`
98    /// compilation strategy.
99    ///
100    /// # Examples
101    ///
102    /// ```
103    /// use tensorlogic_compiler::CompilerContext;
104    ///
105    /// let ctx = CompilerContext::new();
106    /// assert!(ctx.domains.is_empty());
107    /// ```
108    pub fn new() -> Self {
109        CompilerContext {
110            domains: HashMap::new(),
111            var_to_domain: HashMap::new(),
112            var_to_axis: HashMap::new(),
113            next_axis: 'a',
114            temp_counter: 0,
115            config: CompilationConfig::default(),
116            symbol_table_ref: None,
117            let_bindings: HashMap::new(),
118        }
119    }
120
121    /// Creates a new compiler context with a specific configuration.
122    ///
123    /// Use this to control how logical operations are compiled to tensor operations.
124    ///
125    /// # Examples
126    ///
127    /// ```
128    /// use tensorlogic_compiler::{CompilerContext, CompilationConfig};
129    ///
130    /// // Use Łukasiewicz fuzzy logic (satisfies De Morgan's laws)
131    /// let ctx = CompilerContext::with_config(
132    ///     CompilationConfig::fuzzy_lukasiewicz()
133    /// );
134    ///
135    /// // Use hard Boolean logic
136    /// let ctx_bool = CompilerContext::with_config(
137    ///     CompilationConfig::hard_boolean()
138    /// );
139    /// ```
140    pub fn with_config(config: CompilationConfig) -> Self {
141        CompilerContext {
142            domains: HashMap::new(),
143            var_to_domain: HashMap::new(),
144            var_to_axis: HashMap::new(),
145            next_axis: 'a',
146            temp_counter: 0,
147            config,
148            symbol_table_ref: None,
149            let_bindings: HashMap::new(),
150        }
151    }
152
153    /// Creates a compiler context from a SymbolTable for schema-driven compilation.
154    ///
155    /// This constructor automatically imports all domains from the symbol table
156    /// and validates the schema. It enables type-safe compilation with rich
157    /// predicate signatures and domain hierarchies.
158    ///
159    /// # Arguments
160    ///
161    /// * `table` - The symbol table containing domain and predicate definitions
162    ///
163    /// # Examples
164    ///
165    /// ```
166    /// use tensorlogic_compiler::CompilerContext;
167    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo};
168    ///
169    /// let mut table = SymbolTable::new();
170    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
171    /// table.add_predicate(PredicateInfo::new(
172    ///     "knows",
173    ///     vec!["Person".to_string(), "Person".to_string()]
174    /// )).unwrap();
175    ///
176    /// let ctx = CompilerContext::from_symbol_table(&table);
177    ///
178    /// assert_eq!(ctx.domains.len(), 1);
179    /// assert!(ctx.domains.contains_key("Person"));
180    /// ```
181    pub fn from_symbol_table(table: &tensorlogic_adapters::SymbolTable) -> Self {
182        let mut ctx = Self::new();
183
184        // Import all domains from the symbol table
185        for domain in table.domains.values() {
186            ctx.domains.insert(domain.name.clone(), domain.clone());
187        }
188
189        // Import variable bindings if any
190        for (var, domain) in &table.variables {
191            ctx.var_to_domain.insert(var.clone(), domain.clone());
192        }
193
194        ctx.symbol_table_ref = Some("imported".to_string());
195        ctx
196    }
197
198    /// Creates a compiler context from a SymbolTable with a specific configuration.
199    ///
200    /// Combines schema-driven compilation with custom compilation strategies.
201    ///
202    /// # Examples
203    ///
204    /// ```
205    /// use tensorlogic_compiler::{CompilerContext, CompilationConfig};
206    /// use tensorlogic_adapters::{SymbolTable, DomainInfo};
207    ///
208    /// let mut table = SymbolTable::new();
209    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
210    ///
211    /// let ctx = CompilerContext::from_symbol_table_with_config(
212    ///     &table,
213    ///     CompilationConfig::fuzzy_lukasiewicz()
214    /// );
215    /// ```
216    pub fn from_symbol_table_with_config(
217        table: &tensorlogic_adapters::SymbolTable,
218        config: CompilationConfig,
219    ) -> Self {
220        let mut ctx = Self::from_symbol_table(table);
221        ctx.config = config;
222        ctx
223    }
224
225    /// Registers a new domain with its cardinality.
226    ///
227    /// Domains must be registered before they can be used for variable bindings
228    /// or quantifiers. The cardinality determines the size of the tensor dimension
229    /// for variables in this domain.
230    ///
231    /// # Arguments
232    ///
233    /// * `name` - The domain name (e.g., "Person", "City")
234    /// * `cardinality` - The number of possible values in this domain
235    ///
236    /// # Examples
237    ///
238    /// ```
239    /// use tensorlogic_compiler::CompilerContext;
240    ///
241    /// let mut ctx = CompilerContext::new();
242    /// ctx.add_domain("Person", 100);
243    /// ctx.add_domain("City", 50);
244    ///
245    /// assert_eq!(ctx.domains.len(), 2);
246    /// assert_eq!(ctx.domains.get("Person").unwrap().cardinality, 100);
247    /// ```
248    pub fn add_domain(&mut self, name: impl Into<String>, cardinality: usize) {
249        let name = name.into();
250        self.domains
251            .insert(name.clone(), DomainInfo::new(name, cardinality));
252    }
253
254    /// Registers a domain with full metadata.
255    ///
256    /// Use this method when you have a complete DomainInfo instance with
257    /// metadata, descriptions, or parametric types.
258    ///
259    /// # Examples
260    ///
261    /// ```
262    /// use tensorlogic_compiler::CompilerContext;
263    /// use tensorlogic_adapters::DomainInfo;
264    ///
265    /// let mut ctx = CompilerContext::new();
266    /// let domain = DomainInfo::new("Person", 100)
267    ///     .with_description("All persons in the system");
268    ///
269    /// ctx.add_domain_info(domain);
270    ///
271    /// assert!(ctx.domains.get("Person").unwrap().description.is_some());
272    /// ```
273    pub fn add_domain_info(&mut self, domain: DomainInfo) {
274        self.domains.insert(domain.name.clone(), domain);
275    }
276
277    /// Binds a variable to a specific domain.
278    ///
279    /// This is optional - the compiler can often infer domains from quantifiers.
280    /// However, explicit bindings can be useful for type checking and validation.
281    ///
282    /// # Arguments
283    ///
284    /// * `var` - The variable name (e.g., "x", "y")
285    /// * `domain` - The domain name (must be already registered)
286    ///
287    /// # Errors
288    ///
289    /// Returns an error if the specified domain has not been registered.
290    ///
291    /// # Examples
292    ///
293    /// ```
294    /// use tensorlogic_compiler::CompilerContext;
295    ///
296    /// let mut ctx = CompilerContext::new();
297    /// ctx.add_domain("Person", 100);
298    ///
299    /// ctx.bind_var("x", "Person").unwrap();
300    /// assert_eq!(ctx.var_to_domain.get("x"), Some(&"Person".to_string()));
301    ///
302    /// // Error: domain not registered
303    /// assert!(ctx.bind_var("y", "Unknown").is_err());
304    /// ```
305    pub fn bind_var(&mut self, var: &str, domain: &str) -> Result<()> {
306        if !self.domains.contains_key(domain) {
307            bail!("Domain '{}' not found", domain);
308        }
309        self.var_to_domain
310            .insert(var.to_string(), domain.to_string());
311        Ok(())
312    }
313
314    /// Assigns an einsum axis to a variable.
315    ///
316    /// Axes are assigned in lexicographic order ('a', 'b', 'c', ...).
317    /// If a variable already has an assigned axis, that axis is returned.
318    /// Otherwise, a new axis is assigned and the counter is incremented.
319    ///
320    /// # Arguments
321    ///
322    /// * `var` - The variable name
323    ///
324    /// # Returns
325    ///
326    /// The axis character assigned to this variable.
327    ///
328    /// # Examples
329    ///
330    /// ```
331    /// use tensorlogic_compiler::CompilerContext;
332    ///
333    /// let mut ctx = CompilerContext::new();
334    ///
335    /// let axis_x = ctx.assign_axis("x");
336    /// assert_eq!(axis_x, 'a');
337    ///
338    /// let axis_y = ctx.assign_axis("y");
339    /// assert_eq!(axis_y, 'b');
340    ///
341    /// // Re-assigning returns the same axis
342    /// let axis_x_again = ctx.assign_axis("x");
343    /// assert_eq!(axis_x_again, 'a');
344    /// ```
345    pub fn assign_axis(&mut self, var: &str) -> char {
346        if let Some(&axis) = self.var_to_axis.get(var) {
347            return axis;
348        }
349        let axis = self.next_axis;
350        self.var_to_axis.insert(var.to_string(), axis);
351        self.next_axis = ((axis as u8) + 1) as char;
352        axis
353    }
354
355    /// Generates a fresh temporary tensor name.
356    ///
357    /// Temporary tensors are used for intermediate results during compilation.
358    /// Names are generated as "temp_0", "temp_1", etc.
359    ///
360    /// # Returns
361    ///
362    /// A unique temporary tensor name.
363    ///
364    /// # Examples
365    ///
366    /// ```
367    /// use tensorlogic_compiler::CompilerContext;
368    ///
369    /// let mut ctx = CompilerContext::new();
370    ///
371    /// let temp1 = ctx.fresh_temp();
372    /// assert_eq!(temp1, "temp_0");
373    ///
374    /// let temp2 = ctx.fresh_temp();
375    /// assert_eq!(temp2, "temp_1");
376    /// ```
377    pub fn fresh_temp(&mut self) -> String {
378        let name = format!("temp_{}", self.temp_counter);
379        self.temp_counter += 1;
380        name
381    }
382
383    /// Gets the einsum axes string for a list of terms.
384    ///
385    /// This is used internally during predicate compilation to determine
386    /// the axes string for a predicate's arguments.
387    ///
388    /// # Arguments
389    ///
390    /// * `terms` - The list of terms (usually predicate arguments)
391    ///
392    /// # Returns
393    ///
394    /// A string of axis characters (e.g., "ab" for two variables)
395    ///
396    /// # Errors
397    ///
398    /// Returns an error if a variable term has not been assigned an axis.
399    ///
400    /// # Examples
401    ///
402    /// ```
403    /// use tensorlogic_compiler::CompilerContext;
404    /// use tensorlogic_ir::Term;
405    ///
406    /// let mut ctx = CompilerContext::new();
407    /// ctx.assign_axis("x");
408    /// ctx.assign_axis("y");
409    ///
410    /// let terms = vec![Term::var("x"), Term::var("y")];
411    /// let axes = ctx.get_axes(&terms).unwrap();
412    /// assert_eq!(axes, "ab");
413    /// ```
414    pub fn get_axes(&self, terms: &[tensorlogic_ir::Term]) -> Result<String> {
415        use anyhow::anyhow;
416        use tensorlogic_ir::Term;
417
418        let mut axes = String::new();
419        for term in terms {
420            if let Term::Var(v) = term {
421                let axis = self
422                    .var_to_axis
423                    .get(v)
424                    .ok_or_else(|| anyhow!("Variable '{}' not assigned an axis", v))?;
425                axes.push(*axis);
426            }
427        }
428        Ok(axes)
429    }
430}
431
432impl Default for CompilerContext {
433    fn default() -> Self {
434        Self::new()
435    }
436}
437
438/// Internal state during compilation of a single expression
439#[derive(Debug)]
440pub(crate) struct CompileState {
441    pub tensor_idx: usize,
442    pub axes: String,
443}