Skip to main content

tensorlogic_adapters/
signature_matcher.rs

1//! Optimized predicate signature matching.
2//!
3//! This module provides fast lookup and matching of predicates based on their
4//! signatures (arity and domain types). It uses indexed data structures for
5//! O(1) lookups instead of linear scans.
6
7use std::collections::HashMap;
8
9use crate::PredicateInfo;
10
11/// Indexed structure for fast predicate signature matching.
12///
13/// This provides O(1) lookup of predicates by:
14/// - Arity (number of arguments)
15/// - Exact signature (ordered domain types)
16/// - Domain patterns (unordered domain types)
17///
18/// # Example
19///
20/// ```rust
21/// use tensorlogic_adapters::{PredicateInfo, SignatureMatcher};
22///
23/// let mut matcher = SignatureMatcher::new();
24///
25/// let knows = PredicateInfo::new(
26///     "knows",
27///     vec!["Person".to_string(), "Person".to_string()]
28/// );
29/// matcher.add_predicate(&knows);
30///
31/// // Find by arity
32/// let arity_2 = matcher.find_by_arity(2);
33/// assert_eq!(arity_2.len(), 1);
34///
35/// // Find by exact signature
36/// let signature = vec!["Person".to_string(), "Person".to_string()];
37/// let matches = matcher.find_by_signature(&signature);
38/// assert_eq!(matches.len(), 1);
39/// assert_eq!(matches[0], "knows");
40/// ```
41#[derive(Clone, Debug, Default)]
42pub struct SignatureMatcher {
43    /// Index by arity: arity -> [predicate_names]
44    by_arity: HashMap<usize, Vec<String>>,
45    /// Index by exact signature: signature -> [predicate_names]
46    by_signature: HashMap<Vec<String>, Vec<String>>,
47    /// Index by sorted signature (for unordered matching): sorted_sig -> [predicate_names]
48    by_domain_set: HashMap<Vec<String>, Vec<String>>,
49    /// Full predicate information
50    predicates: HashMap<String, PredicateInfo>,
51}
52
53impl SignatureMatcher {
54    /// Create a new signature matcher.
55    pub fn new() -> Self {
56        Self::default()
57    }
58
59    /// Add a predicate to the matcher indices.
60    pub fn add_predicate(&mut self, pred: &PredicateInfo) {
61        let name = pred.name.clone();
62        let arity = pred.arg_domains.len();
63        let signature = pred.arg_domains.clone();
64
65        // Index by arity
66        self.by_arity.entry(arity).or_default().push(name.clone());
67
68        // Index by exact signature
69        self.by_signature
70            .entry(signature.clone())
71            .or_default()
72            .push(name.clone());
73
74        // Index by sorted domain set (for unordered matching)
75        let mut sorted_sig = signature.clone();
76        sorted_sig.sort();
77        self.by_domain_set
78            .entry(sorted_sig)
79            .or_default()
80            .push(name.clone());
81
82        // Store full predicate info
83        self.predicates.insert(name, pred.clone());
84    }
85
86    /// Remove a predicate from all indices.
87    pub fn remove_predicate(&mut self, name: &str) {
88        if let Some(pred) = self.predicates.remove(name) {
89            let arity = pred.arg_domains.len();
90            let signature = pred.arg_domains.clone();
91
92            // Remove from arity index
93            if let Some(names) = self.by_arity.get_mut(&arity) {
94                names.retain(|n| n != name);
95                if names.is_empty() {
96                    self.by_arity.remove(&arity);
97                }
98            }
99
100            // Remove from signature index
101            if let Some(names) = self.by_signature.get_mut(&signature) {
102                names.retain(|n| n != name);
103                if names.is_empty() {
104                    self.by_signature.remove(&signature);
105                }
106            }
107
108            // Remove from domain set index
109            let mut sorted_sig = signature;
110            sorted_sig.sort();
111            if let Some(names) = self.by_domain_set.get_mut(&sorted_sig) {
112                names.retain(|n| n != name);
113                if names.is_empty() {
114                    self.by_domain_set.remove(&sorted_sig);
115                }
116            }
117        }
118    }
119
120    /// Find all predicates with the given arity.
121    ///
122    /// # Example
123    ///
124    /// ```rust
125    /// use tensorlogic_adapters::{PredicateInfo, SignatureMatcher};
126    ///
127    /// let mut matcher = SignatureMatcher::new();
128    /// matcher.add_predicate(&PredicateInfo::new("knows", vec!["Person".into(), "Person".into()]));
129    /// matcher.add_predicate(&PredicateInfo::new("age", vec!["Person".into()]));
130    ///
131    /// let unary = matcher.find_by_arity(1);
132    /// assert_eq!(unary.len(), 1);
133    /// assert!(unary.contains(&"age".to_string()));
134    /// ```
135    pub fn find_by_arity(&self, arity: usize) -> Vec<String> {
136        self.by_arity.get(&arity).cloned().unwrap_or_default()
137    }
138
139    /// Find all predicates with the exact signature (ordered domain types).
140    ///
141    /// # Example
142    ///
143    /// ```rust
144    /// use tensorlogic_adapters::{PredicateInfo, SignatureMatcher};
145    ///
146    /// let mut matcher = SignatureMatcher::new();
147    /// matcher.add_predicate(&PredicateInfo::new("at", vec!["Person".into(), "Location".into()]));
148    ///
149    /// let sig = vec!["Person".to_string(), "Location".to_string()];
150    /// let matches = matcher.find_by_signature(&sig);
151    /// assert_eq!(matches, vec!["at"]);
152    /// ```
153    pub fn find_by_signature(&self, signature: &[String]) -> Vec<String> {
154        self.by_signature
155            .get(signature)
156            .cloned()
157            .unwrap_or_default()
158    }
159
160    /// Find all predicates with the given domain types (unordered).
161    ///
162    /// This is useful for finding predicates that operate on a set of domains
163    /// regardless of argument order.
164    ///
165    /// # Example
166    ///
167    /// ```rust
168    /// use tensorlogic_adapters::{PredicateInfo, SignatureMatcher};
169    ///
170    /// let mut matcher = SignatureMatcher::new();
171    /// matcher.add_predicate(&PredicateInfo::new("knows", vec!["Person".into(), "Person".into()]));
172    ///
173    /// let domains = vec!["Person".to_string()];
174    /// let matches = matcher.find_by_domain_set(&domains);
175    /// // "knows" has signature [Person, Person], which when deduplicated matches [Person]
176    /// // Note: This requires exact match of sorted signature
177    /// ```
178    pub fn find_by_domain_set(&self, domains: &[String]) -> Vec<String> {
179        let mut sorted = domains.to_vec();
180        sorted.sort();
181        self.by_domain_set.get(&sorted).cloned().unwrap_or_default()
182    }
183
184    /// Get full predicate information by name.
185    pub fn get_predicate(&self, name: &str) -> Option<&PredicateInfo> {
186        self.predicates.get(name)
187    }
188
189    /// Check if a predicate exists.
190    pub fn contains(&self, name: &str) -> bool {
191        self.predicates.contains_key(name)
192    }
193
194    /// Get the total number of predicates indexed.
195    pub fn len(&self) -> usize {
196        self.predicates.len()
197    }
198
199    /// Check if the matcher is empty.
200    pub fn is_empty(&self) -> bool {
201        self.predicates.is_empty()
202    }
203
204    /// Get all predicate names.
205    pub fn predicate_names(&self) -> Vec<String> {
206        self.predicates.keys().cloned().collect()
207    }
208
209    /// Clear all indices.
210    pub fn clear(&mut self) {
211        self.by_arity.clear();
212        self.by_signature.clear();
213        self.by_domain_set.clear();
214        self.predicates.clear();
215    }
216
217    /// Get statistics about the index sizes.
218    pub fn stats(&self) -> MatcherStats {
219        MatcherStats {
220            total_predicates: self.predicates.len(),
221            unique_arities: self.by_arity.len(),
222            unique_signatures: self.by_signature.len(),
223            unique_domain_sets: self.by_domain_set.len(),
224        }
225    }
226
227    /// Build a matcher from a collection of predicates.
228    pub fn from_predicates<'a>(predicates: impl IntoIterator<Item = &'a PredicateInfo>) -> Self {
229        let mut matcher = Self::new();
230        for pred in predicates {
231            matcher.add_predicate(pred);
232        }
233        matcher
234    }
235}
236
237/// Statistics about the signature matcher indices.
238#[derive(Clone, Debug, PartialEq, Eq)]
239pub struct MatcherStats {
240    /// Total number of predicates indexed.
241    pub total_predicates: usize,
242    /// Number of unique arities.
243    pub unique_arities: usize,
244    /// Number of unique exact signatures.
245    pub unique_signatures: usize,
246    /// Number of unique domain sets (unordered).
247    pub unique_domain_sets: usize,
248}
249
250impl MatcherStats {
251    /// Calculate the average index size.
252    pub fn avg_index_size(&self) -> f64 {
253        if self.unique_signatures == 0 {
254            0.0
255        } else {
256            self.total_predicates as f64 / self.unique_signatures as f64
257        }
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn test_add_and_find_by_arity() {
267        let mut matcher = SignatureMatcher::new();
268
269        let knows = PredicateInfo::new("knows", vec!["Person".into(), "Person".into()]);
270        let age = PredicateInfo::new("age", vec!["Person".into()]);
271
272        matcher.add_predicate(&knows);
273        matcher.add_predicate(&age);
274
275        let unary = matcher.find_by_arity(1);
276        assert_eq!(unary.len(), 1);
277        assert!(unary.contains(&"age".to_string()));
278
279        let binary = matcher.find_by_arity(2);
280        assert_eq!(binary.len(), 1);
281        assert!(binary.contains(&"knows".to_string()));
282    }
283
284    #[test]
285    fn test_find_by_exact_signature() {
286        let mut matcher = SignatureMatcher::new();
287
288        let at = PredicateInfo::new("at", vec!["Person".into(), "Location".into()]);
289        matcher.add_predicate(&at);
290
291        let sig = vec!["Person".to_string(), "Location".to_string()];
292        let matches = matcher.find_by_signature(&sig);
293        assert_eq!(matches, vec!["at"]);
294
295        // Different order should not match
296        let sig_reversed = vec!["Location".to_string(), "Person".to_string()];
297        let no_matches = matcher.find_by_signature(&sig_reversed);
298        assert!(no_matches.is_empty());
299    }
300
301    #[test]
302    fn test_remove_predicate() {
303        let mut matcher = SignatureMatcher::new();
304
305        let knows = PredicateInfo::new("knows", vec!["Person".into(), "Person".into()]);
306        matcher.add_predicate(&knows);
307
308        assert_eq!(matcher.len(), 1);
309        assert!(matcher.contains("knows"));
310
311        matcher.remove_predicate("knows");
312        assert_eq!(matcher.len(), 0);
313        assert!(!matcher.contains("knows"));
314
315        // Should be empty
316        assert!(matcher.find_by_arity(2).is_empty());
317    }
318
319    #[test]
320    fn test_multiple_predicates_same_signature() {
321        let mut matcher = SignatureMatcher::new();
322
323        let p1 = PredicateInfo::new("pred1", vec!["Person".into(), "Person".into()]);
324        let p2 = PredicateInfo::new("pred2", vec!["Person".into(), "Person".into()]);
325
326        matcher.add_predicate(&p1);
327        matcher.add_predicate(&p2);
328
329        let sig = vec!["Person".to_string(), "Person".to_string()];
330        let matches = matcher.find_by_signature(&sig);
331        assert_eq!(matches.len(), 2);
332        assert!(matches.contains(&"pred1".to_string()));
333        assert!(matches.contains(&"pred2".to_string()));
334    }
335
336    #[test]
337    fn test_from_predicates() {
338        let preds = vec![
339            PredicateInfo::new("knows", vec!["Person".into(), "Person".into()]),
340            PredicateInfo::new("age", vec!["Person".into()]),
341        ];
342
343        let matcher = SignatureMatcher::from_predicates(&preds);
344        assert_eq!(matcher.len(), 2);
345        assert!(matcher.contains("knows"));
346        assert!(matcher.contains("age"));
347    }
348
349    #[test]
350    fn test_stats() {
351        let mut matcher = SignatureMatcher::new();
352
353        matcher.add_predicate(&PredicateInfo::new("p1", vec!["A".into(), "B".into()]));
354        matcher.add_predicate(&PredicateInfo::new("p2", vec!["A".into(), "B".into()]));
355        matcher.add_predicate(&PredicateInfo::new("p3", vec!["C".into()]));
356
357        let stats = matcher.stats();
358        assert_eq!(stats.total_predicates, 3);
359        assert_eq!(stats.unique_arities, 2); // arity 1 and 2
360        assert_eq!(stats.unique_signatures, 2); // [A,B] and [C]
361    }
362
363    #[test]
364    fn test_clear() {
365        let mut matcher = SignatureMatcher::new();
366        matcher.add_predicate(&PredicateInfo::new("p1", vec!["A".into()]));
367
368        assert_eq!(matcher.len(), 1);
369        matcher.clear();
370        assert_eq!(matcher.len(), 0);
371        assert!(matcher.is_empty());
372    }
373}