Skip to main content

tensorlogic_adapters/
utilities.rs

1//! Advanced utility functions for tensorlogic-adapters.
2//!
3//! This module provides helpful utility functions for common operations
4//! on symbol tables, domains, predicates, and related structures.
5
6use crate::{DomainInfo, PredicateInfo, SchemaStatistics, SymbolTable, ValidationReport};
7use anyhow::Result;
8use std::collections::{HashMap, HashSet};
9
10/// Batch operations for efficient bulk processing.
11pub struct BatchOperations;
12
13impl BatchOperations {
14    /// Add multiple domains at once with validation.
15    ///
16    /// # Examples
17    ///
18    /// ```
19    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, BatchOperations};
20    ///
21    /// let mut table = SymbolTable::new();
22    /// let domains = vec![
23    ///     DomainInfo::new("Person", 100),
24    ///     DomainInfo::new("Organization", 50),
25    /// ];
26    ///
27    /// let result = BatchOperations::add_domains(&mut table, domains);
28    /// assert!(result.is_ok());
29    /// assert_eq!(table.domains.len(), 2);
30    /// ```
31    pub fn add_domains(table: &mut SymbolTable, domains: Vec<DomainInfo>) -> Result<()> {
32        for domain in domains {
33            table.add_domain(domain)?;
34        }
35        Ok(())
36    }
37
38    /// Add multiple predicates at once with validation.
39    ///
40    /// # Examples
41    ///
42    /// ```
43    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo, BatchOperations};
44    ///
45    /// let mut table = SymbolTable::new();
46    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
47    ///
48    /// let predicates = vec![
49    ///     PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()]),
50    ///     PredicateInfo::new("age", vec!["Person".to_string()]),
51    /// ];
52    ///
53    /// let result = BatchOperations::add_predicates(&mut table, predicates);
54    /// assert!(result.is_ok());
55    /// assert_eq!(table.predicates.len(), 2);
56    /// ```
57    pub fn add_predicates(table: &mut SymbolTable, predicates: Vec<PredicateInfo>) -> Result<()> {
58        for predicate in predicates {
59            table.add_predicate(predicate)?;
60        }
61        Ok(())
62    }
63
64    /// Bind multiple variables at once.
65    ///
66    /// # Examples
67    ///
68    /// ```
69    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, BatchOperations};
70    /// use std::collections::HashMap;
71    ///
72    /// let mut table = SymbolTable::new();
73    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
74    ///
75    /// let mut bindings = HashMap::new();
76    /// bindings.insert("x".to_string(), "Person".to_string());
77    /// bindings.insert("y".to_string(), "Person".to_string());
78    ///
79    /// let result = BatchOperations::bind_variables(&mut table, bindings);
80    /// assert!(result.is_ok());
81    /// assert_eq!(table.variables.len(), 2);
82    /// ```
83    pub fn bind_variables(
84        table: &mut SymbolTable,
85        bindings: HashMap<String, String>,
86    ) -> Result<()> {
87        for (var, domain) in bindings {
88            table.bind_variable(var, domain)?;
89        }
90        Ok(())
91    }
92}
93
94/// Conversion utilities for different formats.
95pub struct ConversionUtils;
96
97impl ConversionUtils {
98    /// Convert a symbol table to a compact summary string.
99    ///
100    /// # Examples
101    ///
102    /// ```
103    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, ConversionUtils};
104    ///
105    /// let mut table = SymbolTable::new();
106    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
107    ///
108    /// let summary = ConversionUtils::to_summary(&table);
109    /// assert!(summary.contains("Domains: 1"));
110    /// ```
111    pub fn to_summary(table: &SymbolTable) -> String {
112        format!(
113            "SymbolTable Summary:\n  Domains: {}\n  Predicates: {}\n  Variables: {}",
114            table.domains.len(),
115            table.predicates.len(),
116            table.variables.len()
117        )
118    }
119
120    /// Extract domain names as a sorted vector.
121    ///
122    /// # Examples
123    ///
124    /// ```
125    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, ConversionUtils};
126    ///
127    /// let mut table = SymbolTable::new();
128    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
129    /// table.add_domain(DomainInfo::new("Organization", 50)).unwrap();
130    ///
131    /// let names = ConversionUtils::extract_domain_names(&table);
132    /// assert_eq!(names, vec!["Organization", "Person"]);
133    /// ```
134    pub fn extract_domain_names(table: &SymbolTable) -> Vec<String> {
135        let mut names: Vec<String> = table.domains.keys().cloned().collect();
136        names.sort();
137        names
138    }
139
140    /// Extract predicate names as a sorted vector.
141    ///
142    /// # Examples
143    ///
144    /// ```
145    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo, ConversionUtils};
146    ///
147    /// let mut table = SymbolTable::new();
148    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
149    /// table.add_predicate(PredicateInfo::new("knows", vec!["Person".to_string()])).unwrap();
150    ///
151    /// let names = ConversionUtils::extract_predicate_names(&table);
152    /// assert_eq!(names, vec!["knows"]);
153    /// ```
154    pub fn extract_predicate_names(table: &SymbolTable) -> Vec<String> {
155        let mut names: Vec<String> = table.predicates.keys().cloned().collect();
156        names.sort();
157        names
158    }
159}
160
161/// Query utilities for advanced filtering and searching.
162pub struct QueryUtils;
163
164impl QueryUtils {
165    /// Find all predicates that use a specific domain.
166    ///
167    /// # Examples
168    ///
169    /// ```
170    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo, QueryUtils};
171    ///
172    /// let mut table = SymbolTable::new();
173    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
174    /// table.add_predicate(PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()])).unwrap();
175    ///
176    /// let predicates = QueryUtils::find_predicates_using_domain(&table, "Person");
177    /// assert_eq!(predicates.len(), 1);
178    /// assert_eq!(predicates[0].name, "knows");
179    /// ```
180    pub fn find_predicates_using_domain(
181        table: &SymbolTable,
182        domain_name: &str,
183    ) -> Vec<PredicateInfo> {
184        table
185            .predicates
186            .values()
187            .filter(|p| p.arg_domains.contains(&domain_name.to_string()))
188            .cloned()
189            .collect()
190    }
191
192    /// Find all domains that are never used by any predicate.
193    ///
194    /// # Examples
195    ///
196    /// ```
197    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, QueryUtils};
198    ///
199    /// let mut table = SymbolTable::new();
200    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
201    /// table.add_domain(DomainInfo::new("Unused", 10)).unwrap();
202    ///
203    /// let unused = QueryUtils::find_unused_domains(&table);
204    /// assert_eq!(unused.len(), 2); // Both are unused as no predicates defined
205    /// ```
206    pub fn find_unused_domains(table: &SymbolTable) -> Vec<String> {
207        let used_domains: HashSet<&String> = table
208            .predicates
209            .values()
210            .flat_map(|p| &p.arg_domains)
211            .collect();
212
213        table
214            .domains
215            .keys()
216            .filter(|d| !used_domains.contains(d))
217            .cloned()
218            .collect()
219    }
220
221    /// Find predicates with a specific arity.
222    ///
223    /// # Examples
224    ///
225    /// ```
226    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo, QueryUtils};
227    ///
228    /// let mut table = SymbolTable::new();
229    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
230    /// table.add_predicate(PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()])).unwrap();
231    /// table.add_predicate(PredicateInfo::new("age", vec!["Person".to_string()])).unwrap();
232    ///
233    /// let binary = QueryUtils::find_predicates_by_arity(&table, 2);
234    /// assert_eq!(binary.len(), 1);
235    /// assert_eq!(binary[0].name, "knows");
236    /// ```
237    pub fn find_predicates_by_arity(table: &SymbolTable, arity: usize) -> Vec<PredicateInfo> {
238        table
239            .predicates
240            .values()
241            .filter(|p| p.arg_domains.len() == arity)
242            .cloned()
243            .collect()
244    }
245
246    /// Group predicates by their arity.
247    ///
248    /// # Examples
249    ///
250    /// ```
251    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo, QueryUtils};
252    ///
253    /// let mut table = SymbolTable::new();
254    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
255    /// table.add_predicate(PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()])).unwrap();
256    /// table.add_predicate(PredicateInfo::new("age", vec!["Person".to_string()])).unwrap();
257    ///
258    /// let grouped = QueryUtils::group_predicates_by_arity(&table);
259    /// assert_eq!(grouped.get(&1).unwrap().len(), 1);
260    /// assert_eq!(grouped.get(&2).unwrap().len(), 1);
261    /// ```
262    pub fn group_predicates_by_arity(table: &SymbolTable) -> HashMap<usize, Vec<PredicateInfo>> {
263        let mut groups: HashMap<usize, Vec<PredicateInfo>> = HashMap::new();
264        for predicate in table.predicates.values() {
265            let arity = predicate.arg_domains.len();
266            groups.entry(arity).or_default().push(predicate.clone());
267        }
268        groups
269    }
270}
271
272/// Validation utilities for enhanced checking.
273pub struct ValidationUtils;
274
275impl ValidationUtils {
276    /// Quick validation check (returns true if valid).
277    ///
278    /// # Examples
279    ///
280    /// ```
281    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, ValidationUtils};
282    ///
283    /// let mut table = SymbolTable::new();
284    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
285    ///
286    /// assert!(ValidationUtils::is_valid(&table));
287    /// ```
288    pub fn is_valid(table: &SymbolTable) -> bool {
289        // Check if all predicates reference existing domains
290        for predicate in table.predicates.values() {
291            for domain in &predicate.arg_domains {
292                if !table.domains.contains_key(domain) {
293                    return false;
294                }
295            }
296        }
297
298        // Check if all variables reference existing domains
299        for domain in table.variables.values() {
300            if !table.domains.contains_key(domain) {
301                return false;
302            }
303        }
304
305        true
306    }
307
308    /// Get detailed validation report.
309    ///
310    /// # Examples
311    ///
312    /// ```
313    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, ValidationUtils};
314    ///
315    /// let mut table = SymbolTable::new();
316    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
317    ///
318    /// let report = ValidationUtils::detailed_validation(&table);
319    /// assert!(report.is_ok());
320    /// ```
321    pub fn detailed_validation(table: &SymbolTable) -> Result<ValidationReport> {
322        use crate::SchemaValidator;
323        let validator = SchemaValidator::new(table);
324        validator.validate()
325    }
326
327    /// Check if a specific domain is used.
328    ///
329    /// # Examples
330    ///
331    /// ```
332    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo, ValidationUtils};
333    ///
334    /// let mut table = SymbolTable::new();
335    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
336    /// table.add_predicate(PredicateInfo::new("knows", vec!["Person".to_string()])).unwrap();
337    ///
338    /// assert!(ValidationUtils::is_domain_used(&table, "Person"));
339    /// assert!(!ValidationUtils::is_domain_used(&table, "Nonexistent"));
340    /// ```
341    pub fn is_domain_used(table: &SymbolTable, domain_name: &str) -> bool {
342        // Check predicates
343        for predicate in table.predicates.values() {
344            if predicate.arg_domains.contains(&domain_name.to_string()) {
345                return true;
346            }
347        }
348
349        // Check variables
350        for domain in table.variables.values() {
351            if domain == domain_name {
352                return true;
353            }
354        }
355
356        false
357    }
358}
359
360/// Statistics utilities for metrics collection.
361pub struct StatisticsUtils;
362
363impl StatisticsUtils {
364    /// Compute comprehensive statistics for a symbol table.
365    ///
366    /// # Examples
367    ///
368    /// ```
369    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, StatisticsUtils};
370    ///
371    /// let mut table = SymbolTable::new();
372    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
373    ///
374    /// let stats = StatisticsUtils::compute_statistics(&table);
375    /// assert!(stats.is_ok());
376    /// ```
377    pub fn compute_statistics(table: &SymbolTable) -> Result<SchemaStatistics> {
378        Ok(SchemaStatistics::compute(table))
379    }
380
381    /// Get total domain cardinality.
382    ///
383    /// # Examples
384    ///
385    /// ```
386    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, StatisticsUtils};
387    ///
388    /// let mut table = SymbolTable::new();
389    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
390    /// table.add_domain(DomainInfo::new("Organization", 50)).unwrap();
391    ///
392    /// let total = StatisticsUtils::total_domain_cardinality(&table);
393    /// assert_eq!(total, 150);
394    /// ```
395    pub fn total_domain_cardinality(table: &SymbolTable) -> usize {
396        table.domains.values().map(|d| d.cardinality).sum()
397    }
398
399    /// Get average predicate arity.
400    ///
401    /// # Examples
402    ///
403    /// ```
404    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo, StatisticsUtils};
405    ///
406    /// let mut table = SymbolTable::new();
407    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
408    /// table.add_predicate(PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()])).unwrap();
409    /// table.add_predicate(PredicateInfo::new("age", vec!["Person".to_string()])).unwrap();
410    ///
411    /// let avg = StatisticsUtils::average_predicate_arity(&table);
412    /// assert_eq!(avg, 1.5);
413    /// ```
414    pub fn average_predicate_arity(table: &SymbolTable) -> f64 {
415        if table.predicates.is_empty() {
416            return 0.0;
417        }
418
419        let total: usize = table.predicates.values().map(|p| p.arg_domains.len()).sum();
420
421        total as f64 / table.predicates.len() as f64
422    }
423
424    /// Get domain usage counts.
425    ///
426    /// # Examples
427    ///
428    /// ```
429    /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo, StatisticsUtils};
430    ///
431    /// let mut table = SymbolTable::new();
432    /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
433    /// table.add_predicate(PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()])).unwrap();
434    ///
435    /// let usage = StatisticsUtils::domain_usage_counts(&table);
436    /// assert_eq!(usage.get("Person"), Some(&2));
437    /// ```
438    pub fn domain_usage_counts(table: &SymbolTable) -> HashMap<String, usize> {
439        let mut counts: HashMap<String, usize> = HashMap::new();
440
441        for predicate in table.predicates.values() {
442            for domain in &predicate.arg_domains {
443                *counts.entry(domain.clone()).or_insert(0) += 1;
444            }
445        }
446
447        for domain in table.variables.values() {
448            *counts.entry(domain.clone()).or_insert(0) += 1;
449        }
450
451        counts
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458
459    fn create_test_table() -> SymbolTable {
460        let mut table = SymbolTable::new();
461        table.add_domain(DomainInfo::new("Person", 100)).unwrap();
462        table
463            .add_domain(DomainInfo::new("Organization", 50))
464            .unwrap();
465        table
466            .add_predicate(PredicateInfo::new(
467                "knows",
468                vec!["Person".to_string(), "Person".to_string()],
469            ))
470            .unwrap();
471        table
472            .add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
473            .unwrap();
474        table.bind_variable("x", "Person").unwrap();
475        table
476    }
477
478    #[test]
479    fn test_batch_add_domains() {
480        let mut table = SymbolTable::new();
481        let domains = vec![
482            DomainInfo::new("Person", 100),
483            DomainInfo::new("Organization", 50),
484        ];
485
486        BatchOperations::add_domains(&mut table, domains).unwrap();
487        assert_eq!(table.domains.len(), 2);
488    }
489
490    #[test]
491    fn test_conversion_summary() {
492        let table = create_test_table();
493        let summary = ConversionUtils::to_summary(&table);
494        assert!(summary.contains("Domains: 2"));
495        assert!(summary.contains("Predicates: 2"));
496    }
497
498    #[test]
499    fn test_query_predicates_using_domain() {
500        let table = create_test_table();
501        let predicates = QueryUtils::find_predicates_using_domain(&table, "Person");
502        assert_eq!(predicates.len(), 2);
503    }
504
505    #[test]
506    fn test_query_by_arity() {
507        let table = create_test_table();
508        let unary = QueryUtils::find_predicates_by_arity(&table, 1);
509        let binary = QueryUtils::find_predicates_by_arity(&table, 2);
510        assert_eq!(unary.len(), 1);
511        assert_eq!(binary.len(), 1);
512    }
513
514    #[test]
515    fn test_validation_is_valid() {
516        let table = create_test_table();
517        assert!(ValidationUtils::is_valid(&table));
518    }
519
520    #[test]
521    fn test_statistics_total_cardinality() {
522        let table = create_test_table();
523        let total = StatisticsUtils::total_domain_cardinality(&table);
524        assert_eq!(total, 150);
525    }
526
527    #[test]
528    fn test_statistics_average_arity() {
529        let table = create_test_table();
530        let avg = StatisticsUtils::average_predicate_arity(&table);
531        assert_eq!(avg, 1.5);
532    }
533
534    #[test]
535    fn test_domain_usage_counts() {
536        let table = create_test_table();
537        let counts = StatisticsUtils::domain_usage_counts(&table);
538        assert_eq!(counts.get("Person"), Some(&4)); // 2 in knows + 1 in age + 1 variable
539    }
540
541    #[test]
542    fn test_group_by_arity() {
543        let table = create_test_table();
544        let groups = QueryUtils::group_predicates_by_arity(&table);
545        assert_eq!(groups.len(), 2);
546        assert!(groups.contains_key(&1));
547        assert!(groups.contains_key(&2));
548    }
549
550    #[test]
551    fn test_extract_names() {
552        let table = create_test_table();
553        let domain_names = ConversionUtils::extract_domain_names(&table);
554        let predicate_names = ConversionUtils::extract_predicate_names(&table);
555
556        assert_eq!(domain_names.len(), 2);
557        assert_eq!(predicate_names.len(), 2);
558        assert!(domain_names.contains(&"Person".to_string()));
559        assert!(predicate_names.contains(&"knows".to_string()));
560    }
561}