tensorlogic_adapters/utilities.rs
1//! Advanced utility functions for tensorlogic-adapters.
2//!
3//! This module provides helpful utility functions for common operations
4//! on symbol tables, domains, predicates, and related structures.
5
6use crate::{DomainInfo, PredicateInfo, SchemaStatistics, SymbolTable, ValidationReport};
7use anyhow::Result;
8use std::collections::{HashMap, HashSet};
9
10/// Batch operations for efficient bulk processing.
11pub struct BatchOperations;
12
13impl BatchOperations {
14 /// Add multiple domains at once with validation.
15 ///
16 /// # Examples
17 ///
18 /// ```
19 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, BatchOperations};
20 ///
21 /// let mut table = SymbolTable::new();
22 /// let domains = vec![
23 /// DomainInfo::new("Person", 100),
24 /// DomainInfo::new("Organization", 50),
25 /// ];
26 ///
27 /// let result = BatchOperations::add_domains(&mut table, domains);
28 /// assert!(result.is_ok());
29 /// assert_eq!(table.domains.len(), 2);
30 /// ```
31 pub fn add_domains(table: &mut SymbolTable, domains: Vec<DomainInfo>) -> Result<()> {
32 for domain in domains {
33 table.add_domain(domain)?;
34 }
35 Ok(())
36 }
37
38 /// Add multiple predicates at once with validation.
39 ///
40 /// # Examples
41 ///
42 /// ```
43 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo, BatchOperations};
44 ///
45 /// let mut table = SymbolTable::new();
46 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
47 ///
48 /// let predicates = vec![
49 /// PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()]),
50 /// PredicateInfo::new("age", vec!["Person".to_string()]),
51 /// ];
52 ///
53 /// let result = BatchOperations::add_predicates(&mut table, predicates);
54 /// assert!(result.is_ok());
55 /// assert_eq!(table.predicates.len(), 2);
56 /// ```
57 pub fn add_predicates(table: &mut SymbolTable, predicates: Vec<PredicateInfo>) -> Result<()> {
58 for predicate in predicates {
59 table.add_predicate(predicate)?;
60 }
61 Ok(())
62 }
63
64 /// Bind multiple variables at once.
65 ///
66 /// # Examples
67 ///
68 /// ```
69 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, BatchOperations};
70 /// use std::collections::HashMap;
71 ///
72 /// let mut table = SymbolTable::new();
73 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
74 ///
75 /// let mut bindings = HashMap::new();
76 /// bindings.insert("x".to_string(), "Person".to_string());
77 /// bindings.insert("y".to_string(), "Person".to_string());
78 ///
79 /// let result = BatchOperations::bind_variables(&mut table, bindings);
80 /// assert!(result.is_ok());
81 /// assert_eq!(table.variables.len(), 2);
82 /// ```
83 pub fn bind_variables(
84 table: &mut SymbolTable,
85 bindings: HashMap<String, String>,
86 ) -> Result<()> {
87 for (var, domain) in bindings {
88 table.bind_variable(var, domain)?;
89 }
90 Ok(())
91 }
92}
93
94/// Conversion utilities for different formats.
95pub struct ConversionUtils;
96
97impl ConversionUtils {
98 /// Convert a symbol table to a compact summary string.
99 ///
100 /// # Examples
101 ///
102 /// ```
103 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, ConversionUtils};
104 ///
105 /// let mut table = SymbolTable::new();
106 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
107 ///
108 /// let summary = ConversionUtils::to_summary(&table);
109 /// assert!(summary.contains("Domains: 1"));
110 /// ```
111 pub fn to_summary(table: &SymbolTable) -> String {
112 format!(
113 "SymbolTable Summary:\n Domains: {}\n Predicates: {}\n Variables: {}",
114 table.domains.len(),
115 table.predicates.len(),
116 table.variables.len()
117 )
118 }
119
120 /// Extract domain names as a sorted vector.
121 ///
122 /// # Examples
123 ///
124 /// ```
125 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, ConversionUtils};
126 ///
127 /// let mut table = SymbolTable::new();
128 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
129 /// table.add_domain(DomainInfo::new("Organization", 50)).unwrap();
130 ///
131 /// let names = ConversionUtils::extract_domain_names(&table);
132 /// assert_eq!(names, vec!["Organization", "Person"]);
133 /// ```
134 pub fn extract_domain_names(table: &SymbolTable) -> Vec<String> {
135 let mut names: Vec<String> = table.domains.keys().cloned().collect();
136 names.sort();
137 names
138 }
139
140 /// Extract predicate names as a sorted vector.
141 ///
142 /// # Examples
143 ///
144 /// ```
145 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo, ConversionUtils};
146 ///
147 /// let mut table = SymbolTable::new();
148 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
149 /// table.add_predicate(PredicateInfo::new("knows", vec!["Person".to_string()])).unwrap();
150 ///
151 /// let names = ConversionUtils::extract_predicate_names(&table);
152 /// assert_eq!(names, vec!["knows"]);
153 /// ```
154 pub fn extract_predicate_names(table: &SymbolTable) -> Vec<String> {
155 let mut names: Vec<String> = table.predicates.keys().cloned().collect();
156 names.sort();
157 names
158 }
159}
160
161/// Query utilities for advanced filtering and searching.
162pub struct QueryUtils;
163
164impl QueryUtils {
165 /// Find all predicates that use a specific domain.
166 ///
167 /// # Examples
168 ///
169 /// ```
170 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo, QueryUtils};
171 ///
172 /// let mut table = SymbolTable::new();
173 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
174 /// table.add_predicate(PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()])).unwrap();
175 ///
176 /// let predicates = QueryUtils::find_predicates_using_domain(&table, "Person");
177 /// assert_eq!(predicates.len(), 1);
178 /// assert_eq!(predicates[0].name, "knows");
179 /// ```
180 pub fn find_predicates_using_domain(
181 table: &SymbolTable,
182 domain_name: &str,
183 ) -> Vec<PredicateInfo> {
184 table
185 .predicates
186 .values()
187 .filter(|p| p.arg_domains.contains(&domain_name.to_string()))
188 .cloned()
189 .collect()
190 }
191
192 /// Find all domains that are never used by any predicate.
193 ///
194 /// # Examples
195 ///
196 /// ```
197 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, QueryUtils};
198 ///
199 /// let mut table = SymbolTable::new();
200 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
201 /// table.add_domain(DomainInfo::new("Unused", 10)).unwrap();
202 ///
203 /// let unused = QueryUtils::find_unused_domains(&table);
204 /// assert_eq!(unused.len(), 2); // Both are unused as no predicates defined
205 /// ```
206 pub fn find_unused_domains(table: &SymbolTable) -> Vec<String> {
207 let used_domains: HashSet<&String> = table
208 .predicates
209 .values()
210 .flat_map(|p| &p.arg_domains)
211 .collect();
212
213 table
214 .domains
215 .keys()
216 .filter(|d| !used_domains.contains(d))
217 .cloned()
218 .collect()
219 }
220
221 /// Find predicates with a specific arity.
222 ///
223 /// # Examples
224 ///
225 /// ```
226 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo, QueryUtils};
227 ///
228 /// let mut table = SymbolTable::new();
229 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
230 /// table.add_predicate(PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()])).unwrap();
231 /// table.add_predicate(PredicateInfo::new("age", vec!["Person".to_string()])).unwrap();
232 ///
233 /// let binary = QueryUtils::find_predicates_by_arity(&table, 2);
234 /// assert_eq!(binary.len(), 1);
235 /// assert_eq!(binary[0].name, "knows");
236 /// ```
237 pub fn find_predicates_by_arity(table: &SymbolTable, arity: usize) -> Vec<PredicateInfo> {
238 table
239 .predicates
240 .values()
241 .filter(|p| p.arg_domains.len() == arity)
242 .cloned()
243 .collect()
244 }
245
246 /// Group predicates by their arity.
247 ///
248 /// # Examples
249 ///
250 /// ```
251 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo, QueryUtils};
252 ///
253 /// let mut table = SymbolTable::new();
254 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
255 /// table.add_predicate(PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()])).unwrap();
256 /// table.add_predicate(PredicateInfo::new("age", vec!["Person".to_string()])).unwrap();
257 ///
258 /// let grouped = QueryUtils::group_predicates_by_arity(&table);
259 /// assert_eq!(grouped.get(&1).unwrap().len(), 1);
260 /// assert_eq!(grouped.get(&2).unwrap().len(), 1);
261 /// ```
262 pub fn group_predicates_by_arity(table: &SymbolTable) -> HashMap<usize, Vec<PredicateInfo>> {
263 let mut groups: HashMap<usize, Vec<PredicateInfo>> = HashMap::new();
264 for predicate in table.predicates.values() {
265 let arity = predicate.arg_domains.len();
266 groups.entry(arity).or_default().push(predicate.clone());
267 }
268 groups
269 }
270}
271
272/// Validation utilities for enhanced checking.
273pub struct ValidationUtils;
274
275impl ValidationUtils {
276 /// Quick validation check (returns true if valid).
277 ///
278 /// # Examples
279 ///
280 /// ```
281 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, ValidationUtils};
282 ///
283 /// let mut table = SymbolTable::new();
284 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
285 ///
286 /// assert!(ValidationUtils::is_valid(&table));
287 /// ```
288 pub fn is_valid(table: &SymbolTable) -> bool {
289 // Check if all predicates reference existing domains
290 for predicate in table.predicates.values() {
291 for domain in &predicate.arg_domains {
292 if !table.domains.contains_key(domain) {
293 return false;
294 }
295 }
296 }
297
298 // Check if all variables reference existing domains
299 for domain in table.variables.values() {
300 if !table.domains.contains_key(domain) {
301 return false;
302 }
303 }
304
305 true
306 }
307
308 /// Get detailed validation report.
309 ///
310 /// # Examples
311 ///
312 /// ```
313 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, ValidationUtils};
314 ///
315 /// let mut table = SymbolTable::new();
316 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
317 ///
318 /// let report = ValidationUtils::detailed_validation(&table);
319 /// assert!(report.is_ok());
320 /// ```
321 pub fn detailed_validation(table: &SymbolTable) -> Result<ValidationReport> {
322 use crate::SchemaValidator;
323 let validator = SchemaValidator::new(table);
324 validator.validate()
325 }
326
327 /// Check if a specific domain is used.
328 ///
329 /// # Examples
330 ///
331 /// ```
332 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo, ValidationUtils};
333 ///
334 /// let mut table = SymbolTable::new();
335 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
336 /// table.add_predicate(PredicateInfo::new("knows", vec!["Person".to_string()])).unwrap();
337 ///
338 /// assert!(ValidationUtils::is_domain_used(&table, "Person"));
339 /// assert!(!ValidationUtils::is_domain_used(&table, "Nonexistent"));
340 /// ```
341 pub fn is_domain_used(table: &SymbolTable, domain_name: &str) -> bool {
342 // Check predicates
343 for predicate in table.predicates.values() {
344 if predicate.arg_domains.contains(&domain_name.to_string()) {
345 return true;
346 }
347 }
348
349 // Check variables
350 for domain in table.variables.values() {
351 if domain == domain_name {
352 return true;
353 }
354 }
355
356 false
357 }
358}
359
360/// Statistics utilities for metrics collection.
361pub struct StatisticsUtils;
362
363impl StatisticsUtils {
364 /// Compute comprehensive statistics for a symbol table.
365 ///
366 /// # Examples
367 ///
368 /// ```
369 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, StatisticsUtils};
370 ///
371 /// let mut table = SymbolTable::new();
372 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
373 ///
374 /// let stats = StatisticsUtils::compute_statistics(&table);
375 /// assert!(stats.is_ok());
376 /// ```
377 pub fn compute_statistics(table: &SymbolTable) -> Result<SchemaStatistics> {
378 Ok(SchemaStatistics::compute(table))
379 }
380
381 /// Get total domain cardinality.
382 ///
383 /// # Examples
384 ///
385 /// ```
386 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, StatisticsUtils};
387 ///
388 /// let mut table = SymbolTable::new();
389 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
390 /// table.add_domain(DomainInfo::new("Organization", 50)).unwrap();
391 ///
392 /// let total = StatisticsUtils::total_domain_cardinality(&table);
393 /// assert_eq!(total, 150);
394 /// ```
395 pub fn total_domain_cardinality(table: &SymbolTable) -> usize {
396 table.domains.values().map(|d| d.cardinality).sum()
397 }
398
399 /// Get average predicate arity.
400 ///
401 /// # Examples
402 ///
403 /// ```
404 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo, StatisticsUtils};
405 ///
406 /// let mut table = SymbolTable::new();
407 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
408 /// table.add_predicate(PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()])).unwrap();
409 /// table.add_predicate(PredicateInfo::new("age", vec!["Person".to_string()])).unwrap();
410 ///
411 /// let avg = StatisticsUtils::average_predicate_arity(&table);
412 /// assert_eq!(avg, 1.5);
413 /// ```
414 pub fn average_predicate_arity(table: &SymbolTable) -> f64 {
415 if table.predicates.is_empty() {
416 return 0.0;
417 }
418
419 let total: usize = table.predicates.values().map(|p| p.arg_domains.len()).sum();
420
421 total as f64 / table.predicates.len() as f64
422 }
423
424 /// Get domain usage counts.
425 ///
426 /// # Examples
427 ///
428 /// ```
429 /// use tensorlogic_adapters::{SymbolTable, DomainInfo, PredicateInfo, StatisticsUtils};
430 ///
431 /// let mut table = SymbolTable::new();
432 /// table.add_domain(DomainInfo::new("Person", 100)).unwrap();
433 /// table.add_predicate(PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()])).unwrap();
434 ///
435 /// let usage = StatisticsUtils::domain_usage_counts(&table);
436 /// assert_eq!(usage.get("Person"), Some(&2));
437 /// ```
438 pub fn domain_usage_counts(table: &SymbolTable) -> HashMap<String, usize> {
439 let mut counts: HashMap<String, usize> = HashMap::new();
440
441 for predicate in table.predicates.values() {
442 for domain in &predicate.arg_domains {
443 *counts.entry(domain.clone()).or_insert(0) += 1;
444 }
445 }
446
447 for domain in table.variables.values() {
448 *counts.entry(domain.clone()).or_insert(0) += 1;
449 }
450
451 counts
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458
459 fn create_test_table() -> SymbolTable {
460 let mut table = SymbolTable::new();
461 table.add_domain(DomainInfo::new("Person", 100)).unwrap();
462 table
463 .add_domain(DomainInfo::new("Organization", 50))
464 .unwrap();
465 table
466 .add_predicate(PredicateInfo::new(
467 "knows",
468 vec!["Person".to_string(), "Person".to_string()],
469 ))
470 .unwrap();
471 table
472 .add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
473 .unwrap();
474 table.bind_variable("x", "Person").unwrap();
475 table
476 }
477
478 #[test]
479 fn test_batch_add_domains() {
480 let mut table = SymbolTable::new();
481 let domains = vec![
482 DomainInfo::new("Person", 100),
483 DomainInfo::new("Organization", 50),
484 ];
485
486 BatchOperations::add_domains(&mut table, domains).unwrap();
487 assert_eq!(table.domains.len(), 2);
488 }
489
490 #[test]
491 fn test_conversion_summary() {
492 let table = create_test_table();
493 let summary = ConversionUtils::to_summary(&table);
494 assert!(summary.contains("Domains: 2"));
495 assert!(summary.contains("Predicates: 2"));
496 }
497
498 #[test]
499 fn test_query_predicates_using_domain() {
500 let table = create_test_table();
501 let predicates = QueryUtils::find_predicates_using_domain(&table, "Person");
502 assert_eq!(predicates.len(), 2);
503 }
504
505 #[test]
506 fn test_query_by_arity() {
507 let table = create_test_table();
508 let unary = QueryUtils::find_predicates_by_arity(&table, 1);
509 let binary = QueryUtils::find_predicates_by_arity(&table, 2);
510 assert_eq!(unary.len(), 1);
511 assert_eq!(binary.len(), 1);
512 }
513
514 #[test]
515 fn test_validation_is_valid() {
516 let table = create_test_table();
517 assert!(ValidationUtils::is_valid(&table));
518 }
519
520 #[test]
521 fn test_statistics_total_cardinality() {
522 let table = create_test_table();
523 let total = StatisticsUtils::total_domain_cardinality(&table);
524 assert_eq!(total, 150);
525 }
526
527 #[test]
528 fn test_statistics_average_arity() {
529 let table = create_test_table();
530 let avg = StatisticsUtils::average_predicate_arity(&table);
531 assert_eq!(avg, 1.5);
532 }
533
534 #[test]
535 fn test_domain_usage_counts() {
536 let table = create_test_table();
537 let counts = StatisticsUtils::domain_usage_counts(&table);
538 assert_eq!(counts.get("Person"), Some(&4)); // 2 in knows + 1 in age + 1 variable
539 }
540
541 #[test]
542 fn test_group_by_arity() {
543 let table = create_test_table();
544 let groups = QueryUtils::group_predicates_by_arity(&table);
545 assert_eq!(groups.len(), 2);
546 assert!(groups.contains_key(&1));
547 assert!(groups.contains_key(&2));
548 }
549
550 #[test]
551 fn test_extract_names() {
552 let table = create_test_table();
553 let domain_names = ConversionUtils::extract_domain_names(&table);
554 let predicate_names = ConversionUtils::extract_predicate_names(&table);
555
556 assert_eq!(domain_names.len(), 2);
557 assert_eq!(predicate_names.len(), 2);
558 assert!(domain_names.contains(&"Person".to_string()));
559 assert!(predicate_names.contains(&"knows".to_string()));
560 }
561}