Skip to main content

tensorlogic_adapters/
query_planner.rs

1//! Query planning and optimization for predicate lookups.
2//!
3//! This module provides intelligent query planning for efficient predicate
4//! resolution, leveraging statistics, indexing, and cost-based optimization.
5//!
6//! # Overview
7//!
8//! When resolving predicates in a large schema, different lookup strategies
9//! have vastly different performance characteristics. The query planner:
10//!
11//! - Collects statistics about predicate access patterns
12//! - Builds specialized indexes for common queries
13//! - Generates optimal execution plans based on query shape
14//! - Adapts to workload changes dynamically
15//!
16//! # Architecture
17//!
18//! - **QueryStatistics**: Tracks access patterns and selectivity
19//! - **IndexStrategy**: Multiple index types (hash, range, composite)
20//! - **CostModel**: Estimates query execution cost
21//! - **QueryPlanner**: Generates optimal execution plans
22//! - **PlanCache**: Caches frequently used plans
23//!
24//! # Example
25//!
26//! ```rust
27//! use tensorlogic_adapters::{SymbolTable, PredicateInfo, QueryPlanner, PredicateQuery};
28//!
29//! let mut table = SymbolTable::new();
30//! // ... populate table ...
31//!
32//! let mut planner = QueryPlanner::new(&table);
33//!
34//! // Plan a query for binary predicates over Person domain
35//! let query = PredicateQuery::by_signature(vec!["Person".to_string(), "Person".to_string()]);
36//! let plan = planner.plan(&query).unwrap();
37//!
38//! // Execute the plan
39//! let results = plan.execute(&table).unwrap();
40//! ```
41
42use anyhow::Result;
43use std::collections::{HashMap, HashSet};
44use std::time::{Duration, Instant};
45
46use crate::{PredicateInfo, SymbolTable};
47
48/// Query for predicates
49#[derive(Clone, Debug, PartialEq, Eq, Hash)]
50pub enum PredicateQuery {
51    /// Find predicate by exact name
52    ByName(String),
53    /// Find predicates by arity
54    ByArity(usize),
55    /// Find predicates by exact signature
56    BySignature(Vec<String>),
57    /// Find predicates containing a specific domain
58    ByDomain(String),
59    /// Find predicates matching a pattern
60    ByPattern(PredicatePattern),
61    /// Conjunction of queries
62    And(Vec<PredicateQuery>),
63    /// Disjunction of queries
64    Or(Vec<PredicateQuery>),
65}
66
67impl PredicateQuery {
68    pub fn by_name(name: impl Into<String>) -> Self {
69        Self::ByName(name.into())
70    }
71
72    pub fn by_arity(arity: usize) -> Self {
73        Self::ByArity(arity)
74    }
75
76    pub fn by_signature(domains: Vec<String>) -> Self {
77        Self::BySignature(domains)
78    }
79
80    pub fn by_domain(domain: impl Into<String>) -> Self {
81        Self::ByDomain(domain.into())
82    }
83
84    pub fn by_pattern(pattern: PredicatePattern) -> Self {
85        Self::ByPattern(pattern)
86    }
87
88    pub fn and(queries: Vec<PredicateQuery>) -> Self {
89        Self::And(queries)
90    }
91
92    pub fn or(queries: Vec<PredicateQuery>) -> Self {
93        Self::Or(queries)
94    }
95}
96
97/// Pattern for predicate matching
98#[derive(Clone, Debug, PartialEq, Eq, Hash)]
99pub struct PredicatePattern {
100    /// Name pattern (supports wildcards)
101    pub name_pattern: Option<String>,
102    /// Minimum arity
103    pub min_arity: Option<usize>,
104    /// Maximum arity
105    pub max_arity: Option<usize>,
106    /// Required domains (at any position)
107    pub required_domains: Vec<String>,
108    /// Excluded domains
109    pub excluded_domains: Vec<String>,
110}
111
112impl PredicatePattern {
113    pub fn new() -> Self {
114        Self {
115            name_pattern: None,
116            min_arity: None,
117            max_arity: None,
118            required_domains: Vec::new(),
119            excluded_domains: Vec::new(),
120        }
121    }
122
123    pub fn with_name_pattern(mut self, pattern: impl Into<String>) -> Self {
124        self.name_pattern = Some(pattern.into());
125        self
126    }
127
128    pub fn with_arity_range(mut self, min: usize, max: usize) -> Self {
129        self.min_arity = Some(min);
130        self.max_arity = Some(max);
131        self
132    }
133
134    pub fn with_required_domain(mut self, domain: impl Into<String>) -> Self {
135        self.required_domains.push(domain.into());
136        self
137    }
138
139    pub fn with_excluded_domain(mut self, domain: impl Into<String>) -> Self {
140        self.excluded_domains.push(domain.into());
141        self
142    }
143
144    /// Check if a predicate matches this pattern
145    pub fn matches(&self, name: &str, predicate: &PredicateInfo) -> bool {
146        // Check name pattern
147        if let Some(pattern) = &self.name_pattern {
148            if !matches_wildcard(name, pattern) {
149                return false;
150            }
151        }
152
153        // Check arity range
154        let arity = predicate.arg_domains.len();
155        if let Some(min) = self.min_arity {
156            if arity < min {
157                return false;
158            }
159        }
160        if let Some(max) = self.max_arity {
161            if arity > max {
162                return false;
163            }
164        }
165
166        // Check required domains
167        let domain_set: HashSet<_> = predicate.arg_domains.iter().collect();
168        for required in &self.required_domains {
169            if !domain_set.contains(required) {
170                return false;
171            }
172        }
173
174        // Check excluded domains
175        for excluded in &self.excluded_domains {
176            if domain_set.contains(excluded) {
177                return false;
178            }
179        }
180
181        true
182    }
183}
184
185impl Default for PredicatePattern {
186    fn default() -> Self {
187        Self::new()
188    }
189}
190
191/// Simple wildcard matching (supports * and ?)
192fn matches_wildcard(text: &str, pattern: &str) -> bool {
193    let text_chars: Vec<char> = text.chars().collect();
194    let pattern_chars: Vec<char> = pattern.chars().collect();
195
196    let mut dp = vec![vec![false; pattern_chars.len() + 1]; text_chars.len() + 1];
197    dp[0][0] = true;
198
199    // Handle leading stars
200    for j in 1..=pattern_chars.len() {
201        if pattern_chars[j - 1] == '*' {
202            dp[0][j] = dp[0][j - 1];
203        }
204    }
205
206    for i in 1..=text_chars.len() {
207        for j in 1..=pattern_chars.len() {
208            if pattern_chars[j - 1] == '*' {
209                dp[i][j] = dp[i - 1][j] || dp[i][j - 1];
210            } else if pattern_chars[j - 1] == '?' || text_chars[i - 1] == pattern_chars[j - 1] {
211                dp[i][j] = dp[i - 1][j - 1];
212            }
213        }
214    }
215
216    dp[text_chars.len()][pattern_chars.len()]
217}
218
219/// Statistics about predicate access patterns
220#[derive(Clone, Debug)]
221pub struct QueryStatistics {
222    /// Number of times each query type has been executed
223    query_counts: HashMap<String, usize>,
224    /// Selectivity (fraction of results) for each query
225    selectivity: HashMap<String, f64>,
226    /// Average execution time for each query type
227    avg_execution_time: HashMap<String, Duration>,
228    /// Total executions
229    total_queries: usize,
230}
231
232impl QueryStatistics {
233    pub fn new() -> Self {
234        Self {
235            query_counts: HashMap::new(),
236            selectivity: HashMap::new(),
237            avg_execution_time: HashMap::new(),
238            total_queries: 0,
239        }
240    }
241
242    /// Record a query execution
243    pub fn record_query(
244        &mut self,
245        query_type: impl Into<String>,
246        duration: Duration,
247        result_count: usize,
248        total_predicates: usize,
249    ) {
250        let query_type = query_type.into();
251
252        *self.query_counts.entry(query_type.clone()).or_insert(0) += 1;
253        self.total_queries += 1;
254
255        let selectivity = if total_predicates > 0 {
256            result_count as f64 / total_predicates as f64
257        } else {
258            0.0
259        };
260
261        self.selectivity.insert(query_type.clone(), selectivity);
262
263        let count = self.query_counts[&query_type];
264        let current_avg = self
265            .avg_execution_time
266            .get(&query_type)
267            .copied()
268            .unwrap_or(Duration::ZERO);
269
270        let new_avg = (current_avg * (count as u32 - 1) + duration) / count as u32;
271        self.avg_execution_time.insert(query_type, new_avg);
272    }
273
274    /// Get the most frequent query types
275    pub fn top_queries(&self, limit: usize) -> Vec<(String, usize)> {
276        let mut queries: Vec<_> = self
277            .query_counts
278            .iter()
279            .map(|(k, v)| (k.clone(), *v))
280            .collect();
281        queries.sort_by(|a, b| b.1.cmp(&a.1));
282        queries.truncate(limit);
283        queries
284    }
285
286    /// Get average selectivity for a query type
287    pub fn get_selectivity(&self, query_type: &str) -> f64 {
288        self.selectivity.get(query_type).copied().unwrap_or(1.0)
289    }
290
291    /// Get average execution time for a query type
292    pub fn get_avg_time(&self, query_type: &str) -> Duration {
293        self.avg_execution_time
294            .get(query_type)
295            .copied()
296            .unwrap_or(Duration::ZERO)
297    }
298}
299
300impl Default for QueryStatistics {
301    fn default() -> Self {
302        Self::new()
303    }
304}
305
306/// Index strategy for predicate lookups
307#[derive(Clone, Debug, PartialEq, Eq)]
308pub enum IndexStrategy {
309    /// No index, full scan
310    FullScan,
311    /// Hash index on predicate name
312    NameHash,
313    /// Range index on arity
314    ArityRange,
315    /// Hash index on signature
316    SignatureHash,
317    /// Inverted index on domains
318    DomainInverted,
319    /// Composite index
320    Composite(Vec<IndexStrategy>),
321}
322
323impl IndexStrategy {
324    /// Estimate the cost of using this strategy
325    pub fn estimate_cost(&self, predicates_count: usize, _stats: &QueryStatistics) -> f64 {
326        match self {
327            IndexStrategy::FullScan => predicates_count as f64,
328            IndexStrategy::NameHash => 1.0, // O(1) lookup
329            IndexStrategy::ArityRange => (predicates_count as f64).sqrt(), // O(sqrt(n)) estimate
330            IndexStrategy::SignatureHash => 1.0, // O(1) lookup
331            IndexStrategy::DomainInverted => (predicates_count as f64).log2(), // O(log n) estimate
332            IndexStrategy::Composite(strategies) => {
333                // Cost is minimum of component strategies
334                strategies
335                    .iter()
336                    .map(|s| s.estimate_cost(predicates_count, _stats))
337                    .min_by(|a, b| a.partial_cmp(b).unwrap())
338                    .unwrap_or(predicates_count as f64)
339            }
340        }
341    }
342}
343
344/// Query execution plan
345#[derive(Clone, Debug)]
346pub struct QueryPlan {
347    query: PredicateQuery,
348    strategy: IndexStrategy,
349    estimated_cost: f64,
350    estimated_results: usize,
351}
352
353impl QueryPlan {
354    pub fn new(query: PredicateQuery, strategy: IndexStrategy) -> Self {
355        Self {
356            query,
357            strategy,
358            estimated_cost: 0.0,
359            estimated_results: 0,
360        }
361    }
362
363    pub fn with_cost(mut self, cost: f64) -> Self {
364        self.estimated_cost = cost;
365        self
366    }
367
368    pub fn with_estimated_results(mut self, count: usize) -> Self {
369        self.estimated_results = count;
370        self
371    }
372
373    /// Execute the plan
374    pub fn execute(&self, table: &SymbolTable) -> Result<Vec<(String, PredicateInfo)>> {
375        match &self.query {
376            PredicateQuery::ByName(name) => {
377                if let Some(pred) = table.predicates.get(name) {
378                    Ok(vec![(name.clone(), pred.clone())])
379                } else {
380                    Ok(Vec::new())
381                }
382            }
383            PredicateQuery::ByArity(arity) => {
384                let results: Vec<_> = table
385                    .predicates
386                    .iter()
387                    .filter(|(_, pred)| pred.arg_domains.len() == *arity)
388                    .map(|(name, pred)| (name.clone(), pred.clone()))
389                    .collect();
390                Ok(results)
391            }
392            PredicateQuery::BySignature(signature) => {
393                let results: Vec<_> = table
394                    .predicates
395                    .iter()
396                    .filter(|(_, pred)| pred.arg_domains == *signature)
397                    .map(|(name, pred)| (name.clone(), pred.clone()))
398                    .collect();
399                Ok(results)
400            }
401            PredicateQuery::ByDomain(domain) => {
402                let results: Vec<_> = table
403                    .predicates
404                    .iter()
405                    .filter(|(_, pred)| pred.arg_domains.contains(domain))
406                    .map(|(name, pred)| (name.clone(), pred.clone()))
407                    .collect();
408                Ok(results)
409            }
410            PredicateQuery::ByPattern(pattern) => {
411                let results: Vec<_> = table
412                    .predicates
413                    .iter()
414                    .filter(|(name, pred)| pattern.matches(name, pred))
415                    .map(|(name, pred)| (name.clone(), pred.clone()))
416                    .collect();
417                Ok(results)
418            }
419            PredicateQuery::And(queries) => {
420                if queries.is_empty() {
421                    return Ok(Vec::new());
422                }
423
424                // Execute first query
425                let mut results: HashSet<String> = self
426                    .execute_subquery(&queries[0], table)?
427                    .into_iter()
428                    .map(|(name, _)| name)
429                    .collect();
430
431                // Intersect with remaining queries
432                for query in &queries[1..] {
433                    let subresults: HashSet<String> = self
434                        .execute_subquery(query, table)?
435                        .into_iter()
436                        .map(|(name, _)| name)
437                        .collect();
438                    results.retain(|name| subresults.contains(name));
439                }
440
441                Ok(results
442                    .into_iter()
443                    .filter_map(|name| {
444                        table
445                            .predicates
446                            .get(&name)
447                            .map(|pred| (name.clone(), pred.clone()))
448                    })
449                    .collect())
450            }
451            PredicateQuery::Or(queries) => {
452                let mut results_map: HashMap<String, PredicateInfo> = HashMap::new();
453
454                for query in queries {
455                    let subresults = self.execute_subquery(query, table)?;
456                    for (name, pred) in subresults {
457                        results_map.insert(name, pred);
458                    }
459                }
460
461                Ok(results_map.into_iter().collect())
462            }
463        }
464    }
465
466    fn execute_subquery(
467        &self,
468        query: &PredicateQuery,
469        table: &SymbolTable,
470    ) -> Result<Vec<(String, PredicateInfo)>> {
471        let subplan = QueryPlan::new(query.clone(), self.strategy.clone());
472        subplan.execute(table)
473    }
474
475    pub fn query(&self) -> &PredicateQuery {
476        &self.query
477    }
478
479    pub fn strategy(&self) -> &IndexStrategy {
480        &self.strategy
481    }
482
483    pub fn estimated_cost(&self) -> f64 {
484        self.estimated_cost
485    }
486}
487
488/// Query planner for optimizing predicate lookups
489pub struct QueryPlanner<'a> {
490    table: &'a SymbolTable,
491    statistics: QueryStatistics,
492    plan_cache: HashMap<PredicateQuery, QueryPlan>,
493}
494
495impl<'a> QueryPlanner<'a> {
496    pub fn new(table: &'a SymbolTable) -> Self {
497        Self {
498            table,
499            statistics: QueryStatistics::new(),
500            plan_cache: HashMap::new(),
501        }
502    }
503
504    pub fn with_statistics(mut self, statistics: QueryStatistics) -> Self {
505        self.statistics = statistics;
506        self
507    }
508
509    /// Plan a query
510    pub fn plan(&mut self, query: &PredicateQuery) -> Result<QueryPlan> {
511        // Check cache first
512        if let Some(cached) = self.plan_cache.get(query) {
513            return Ok(cached.clone());
514        }
515
516        let plan = self.generate_plan(query)?;
517        self.plan_cache.insert(query.clone(), plan.clone());
518        Ok(plan)
519    }
520
521    /// Generate an optimal plan for a query
522    fn generate_plan(&self, query: &PredicateQuery) -> Result<QueryPlan> {
523        let strategy = self.select_strategy(query);
524        let cost = strategy.estimate_cost(self.table.predicates.len(), &self.statistics);
525
526        let plan = QueryPlan::new(query.clone(), strategy).with_cost(cost);
527
528        Ok(plan)
529    }
530
531    /// Select the best index strategy for a query
532    fn select_strategy(&self, query: &PredicateQuery) -> IndexStrategy {
533        Self::select_strategy_static(query)
534    }
535
536    /// Static strategy selection (to avoid recursion on self)
537    fn select_strategy_static(query: &PredicateQuery) -> IndexStrategy {
538        match query {
539            PredicateQuery::ByName(_) => IndexStrategy::NameHash,
540            PredicateQuery::ByArity(_) => IndexStrategy::ArityRange,
541            PredicateQuery::BySignature(_) => IndexStrategy::SignatureHash,
542            PredicateQuery::ByDomain(_) => IndexStrategy::DomainInverted,
543            PredicateQuery::ByPattern(_) => {
544                // Pattern queries typically require full scan
545                IndexStrategy::FullScan
546            }
547            PredicateQuery::And(queries) => {
548                // Use the most selective strategy
549                let strategies: Vec<_> = queries.iter().map(Self::select_strategy_static).collect();
550                IndexStrategy::Composite(strategies)
551            }
552            PredicateQuery::Or(queries) => {
553                // Use composite strategy
554                let strategies: Vec<_> = queries.iter().map(Self::select_strategy_static).collect();
555                IndexStrategy::Composite(strategies)
556            }
557        }
558    }
559
560    /// Execute a query and record statistics
561    pub fn execute(&mut self, query: &PredicateQuery) -> Result<Vec<(String, PredicateInfo)>> {
562        let start = Instant::now();
563        let plan = self.plan(query)?;
564        let results = plan.execute(self.table)?;
565        let duration = start.elapsed();
566
567        let query_type = format!("{:?}", query)
568            .split('(')
569            .next()
570            .unwrap_or("Unknown")
571            .to_string();
572        self.statistics.record_query(
573            query_type,
574            duration,
575            results.len(),
576            self.table.predicates.len(),
577        );
578
579        Ok(results)
580    }
581
582    /// Get query statistics
583    pub fn statistics(&self) -> &QueryStatistics {
584        &self.statistics
585    }
586
587    /// Clear the plan cache
588    pub fn clear_cache(&mut self) {
589        self.plan_cache.clear();
590    }
591
592    /// Get cache size
593    pub fn cache_size(&self) -> usize {
594        self.plan_cache.len()
595    }
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601    use crate::DomainInfo;
602
603    fn setup_table() -> SymbolTable {
604        let mut table = SymbolTable::new();
605        table.add_domain(DomainInfo::new("Person", 100)).unwrap();
606        table.add_domain(DomainInfo::new("Location", 50)).unwrap();
607
608        let knows = PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()]);
609        table.add_predicate(knows).unwrap();
610
611        let at = PredicateInfo::new("at", vec!["Person".to_string(), "Location".to_string()]);
612        table.add_predicate(at).unwrap();
613
614        let friends =
615            PredicateInfo::new("friends", vec!["Person".to_string(), "Person".to_string()]);
616        table.add_predicate(friends).unwrap();
617
618        table
619    }
620
621    #[test]
622    fn test_query_by_name() {
623        let table = setup_table();
624        let mut planner = QueryPlanner::new(&table);
625
626        let query = PredicateQuery::by_name("knows");
627        let results = planner.execute(&query).unwrap();
628
629        assert_eq!(results.len(), 1);
630        assert_eq!(results[0].0, "knows");
631    }
632
633    #[test]
634    fn test_query_by_arity() {
635        let table = setup_table();
636        let mut planner = QueryPlanner::new(&table);
637
638        let query = PredicateQuery::by_arity(2);
639        let results = planner.execute(&query).unwrap();
640
641        assert_eq!(results.len(), 3); // knows, at, friends
642    }
643
644    #[test]
645    fn test_query_by_signature() {
646        let table = setup_table();
647        let mut planner = QueryPlanner::new(&table);
648
649        let query = PredicateQuery::by_signature(vec!["Person".to_string(), "Person".to_string()]);
650        let results = planner.execute(&query).unwrap();
651
652        assert_eq!(results.len(), 2); // knows, friends
653    }
654
655    #[test]
656    fn test_query_by_domain() {
657        let table = setup_table();
658        let mut planner = QueryPlanner::new(&table);
659
660        let query = PredicateQuery::by_domain("Location");
661        let results = planner.execute(&query).unwrap();
662
663        assert_eq!(results.len(), 1); // at
664        assert_eq!(results[0].0, "at");
665    }
666
667    #[test]
668    fn test_query_and() {
669        let table = setup_table();
670        let mut planner = QueryPlanner::new(&table);
671
672        let query = PredicateQuery::and(vec![
673            PredicateQuery::by_arity(2),
674            PredicateQuery::by_domain("Location"),
675        ]);
676        let results = planner.execute(&query).unwrap();
677
678        assert_eq!(results.len(), 1); // at
679    }
680
681    #[test]
682    fn test_query_or() {
683        let table = setup_table();
684        let mut planner = QueryPlanner::new(&table);
685
686        let query = PredicateQuery::or(vec![
687            PredicateQuery::by_name("knows"),
688            PredicateQuery::by_name("at"),
689        ]);
690        let results = planner.execute(&query).unwrap();
691
692        assert_eq!(results.len(), 2); // knows, at
693    }
694
695    #[test]
696    fn test_predicate_pattern() {
697        let pattern = PredicatePattern::new()
698            .with_name_pattern("know*")
699            .with_arity_range(2, 3)
700            .with_required_domain("Person");
701
702        let knows = PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()]);
703        assert!(pattern.matches("knows", &knows));
704
705        let at = PredicateInfo::new("at", vec!["Person".to_string(), "Location".to_string()]);
706        assert!(!pattern.matches("at", &at));
707    }
708
709    #[test]
710    fn test_query_by_pattern() {
711        let table = setup_table();
712        let mut planner = QueryPlanner::new(&table);
713
714        let pattern = PredicatePattern::new()
715            .with_name_pattern("*friend*")
716            .with_required_domain("Person");
717
718        let query = PredicateQuery::by_pattern(pattern);
719        let results = planner.execute(&query).unwrap();
720
721        assert_eq!(results.len(), 1); // friends
722    }
723
724    #[test]
725    fn test_wildcard_matching() {
726        assert!(matches_wildcard("hello", "h*"));
727        assert!(matches_wildcard("hello", "he??o"));
728        assert!(matches_wildcard("hello", "*"));
729        assert!(matches_wildcard("hello", "hello"));
730        assert!(!matches_wildcard("hello", "h*x"));
731        assert!(matches_wildcard("test123", "test*"));
732    }
733
734    #[test]
735    fn test_statistics() {
736        let mut stats = QueryStatistics::new();
737
738        stats.record_query("ByName", Duration::from_millis(10), 1, 100);
739        stats.record_query("ByName", Duration::from_millis(20), 1, 100);
740        stats.record_query("ByArity", Duration::from_millis(50), 10, 100);
741
742        assert_eq!(stats.total_queries, 3);
743        assert_eq!(stats.get_selectivity("ByName"), 0.01);
744        assert_eq!(stats.get_selectivity("ByArity"), 0.1);
745
746        let top = stats.top_queries(2);
747        assert_eq!(top[0].0, "ByName");
748        assert_eq!(top[0].1, 2);
749    }
750
751    #[test]
752    fn test_plan_caching() {
753        let table = setup_table();
754        let mut planner = QueryPlanner::new(&table);
755
756        let query = PredicateQuery::by_name("knows");
757
758        planner.plan(&query).unwrap();
759        assert_eq!(planner.cache_size(), 1);
760
761        planner.plan(&query).unwrap();
762        assert_eq!(planner.cache_size(), 1); // Should reuse cached plan
763
764        planner.clear_cache();
765        assert_eq!(planner.cache_size(), 0);
766    }
767
768    #[test]
769    fn test_index_strategy_cost() {
770        let stats = QueryStatistics::new();
771
772        let full_scan = IndexStrategy::FullScan;
773        let hash = IndexStrategy::NameHash;
774
775        assert_eq!(full_scan.estimate_cost(1000, &stats), 1000.0);
776        assert_eq!(hash.estimate_cost(1000, &stats), 1.0);
777    }
778}