Skip to main content

tensorlogic_adapters/
product.rs

1//! Domain product types for cross-domain reasoning.
2//!
3//! Product types allow combining multiple domains into composite types,
4//! enabling predicates over tuples of values from different domains.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use tensorlogic_adapters::ProductDomain;
10//!
11//! // Create Person × Location product
12//! let product = ProductDomain::new(vec!["Person".to_string(), "Location".to_string()]);
13//! assert_eq!(product.components(), &["Person", "Location"]);
14//!
15//! // Nested product: (Person × Location) × Time
16//! let nested = ProductDomain::new(vec![
17//!     product.to_string(),
18//!     "Time".to_string()
19//! ]);
20//! ```
21
22use serde::{Deserialize, Serialize};
23use std::fmt;
24
25use crate::{AdapterError, DomainInfo, SymbolTable};
26
27/// A product domain representing a tuple of component domains.
28///
29/// Product domains enable cross-domain reasoning by creating composite types
30/// from multiple base domains. The cardinality of a product domain is the
31/// product of its component cardinalities.
32///
33/// # Examples
34///
35/// ```rust
36/// use tensorlogic_adapters::ProductDomain;
37///
38/// let product = ProductDomain::new(vec![
39///     "Person".to_string(),
40///     "Location".to_string()
41/// ]);
42///
43/// assert_eq!(product.arity(), 2);
44/// assert_eq!(product.to_string(), "Person × Location");
45/// ```
46#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
47pub struct ProductDomain {
48    /// Component domains in the product.
49    components: Vec<String>,
50}
51
52impl ProductDomain {
53    /// Create a new product domain from component domains.
54    ///
55    /// # Panics
56    ///
57    /// Panics if `components` has fewer than 2 elements.
58    ///
59    /// # Examples
60    ///
61    /// ```rust
62    /// use tensorlogic_adapters::ProductDomain;
63    ///
64    /// let product = ProductDomain::new(vec!["A".to_string(), "B".to_string()]);
65    /// assert_eq!(product.arity(), 2);
66    /// ```
67    pub fn new(components: Vec<String>) -> Self {
68        assert!(
69            components.len() >= 2,
70            "Product domain must have at least 2 components"
71        );
72        Self { components }
73    }
74
75    /// Create a binary product domain (A × B).
76    ///
77    /// # Examples
78    ///
79    /// ```rust
80    /// use tensorlogic_adapters::ProductDomain;
81    ///
82    /// let product = ProductDomain::binary("Person", "Location");
83    /// assert_eq!(product.to_string(), "Person × Location");
84    /// ```
85    pub fn binary(a: impl Into<String>, b: impl Into<String>) -> Self {
86        Self::new(vec![a.into(), b.into()])
87    }
88
89    /// Create a ternary product domain (A × B × C).
90    ///
91    /// # Examples
92    ///
93    /// ```rust
94    /// use tensorlogic_adapters::ProductDomain;
95    ///
96    /// let product = ProductDomain::ternary("Person", "Location", "Time");
97    /// assert_eq!(product.to_string(), "Person × Location × Time");
98    /// ```
99    pub fn ternary(a: impl Into<String>, b: impl Into<String>, c: impl Into<String>) -> Self {
100        Self::new(vec![a.into(), b.into(), c.into()])
101    }
102
103    /// Get the component domains.
104    ///
105    /// # Examples
106    ///
107    /// ```rust
108    /// use tensorlogic_adapters::ProductDomain;
109    ///
110    /// let product = ProductDomain::binary("A", "B");
111    /// assert_eq!(product.components(), &["A", "B"]);
112    /// ```
113    pub fn components(&self) -> &[String] {
114        &self.components
115    }
116
117    /// Get the arity (number of components) of this product.
118    ///
119    /// # Examples
120    ///
121    /// ```rust
122    /// use tensorlogic_adapters::ProductDomain;
123    ///
124    /// let product = ProductDomain::ternary("A", "B", "C");
125    /// assert_eq!(product.arity(), 3);
126    /// ```
127    pub fn arity(&self) -> usize {
128        self.components.len()
129    }
130
131    /// Compute the cardinality of this product domain.
132    ///
133    /// Returns the product of component cardinalities, or an error if
134    /// any component domain is not found in the symbol table.
135    ///
136    /// # Examples
137    ///
138    /// ```rust
139    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, ProductDomain};
140    ///
141    /// let mut table = SymbolTable::new();
142    /// table.add_domain(DomainInfo::new("A", 10)).unwrap();
143    /// table.add_domain(DomainInfo::new("B", 20)).unwrap();
144    ///
145    /// let product = ProductDomain::binary("A", "B");
146    /// assert_eq!(product.cardinality(&table).unwrap(), 200);
147    /// ```
148    pub fn cardinality(&self, table: &SymbolTable) -> Result<usize, AdapterError> {
149        let mut result = 1_usize;
150        for component in &self.components {
151            let domain = table
152                .get_domain(component)
153                .ok_or_else(|| AdapterError::UnknownDomain(component.clone()))?;
154            result = result.checked_mul(domain.cardinality).ok_or_else(|| {
155                AdapterError::InvalidCardinality(format!(
156                    "Cardinality overflow in product domain: {}",
157                    self
158                ))
159            })?;
160        }
161        Ok(result)
162    }
163
164    /// Check if all component domains exist in the symbol table.
165    ///
166    /// # Examples
167    ///
168    /// ```rust
169    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, ProductDomain};
170    ///
171    /// let mut table = SymbolTable::new();
172    /// table.add_domain(DomainInfo::new("A", 10)).unwrap();
173    /// table.add_domain(DomainInfo::new("B", 20)).unwrap();
174    ///
175    /// let product = ProductDomain::binary("A", "B");
176    /// assert!(product.validate(&table).is_ok());
177    ///
178    /// let invalid = ProductDomain::binary("A", "Unknown");
179    /// assert!(invalid.validate(&table).is_err());
180    /// ```
181    pub fn validate(&self, table: &SymbolTable) -> Result<(), AdapterError> {
182        for component in &self.components {
183            if table.get_domain(component).is_none() {
184                return Err(AdapterError::UnknownDomain(component.clone()));
185            }
186        }
187        Ok(())
188    }
189
190    /// Project to a specific component by index.
191    ///
192    /// Returns the domain name of the component at the given index.
193    ///
194    /// # Examples
195    ///
196    /// ```rust
197    /// use tensorlogic_adapters::ProductDomain;
198    ///
199    /// let product = ProductDomain::ternary("A", "B", "C");
200    /// assert_eq!(product.project(0), Some("A"));
201    /// assert_eq!(product.project(1), Some("B"));
202    /// assert_eq!(product.project(2), Some("C"));
203    /// assert_eq!(product.project(3), None);
204    /// ```
205    pub fn project(&self, index: usize) -> Option<&str> {
206        self.components.get(index).map(|s| s.as_str())
207    }
208
209    /// Get a subproduct by slicing component indices.
210    ///
211    /// # Examples
212    ///
213    /// ```rust
214    /// use tensorlogic_adapters::ProductDomain;
215    ///
216    /// let product = ProductDomain::new(vec![
217    ///     "A".to_string(),
218    ///     "B".to_string(),
219    ///     "C".to_string(),
220    ///     "D".to_string()
221    /// ]);
222    ///
223    /// // Get middle two components (B × C)
224    /// let sub = product.slice(1, 3).unwrap();
225    /// assert_eq!(sub.components(), &["B", "C"]);
226    /// ```
227    pub fn slice(&self, start: usize, end: usize) -> Result<ProductDomain, AdapterError> {
228        if start >= end || end > self.components.len() {
229            return Err(AdapterError::InvalidOperation(format!(
230                "Invalid slice indices: {}..{} for product of arity {}",
231                start,
232                end,
233                self.components.len()
234            )));
235        }
236        let components = self.components[start..end].to_vec();
237        Ok(ProductDomain::new(components))
238    }
239
240    /// Extend this product with additional components.
241    ///
242    /// # Examples
243    ///
244    /// ```rust
245    /// use tensorlogic_adapters::ProductDomain;
246    ///
247    /// let mut product = ProductDomain::binary("A", "B");
248    /// product.extend(vec!["C".to_string(), "D".to_string()]);
249    /// assert_eq!(product.arity(), 4);
250    /// ```
251    pub fn extend(&mut self, mut additional: Vec<String>) {
252        self.components.append(&mut additional);
253    }
254}
255
256impl fmt::Display for ProductDomain {
257    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258        write!(f, "{}", self.components.join(" × "))
259    }
260}
261
262impl From<Vec<String>> for ProductDomain {
263    fn from(components: Vec<String>) -> Self {
264        Self::new(components)
265    }
266}
267
268/// Extension trait for SymbolTable to support product domains.
269pub trait ProductDomainExt {
270    /// Add a product domain to the symbol table.
271    ///
272    /// The product domain's cardinality is computed from its components.
273    ///
274    /// # Examples
275    ///
276    /// ```rust
277    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, ProductDomain, ProductDomainExt};
278    ///
279    /// let mut table = SymbolTable::new();
280    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
281    /// table.add_domain(DomainInfo::new("Location", 50)).unwrap();
282    ///
283    /// let product = ProductDomain::binary("Person", "Location");
284    /// table.add_product_domain("PersonAtLocation", product).unwrap();
285    ///
286    /// let domain = table.get_domain("PersonAtLocation").unwrap();
287    /// assert_eq!(domain.cardinality, 5000);
288    /// ```
289    fn add_product_domain(
290        &mut self,
291        name: impl Into<String>,
292        product: ProductDomain,
293    ) -> Result<(), AdapterError>;
294
295    /// Get a product domain by name.
296    ///
297    /// Returns `None` if the domain doesn't exist or is not a product domain.
298    fn get_product_domain(&self, name: &str) -> Option<&ProductDomain>;
299
300    /// List all product domains in the symbol table.
301    fn list_product_domains(&self) -> Vec<(&str, &ProductDomain)>;
302}
303
304impl ProductDomainExt for SymbolTable {
305    fn add_product_domain(
306        &mut self,
307        name: impl Into<String>,
308        product: ProductDomain,
309    ) -> Result<(), AdapterError> {
310        let name = name.into();
311
312        // Validate that all component domains exist
313        product.validate(self)?;
314
315        // Compute cardinality
316        let cardinality = product.cardinality(self)?;
317
318        // Create domain info with product type metadata
319        let mut domain_info = DomainInfo::new(&name, cardinality);
320        domain_info.description = Some(format!("Product domain: {}", product));
321
322        // Store product domain metadata (we'll use a custom attribute)
323        if let Some(ref mut meta) = domain_info.metadata {
324            let components_json = serde_json::to_string(&product.components).map_err(|e| {
325                AdapterError::InvalidOperation(format!(
326                    "Failed to serialize product components: {}",
327                    e
328                ))
329            })?;
330            meta.set_attribute("product_components", &components_json);
331        }
332
333        self.add_domain(domain_info)
334            .map_err(|_| AdapterError::DuplicateDomain(name.clone()))?;
335        Ok(())
336    }
337
338    fn get_product_domain(&self, name: &str) -> Option<&ProductDomain> {
339        // This is a simplified implementation
340        // In a real implementation, we'd store ProductDomain instances separately
341        // For now, we return None as a placeholder
342        let _domain = self.get_domain(name)?;
343        None
344    }
345
346    fn list_product_domains(&self) -> Vec<(&str, &ProductDomain)> {
347        // Simplified implementation
348        Vec::new()
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn test_binary_product() {
358        let product = ProductDomain::binary("A", "B");
359        assert_eq!(product.arity(), 2);
360        assert_eq!(product.components(), &["A", "B"]);
361        assert_eq!(product.to_string(), "A × B");
362    }
363
364    #[test]
365    fn test_ternary_product() {
366        let product = ProductDomain::ternary("A", "B", "C");
367        assert_eq!(product.arity(), 3);
368        assert_eq!(product.to_string(), "A × B × C");
369    }
370
371    #[test]
372    fn test_cardinality() {
373        let mut table = SymbolTable::new();
374        table.add_domain(DomainInfo::new("A", 10)).unwrap();
375        table.add_domain(DomainInfo::new("B", 20)).unwrap();
376        table.add_domain(DomainInfo::new("C", 5)).unwrap();
377
378        let product = ProductDomain::ternary("A", "B", "C");
379        assert_eq!(product.cardinality(&table).unwrap(), 1000);
380    }
381
382    #[test]
383    fn test_validate_success() {
384        let mut table = SymbolTable::new();
385        table.add_domain(DomainInfo::new("A", 10)).unwrap();
386        table.add_domain(DomainInfo::new("B", 20)).unwrap();
387
388        let product = ProductDomain::binary("A", "B");
389        assert!(product.validate(&table).is_ok());
390    }
391
392    #[test]
393    fn test_validate_unknown_domain() {
394        let mut table = SymbolTable::new();
395        table.add_domain(DomainInfo::new("A", 10)).unwrap();
396
397        let product = ProductDomain::binary("A", "Unknown");
398        assert!(product.validate(&table).is_err());
399    }
400
401    #[test]
402    fn test_project() {
403        let product = ProductDomain::ternary("A", "B", "C");
404        assert_eq!(product.project(0), Some("A"));
405        assert_eq!(product.project(1), Some("B"));
406        assert_eq!(product.project(2), Some("C"));
407        assert_eq!(product.project(3), None);
408    }
409
410    #[test]
411    fn test_slice() {
412        let product = ProductDomain::new(vec![
413            "A".to_string(),
414            "B".to_string(),
415            "C".to_string(),
416            "D".to_string(),
417        ]);
418
419        let sub = product.slice(1, 3).unwrap();
420        assert_eq!(sub.components(), &["B", "C"]);
421        assert_eq!(sub.to_string(), "B × C");
422    }
423
424    #[test]
425    fn test_slice_invalid() {
426        let product = ProductDomain::binary("A", "B");
427        assert!(product.slice(0, 3).is_err());
428        assert!(product.slice(2, 1).is_err());
429    }
430
431    #[test]
432    fn test_extend() {
433        let mut product = ProductDomain::binary("A", "B");
434        product.extend(vec!["C".to_string(), "D".to_string()]);
435        assert_eq!(product.arity(), 4);
436        assert_eq!(product.to_string(), "A × B × C × D");
437    }
438
439    #[test]
440    fn test_add_product_domain() {
441        let mut table = SymbolTable::new();
442        table.add_domain(DomainInfo::new("Person", 100)).unwrap();
443        table.add_domain(DomainInfo::new("Location", 50)).unwrap();
444
445        let product = ProductDomain::binary("Person", "Location");
446        table
447            .add_product_domain("PersonAtLocation", product)
448            .unwrap();
449
450        let domain = table.get_domain("PersonAtLocation").unwrap();
451        assert_eq!(domain.cardinality, 5000);
452        assert!(domain
453            .description
454            .as_ref()
455            .unwrap()
456            .contains("Product domain"));
457    }
458
459    #[test]
460    #[should_panic(expected = "Product domain must have at least 2 components")]
461    fn test_invalid_single_component() {
462        ProductDomain::new(vec!["A".to_string()]);
463    }
464
465    #[test]
466    fn test_nested_product() {
467        let mut table = SymbolTable::new();
468        table.add_domain(DomainInfo::new("A", 10)).unwrap();
469        table.add_domain(DomainInfo::new("B", 20)).unwrap();
470        table.add_domain(DomainInfo::new("C", 5)).unwrap();
471
472        // Create (A × B)
473        let ab = ProductDomain::binary("A", "B");
474        table.add_product_domain("AB", ab).unwrap();
475
476        // Create (AB × C)
477        let abc = ProductDomain::binary("AB", "C");
478        assert_eq!(abc.cardinality(&table).unwrap(), 1000);
479    }
480
481    #[test]
482    fn test_display() {
483        let product = ProductDomain::new(vec![
484            "Person".to_string(),
485            "Location".to_string(),
486            "Time".to_string(),
487        ]);
488        assert_eq!(format!("{}", product), "Person × Location × Time");
489    }
490
491    #[test]
492    fn test_from_vec() {
493        let components = vec!["A".to_string(), "B".to_string()];
494        let product: ProductDomain = components.into();
495        assert_eq!(product.arity(), 2);
496    }
497}