tensorlogic_compiler/context.rs
1//! Compiler context for tracking domains, variables, and axes.
2//!
3//! The [`CompilerContext`] is the central state manager for the compilation process.
4//! It tracks domain information, variable-to-domain bindings, axis assignments,
5//! and manages temporary tensor names.
6
7use anyhow::{bail, Result};
8use std::collections::HashMap;
9
10use crate::config::CompilationConfig;
11
12// Re-export DomainInfo from adapters for backward compatibility
13pub use tensorlogic_adapters::DomainInfo;
14
15/// Compiler context for managing compilation state.
16///
17/// The `CompilerContext` tracks all stateful information needed during compilation:
18/// - Domain definitions and their cardinalities
19/// - Variable-to-domain bindings
20/// - Variable-to-axis assignments (for einsum notation)
21/// - Temporary tensor name generation
22/// - Compilation configuration (logic-to-tensor mapping strategies)
23/// - Optional SymbolTable integration for schema-driven compilation
24///
25/// # Lifecycle
26///
27/// 1. Create a new context with [`CompilerContext::new()`], [`CompilerContext::with_config()`],
28/// or [`CompilerContext::from_symbol_table()`]
29/// 2. Register domains with [`add_domain`](CompilerContext::add_domain)
30/// 3. Optionally bind variables to domains with [`bind_var`](CompilerContext::bind_var)
31/// 4. Pass the context to [`compile_to_einsum_with_context`](crate::compile_to_einsum_with_context)
32/// 5. Axes are automatically assigned during compilation
33///
34/// # Examples
35///
36/// ## Basic Usage
37///
38/// ```
39/// use tensorlogic_compiler::{CompilerContext, CompilationConfig};
40///
41/// // Use default soft_differentiable strategy
42/// let mut ctx = CompilerContext::new();
43///
44/// // Or use a specific strategy
45/// let mut ctx_fuzzy = CompilerContext::with_config(
46/// CompilationConfig::fuzzy_lukasiewicz()
47/// );
48///
49/// // Register domains
50/// ctx.add_domain("Person", 100);
51/// ctx.add_domain("City", 50);
52///
53/// // Optionally bind variables (or let the compiler infer)
54/// ctx.bind_var("x", "Person").unwrap();
55/// ```
56///
57/// ## Schema-Driven Compilation
58///
59/// ```
60/// use tensorlogic_compiler::CompilerContext;
61/// use tensorlogic_adapters::{SymbolTable, DomainInfo};
62///
63/// // Create a symbol table with schema
64/// let mut table = SymbolTable::new();
65/// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
66///
67/// // Create context from symbol table
68/// let ctx = CompilerContext::from_symbol_table(&table);
69///
70/// // Domains are automatically imported
71/// assert!(ctx.domains.contains_key("Person"));
72/// ```
73#[derive(Debug, Clone)]
74pub struct CompilerContext {
75 /// Registered domains with their metadata
76 pub domains: HashMap<String, DomainInfo>,
77 /// Variable-to-domain bindings
78 pub var_to_domain: HashMap<String, String>,
79 /// Variable-to-axis assignments (e.g., 'x' → 'a', 'y' → 'b')
80 pub var_to_axis: HashMap<String, char>,
81 /// Next available axis character
82 next_axis: char,
83 /// Counter for generating unique temporary tensor names
84 temp_counter: usize,
85 /// Compilation configuration (strategies for logic operations)
86 pub config: CompilationConfig,
87 /// Optional reference to symbol table for schema-driven compilation
88 symbol_table_ref: Option<String>, // Just a marker for now
89 /// Let bindings: variable name to tensor index
90 pub let_bindings: HashMap<String, usize>,
91}
92
93impl CompilerContext {
94 /// Creates a new, empty compiler context with default configuration.
95 ///
96 /// The context starts with no domains, no variable bindings, axis
97 /// assignment beginning at 'a', and uses the default `soft_differentiable`
98 /// compilation strategy.
99 ///
100 /// # Examples
101 ///
102 /// ```
103 /// use tensorlogic_compiler::CompilerContext;
104 ///
105 /// let ctx = CompilerContext::new();
106 /// assert!(ctx.domains.is_empty());
107 /// ```
108 pub fn new() -> Self {
109 CompilerContext {
110 domains: HashMap::new(),
111 var_to_domain: HashMap::new(),
112 var_to_axis: HashMap::new(),
113 next_axis: 'a',
114 temp_counter: 0,
115 config: CompilationConfig::default(),
116 symbol_table_ref: None,
117 let_bindings: HashMap::new(),
118 }
119 }
120
121 /// Creates a new compiler context with a specific configuration.
122 ///
123 /// Use this to control how logical operations are compiled to tensor operations.
124 ///
125 /// # Examples
126 ///
127 /// ```
128 /// use tensorlogic_compiler::{CompilerContext, CompilationConfig};
129 ///
130 /// // Use Łukasiewicz fuzzy logic (satisfies De Morgan's laws)
131 /// let ctx = CompilerContext::with_config(
132 /// CompilationConfig::fuzzy_lukasiewicz()
133 /// );
134 ///
135 /// // Use hard Boolean logic
136 /// let ctx_bool = CompilerContext::with_config(
137 /// CompilationConfig::hard_boolean()
138 /// );
139 /// ```
140 pub fn with_config(config: CompilationConfig) -> Self {
141 CompilerContext {
142 domains: HashMap::new(),
143 var_to_domain: HashMap::new(),
144 var_to_axis: HashMap::new(),
145 next_axis: 'a',
146 temp_counter: 0,
147 config,
148 symbol_table_ref: None,
149 let_bindings: HashMap::new(),
150 }
151 }
152
153 /// Creates a compiler context from a SymbolTable for schema-driven compilation.
154 ///
155 /// This constructor automatically imports all domains from the symbol table
156 /// and validates the schema. It enables type-safe compilation with rich
157 /// predicate signatures and domain hierarchies.
158 ///
159 /// # Arguments
160 ///
161 /// * `table` - The symbol table containing domain and predicate definitions
162 ///
163 /// # Examples
164 ///
165 /// ```
166 /// use tensorlogic_compiler::CompilerContext;
167 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo};
168 ///
169 /// let mut table = SymbolTable::new();
170 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
171 /// table.add_predicate(PredicateInfo::new(
172 /// "knows",
173 /// vec!["Person".to_string(), "Person".to_string()]
174 /// )).unwrap();
175 ///
176 /// let ctx = CompilerContext::from_symbol_table(&table);
177 ///
178 /// assert_eq!(ctx.domains.len(), 1);
179 /// assert!(ctx.domains.contains_key("Person"));
180 /// ```
181 pub fn from_symbol_table(table: &tensorlogic_adapters::SymbolTable) -> Self {
182 let mut ctx = Self::new();
183
184 // Import all domains from the symbol table
185 for domain in table.domains.values() {
186 ctx.domains.insert(domain.name.clone(), domain.clone());
187 }
188
189 // Import variable bindings if any
190 for (var, domain) in &table.variables {
191 ctx.var_to_domain.insert(var.clone(), domain.clone());
192 }
193
194 ctx.symbol_table_ref = Some("imported".to_string());
195 ctx
196 }
197
198 /// Creates a compiler context from a SymbolTable with a specific configuration.
199 ///
200 /// Combines schema-driven compilation with custom compilation strategies.
201 ///
202 /// # Examples
203 ///
204 /// ```
205 /// use tensorlogic_compiler::{CompilerContext, CompilationConfig};
206 /// use tensorlogic_adapters::{SymbolTable, DomainInfo};
207 ///
208 /// let mut table = SymbolTable::new();
209 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
210 ///
211 /// let ctx = CompilerContext::from_symbol_table_with_config(
212 /// &table,
213 /// CompilationConfig::fuzzy_lukasiewicz()
214 /// );
215 /// ```
216 pub fn from_symbol_table_with_config(
217 table: &tensorlogic_adapters::SymbolTable,
218 config: CompilationConfig,
219 ) -> Self {
220 let mut ctx = Self::from_symbol_table(table);
221 ctx.config = config;
222 ctx
223 }
224
225 /// Registers a new domain with its cardinality.
226 ///
227 /// Domains must be registered before they can be used for variable bindings
228 /// or quantifiers. The cardinality determines the size of the tensor dimension
229 /// for variables in this domain.
230 ///
231 /// # Arguments
232 ///
233 /// * `name` - The domain name (e.g., "Person", "City")
234 /// * `cardinality` - The number of possible values in this domain
235 ///
236 /// # Examples
237 ///
238 /// ```
239 /// use tensorlogic_compiler::CompilerContext;
240 ///
241 /// let mut ctx = CompilerContext::new();
242 /// ctx.add_domain("Person", 100);
243 /// ctx.add_domain("City", 50);
244 ///
245 /// assert_eq!(ctx.domains.len(), 2);
246 /// assert_eq!(ctx.domains.get("Person").unwrap().cardinality, 100);
247 /// ```
248 pub fn add_domain(&mut self, name: impl Into<String>, cardinality: usize) {
249 let name = name.into();
250 self.domains
251 .insert(name.clone(), DomainInfo::new(name, cardinality));
252 }
253
254 /// Registers a domain with full metadata.
255 ///
256 /// Use this method when you have a complete DomainInfo instance with
257 /// metadata, descriptions, or parametric types.
258 ///
259 /// # Examples
260 ///
261 /// ```
262 /// use tensorlogic_compiler::CompilerContext;
263 /// use tensorlogic_adapters::DomainInfo;
264 ///
265 /// let mut ctx = CompilerContext::new();
266 /// let domain = DomainInfo::new("Person", 100)
267 /// .with_description("All persons in the system");
268 ///
269 /// ctx.add_domain_info(domain);
270 ///
271 /// assert!(ctx.domains.get("Person").unwrap().description.is_some());
272 /// ```
273 pub fn add_domain_info(&mut self, domain: DomainInfo) {
274 self.domains.insert(domain.name.clone(), domain);
275 }
276
277 /// Binds a variable to a specific domain.
278 ///
279 /// This is optional - the compiler can often infer domains from quantifiers.
280 /// However, explicit bindings can be useful for type checking and validation.
281 ///
282 /// # Arguments
283 ///
284 /// * `var` - The variable name (e.g., "x", "y")
285 /// * `domain` - The domain name (must be already registered)
286 ///
287 /// # Errors
288 ///
289 /// Returns an error if the specified domain has not been registered.
290 ///
291 /// # Examples
292 ///
293 /// ```
294 /// use tensorlogic_compiler::CompilerContext;
295 ///
296 /// let mut ctx = CompilerContext::new();
297 /// ctx.add_domain("Person", 100);
298 ///
299 /// ctx.bind_var("x", "Person").unwrap();
300 /// assert_eq!(ctx.var_to_domain.get("x"), Some(&"Person".to_string()));
301 ///
302 /// // Error: domain not registered
303 /// assert!(ctx.bind_var("y", "Unknown").is_err());
304 /// ```
305 pub fn bind_var(&mut self, var: &str, domain: &str) -> Result<()> {
306 if !self.domains.contains_key(domain) {
307 bail!("Domain '{}' not found", domain);
308 }
309 self.var_to_domain
310 .insert(var.to_string(), domain.to_string());
311 Ok(())
312 }
313
314 /// Assigns an einsum axis to a variable.
315 ///
316 /// Axes are assigned in lexicographic order ('a', 'b', 'c', ...).
317 /// If a variable already has an assigned axis, that axis is returned.
318 /// Otherwise, a new axis is assigned and the counter is incremented.
319 ///
320 /// # Arguments
321 ///
322 /// * `var` - The variable name
323 ///
324 /// # Returns
325 ///
326 /// The axis character assigned to this variable.
327 ///
328 /// # Examples
329 ///
330 /// ```
331 /// use tensorlogic_compiler::CompilerContext;
332 ///
333 /// let mut ctx = CompilerContext::new();
334 ///
335 /// let axis_x = ctx.assign_axis("x");
336 /// assert_eq!(axis_x, 'a');
337 ///
338 /// let axis_y = ctx.assign_axis("y");
339 /// assert_eq!(axis_y, 'b');
340 ///
341 /// // Re-assigning returns the same axis
342 /// let axis_x_again = ctx.assign_axis("x");
343 /// assert_eq!(axis_x_again, 'a');
344 /// ```
345 pub fn assign_axis(&mut self, var: &str) -> char {
346 if let Some(&axis) = self.var_to_axis.get(var) {
347 return axis;
348 }
349 let axis = self.next_axis;
350 self.var_to_axis.insert(var.to_string(), axis);
351 self.next_axis = ((axis as u8) + 1) as char;
352 axis
353 }
354
355 /// Generates a fresh temporary tensor name.
356 ///
357 /// Temporary tensors are used for intermediate results during compilation.
358 /// Names are generated as "temp_0", "temp_1", etc.
359 ///
360 /// # Returns
361 ///
362 /// A unique temporary tensor name.
363 ///
364 /// # Examples
365 ///
366 /// ```
367 /// use tensorlogic_compiler::CompilerContext;
368 ///
369 /// let mut ctx = CompilerContext::new();
370 ///
371 /// let temp1 = ctx.fresh_temp();
372 /// assert_eq!(temp1, "temp_0");
373 ///
374 /// let temp2 = ctx.fresh_temp();
375 /// assert_eq!(temp2, "temp_1");
376 /// ```
377 pub fn fresh_temp(&mut self) -> String {
378 let name = format!("temp_{}", self.temp_counter);
379 self.temp_counter += 1;
380 name
381 }
382
383 /// Gets the einsum axes string for a list of terms.
384 ///
385 /// This is used internally during predicate compilation to determine
386 /// the axes string for a predicate's arguments.
387 ///
388 /// # Arguments
389 ///
390 /// * `terms` - The list of terms (usually predicate arguments)
391 ///
392 /// # Returns
393 ///
394 /// A string of axis characters (e.g., "ab" for two variables)
395 ///
396 /// # Errors
397 ///
398 /// Returns an error if a variable term has not been assigned an axis.
399 ///
400 /// # Examples
401 ///
402 /// ```
403 /// use tensorlogic_compiler::CompilerContext;
404 /// use tensorlogic_ir::Term;
405 ///
406 /// let mut ctx = CompilerContext::new();
407 /// ctx.assign_axis("x");
408 /// ctx.assign_axis("y");
409 ///
410 /// let terms = vec![Term::var("x"), Term::var("y")];
411 /// let axes = ctx.get_axes(&terms).unwrap();
412 /// assert_eq!(axes, "ab");
413 /// ```
414 pub fn get_axes(&self, terms: &[tensorlogic_ir::Term]) -> Result<String> {
415 use anyhow::anyhow;
416 use tensorlogic_ir::Term;
417
418 let mut axes = String::new();
419 for term in terms {
420 if let Term::Var(v) = term {
421 let axis = self
422 .var_to_axis
423 .get(v)
424 .ok_or_else(|| anyhow!("Variable '{}' not assigned an axis", v))?;
425 axes.push(*axis);
426 }
427 }
428 Ok(axes)
429 }
430}
431
432impl Default for CompilerContext {
433 fn default() -> Self {
434 Self::new()
435 }
436}
437
438/// Internal state during compilation of a single expression
439#[derive(Debug)]
440pub(crate) struct CompileState {
441 pub tensor_idx: usize,
442 pub axes: String,
443}