1use anyhow::Result;
43use std::collections::{HashMap, HashSet};
44use std::time::{Duration, Instant};
45
46use crate::{PredicateInfo, SymbolTable};
47
48#[derive(Clone, Debug, PartialEq, Eq, Hash)]
50pub enum PredicateQuery {
51 ByName(String),
53 ByArity(usize),
55 BySignature(Vec<String>),
57 ByDomain(String),
59 ByPattern(PredicatePattern),
61 And(Vec<PredicateQuery>),
63 Or(Vec<PredicateQuery>),
65}
66
67impl PredicateQuery {
68 pub fn by_name(name: impl Into<String>) -> Self {
69 Self::ByName(name.into())
70 }
71
72 pub fn by_arity(arity: usize) -> Self {
73 Self::ByArity(arity)
74 }
75
76 pub fn by_signature(domains: Vec<String>) -> Self {
77 Self::BySignature(domains)
78 }
79
80 pub fn by_domain(domain: impl Into<String>) -> Self {
81 Self::ByDomain(domain.into())
82 }
83
84 pub fn by_pattern(pattern: PredicatePattern) -> Self {
85 Self::ByPattern(pattern)
86 }
87
88 pub fn and(queries: Vec<PredicateQuery>) -> Self {
89 Self::And(queries)
90 }
91
92 pub fn or(queries: Vec<PredicateQuery>) -> Self {
93 Self::Or(queries)
94 }
95}
96
97#[derive(Clone, Debug, PartialEq, Eq, Hash)]
99pub struct PredicatePattern {
100 pub name_pattern: Option<String>,
102 pub min_arity: Option<usize>,
104 pub max_arity: Option<usize>,
106 pub required_domains: Vec<String>,
108 pub excluded_domains: Vec<String>,
110}
111
112impl PredicatePattern {
113 pub fn new() -> Self {
114 Self {
115 name_pattern: None,
116 min_arity: None,
117 max_arity: None,
118 required_domains: Vec::new(),
119 excluded_domains: Vec::new(),
120 }
121 }
122
123 pub fn with_name_pattern(mut self, pattern: impl Into<String>) -> Self {
124 self.name_pattern = Some(pattern.into());
125 self
126 }
127
128 pub fn with_arity_range(mut self, min: usize, max: usize) -> Self {
129 self.min_arity = Some(min);
130 self.max_arity = Some(max);
131 self
132 }
133
134 pub fn with_required_domain(mut self, domain: impl Into<String>) -> Self {
135 self.required_domains.push(domain.into());
136 self
137 }
138
139 pub fn with_excluded_domain(mut self, domain: impl Into<String>) -> Self {
140 self.excluded_domains.push(domain.into());
141 self
142 }
143
144 pub fn matches(&self, name: &str, predicate: &PredicateInfo) -> bool {
146 if let Some(pattern) = &self.name_pattern {
148 if !matches_wildcard(name, pattern) {
149 return false;
150 }
151 }
152
153 let arity = predicate.arg_domains.len();
155 if let Some(min) = self.min_arity {
156 if arity < min {
157 return false;
158 }
159 }
160 if let Some(max) = self.max_arity {
161 if arity > max {
162 return false;
163 }
164 }
165
166 let domain_set: HashSet<_> = predicate.arg_domains.iter().collect();
168 for required in &self.required_domains {
169 if !domain_set.contains(required) {
170 return false;
171 }
172 }
173
174 for excluded in &self.excluded_domains {
176 if domain_set.contains(excluded) {
177 return false;
178 }
179 }
180
181 true
182 }
183}
184
185impl Default for PredicatePattern {
186 fn default() -> Self {
187 Self::new()
188 }
189}
190
191fn matches_wildcard(text: &str, pattern: &str) -> bool {
193 let text_chars: Vec<char> = text.chars().collect();
194 let pattern_chars: Vec<char> = pattern.chars().collect();
195
196 let mut dp = vec![vec![false; pattern_chars.len() + 1]; text_chars.len() + 1];
197 dp[0][0] = true;
198
199 for j in 1..=pattern_chars.len() {
201 if pattern_chars[j - 1] == '*' {
202 dp[0][j] = dp[0][j - 1];
203 }
204 }
205
206 for i in 1..=text_chars.len() {
207 for j in 1..=pattern_chars.len() {
208 if pattern_chars[j - 1] == '*' {
209 dp[i][j] = dp[i - 1][j] || dp[i][j - 1];
210 } else if pattern_chars[j - 1] == '?' || text_chars[i - 1] == pattern_chars[j - 1] {
211 dp[i][j] = dp[i - 1][j - 1];
212 }
213 }
214 }
215
216 dp[text_chars.len()][pattern_chars.len()]
217}
218
219#[derive(Clone, Debug)]
221pub struct QueryStatistics {
222 query_counts: HashMap<String, usize>,
224 selectivity: HashMap<String, f64>,
226 avg_execution_time: HashMap<String, Duration>,
228 total_queries: usize,
230}
231
232impl QueryStatistics {
233 pub fn new() -> Self {
234 Self {
235 query_counts: HashMap::new(),
236 selectivity: HashMap::new(),
237 avg_execution_time: HashMap::new(),
238 total_queries: 0,
239 }
240 }
241
242 pub fn record_query(
244 &mut self,
245 query_type: impl Into<String>,
246 duration: Duration,
247 result_count: usize,
248 total_predicates: usize,
249 ) {
250 let query_type = query_type.into();
251
252 *self.query_counts.entry(query_type.clone()).or_insert(0) += 1;
253 self.total_queries += 1;
254
255 let selectivity = if total_predicates > 0 {
256 result_count as f64 / total_predicates as f64
257 } else {
258 0.0
259 };
260
261 self.selectivity.insert(query_type.clone(), selectivity);
262
263 let count = self.query_counts[&query_type];
264 let current_avg = self
265 .avg_execution_time
266 .get(&query_type)
267 .copied()
268 .unwrap_or(Duration::ZERO);
269
270 let new_avg = (current_avg * (count as u32 - 1) + duration) / count as u32;
271 self.avg_execution_time.insert(query_type, new_avg);
272 }
273
274 pub fn top_queries(&self, limit: usize) -> Vec<(String, usize)> {
276 let mut queries: Vec<_> = self
277 .query_counts
278 .iter()
279 .map(|(k, v)| (k.clone(), *v))
280 .collect();
281 queries.sort_by_key(|b| std::cmp::Reverse(b.1));
282 queries.truncate(limit);
283 queries
284 }
285
286 pub fn get_selectivity(&self, query_type: &str) -> f64 {
288 self.selectivity.get(query_type).copied().unwrap_or(1.0)
289 }
290
291 pub fn get_avg_time(&self, query_type: &str) -> Duration {
293 self.avg_execution_time
294 .get(query_type)
295 .copied()
296 .unwrap_or(Duration::ZERO)
297 }
298}
299
300impl Default for QueryStatistics {
301 fn default() -> Self {
302 Self::new()
303 }
304}
305
306#[derive(Clone, Debug, PartialEq, Eq)]
308pub enum IndexStrategy {
309 FullScan,
311 NameHash,
313 ArityRange,
315 SignatureHash,
317 DomainInverted,
319 Composite(Vec<IndexStrategy>),
321}
322
323impl IndexStrategy {
324 pub fn estimate_cost(&self, predicates_count: usize, _stats: &QueryStatistics) -> f64 {
326 match self {
327 IndexStrategy::FullScan => predicates_count as f64,
328 IndexStrategy::NameHash => 1.0, IndexStrategy::ArityRange => (predicates_count as f64).sqrt(), IndexStrategy::SignatureHash => 1.0, IndexStrategy::DomainInverted => (predicates_count as f64).log2(), IndexStrategy::Composite(strategies) => {
333 strategies
335 .iter()
336 .map(|s| s.estimate_cost(predicates_count, _stats))
337 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
338 .unwrap_or(predicates_count as f64)
339 }
340 }
341 }
342}
343
344#[derive(Clone, Debug)]
346pub struct QueryPlan {
347 query: PredicateQuery,
348 strategy: IndexStrategy,
349 estimated_cost: f64,
350 estimated_results: usize,
351}
352
353impl QueryPlan {
354 pub fn new(query: PredicateQuery, strategy: IndexStrategy) -> Self {
355 Self {
356 query,
357 strategy,
358 estimated_cost: 0.0,
359 estimated_results: 0,
360 }
361 }
362
363 pub fn with_cost(mut self, cost: f64) -> Self {
364 self.estimated_cost = cost;
365 self
366 }
367
368 pub fn with_estimated_results(mut self, count: usize) -> Self {
369 self.estimated_results = count;
370 self
371 }
372
373 pub fn execute(&self, table: &SymbolTable) -> Result<Vec<(String, PredicateInfo)>> {
375 match &self.query {
376 PredicateQuery::ByName(name) => {
377 if let Some(pred) = table.predicates.get(name) {
378 Ok(vec![(name.clone(), pred.clone())])
379 } else {
380 Ok(Vec::new())
381 }
382 }
383 PredicateQuery::ByArity(arity) => {
384 let results: Vec<_> = table
385 .predicates
386 .iter()
387 .filter(|(_, pred)| pred.arg_domains.len() == *arity)
388 .map(|(name, pred)| (name.clone(), pred.clone()))
389 .collect();
390 Ok(results)
391 }
392 PredicateQuery::BySignature(signature) => {
393 let results: Vec<_> = table
394 .predicates
395 .iter()
396 .filter(|(_, pred)| pred.arg_domains == *signature)
397 .map(|(name, pred)| (name.clone(), pred.clone()))
398 .collect();
399 Ok(results)
400 }
401 PredicateQuery::ByDomain(domain) => {
402 let results: Vec<_> = table
403 .predicates
404 .iter()
405 .filter(|(_, pred)| pred.arg_domains.contains(domain))
406 .map(|(name, pred)| (name.clone(), pred.clone()))
407 .collect();
408 Ok(results)
409 }
410 PredicateQuery::ByPattern(pattern) => {
411 let results: Vec<_> = table
412 .predicates
413 .iter()
414 .filter(|(name, pred)| pattern.matches(name, pred))
415 .map(|(name, pred)| (name.clone(), pred.clone()))
416 .collect();
417 Ok(results)
418 }
419 PredicateQuery::And(queries) => {
420 if queries.is_empty() {
421 return Ok(Vec::new());
422 }
423
424 let mut results: HashSet<String> = self
426 .execute_subquery(&queries[0], table)?
427 .into_iter()
428 .map(|(name, _)| name)
429 .collect();
430
431 for query in &queries[1..] {
433 let subresults: HashSet<String> = self
434 .execute_subquery(query, table)?
435 .into_iter()
436 .map(|(name, _)| name)
437 .collect();
438 results.retain(|name| subresults.contains(name));
439 }
440
441 Ok(results
442 .into_iter()
443 .filter_map(|name| {
444 table
445 .predicates
446 .get(&name)
447 .map(|pred| (name.clone(), pred.clone()))
448 })
449 .collect())
450 }
451 PredicateQuery::Or(queries) => {
452 let mut results_map: HashMap<String, PredicateInfo> = HashMap::new();
453
454 for query in queries {
455 let subresults = self.execute_subquery(query, table)?;
456 for (name, pred) in subresults {
457 results_map.insert(name, pred);
458 }
459 }
460
461 Ok(results_map.into_iter().collect())
462 }
463 }
464 }
465
466 fn execute_subquery(
467 &self,
468 query: &PredicateQuery,
469 table: &SymbolTable,
470 ) -> Result<Vec<(String, PredicateInfo)>> {
471 let subplan = QueryPlan::new(query.clone(), self.strategy.clone());
472 subplan.execute(table)
473 }
474
475 pub fn query(&self) -> &PredicateQuery {
476 &self.query
477 }
478
479 pub fn strategy(&self) -> &IndexStrategy {
480 &self.strategy
481 }
482
483 pub fn estimated_cost(&self) -> f64 {
484 self.estimated_cost
485 }
486}
487
488pub struct QueryPlanner<'a> {
490 table: &'a SymbolTable,
491 statistics: QueryStatistics,
492 plan_cache: HashMap<PredicateQuery, QueryPlan>,
493}
494
495impl<'a> QueryPlanner<'a> {
496 pub fn new(table: &'a SymbolTable) -> Self {
497 Self {
498 table,
499 statistics: QueryStatistics::new(),
500 plan_cache: HashMap::new(),
501 }
502 }
503
504 pub fn with_statistics(mut self, statistics: QueryStatistics) -> Self {
505 self.statistics = statistics;
506 self
507 }
508
509 pub fn plan(&mut self, query: &PredicateQuery) -> Result<QueryPlan> {
511 if let Some(cached) = self.plan_cache.get(query) {
513 return Ok(cached.clone());
514 }
515
516 let plan = self.generate_plan(query)?;
517 self.plan_cache.insert(query.clone(), plan.clone());
518 Ok(plan)
519 }
520
521 fn generate_plan(&self, query: &PredicateQuery) -> Result<QueryPlan> {
523 let strategy = self.select_strategy(query);
524 let cost = strategy.estimate_cost(self.table.predicates.len(), &self.statistics);
525
526 let plan = QueryPlan::new(query.clone(), strategy).with_cost(cost);
527
528 Ok(plan)
529 }
530
531 fn select_strategy(&self, query: &PredicateQuery) -> IndexStrategy {
533 Self::select_strategy_static(query)
534 }
535
536 fn select_strategy_static(query: &PredicateQuery) -> IndexStrategy {
538 match query {
539 PredicateQuery::ByName(_) => IndexStrategy::NameHash,
540 PredicateQuery::ByArity(_) => IndexStrategy::ArityRange,
541 PredicateQuery::BySignature(_) => IndexStrategy::SignatureHash,
542 PredicateQuery::ByDomain(_) => IndexStrategy::DomainInverted,
543 PredicateQuery::ByPattern(_) => {
544 IndexStrategy::FullScan
546 }
547 PredicateQuery::And(queries) => {
548 let strategies: Vec<_> = queries.iter().map(Self::select_strategy_static).collect();
550 IndexStrategy::Composite(strategies)
551 }
552 PredicateQuery::Or(queries) => {
553 let strategies: Vec<_> = queries.iter().map(Self::select_strategy_static).collect();
555 IndexStrategy::Composite(strategies)
556 }
557 }
558 }
559
560 pub fn execute(&mut self, query: &PredicateQuery) -> Result<Vec<(String, PredicateInfo)>> {
562 let start = Instant::now();
563 let plan = self.plan(query)?;
564 let results = plan.execute(self.table)?;
565 let duration = start.elapsed();
566
567 let query_type = format!("{:?}", query)
568 .split('(')
569 .next()
570 .unwrap_or("Unknown")
571 .to_string();
572 self.statistics.record_query(
573 query_type,
574 duration,
575 results.len(),
576 self.table.predicates.len(),
577 );
578
579 Ok(results)
580 }
581
582 pub fn statistics(&self) -> &QueryStatistics {
584 &self.statistics
585 }
586
587 pub fn clear_cache(&mut self) {
589 self.plan_cache.clear();
590 }
591
592 pub fn cache_size(&self) -> usize {
594 self.plan_cache.len()
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601 use crate::DomainInfo;
602
603 fn setup_table() -> SymbolTable {
604 let mut table = SymbolTable::new();
605 table
606 .add_domain(DomainInfo::new("Person", 100))
607 .expect("unwrap");
608 table
609 .add_domain(DomainInfo::new("Location", 50))
610 .expect("unwrap");
611
612 let knows = PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()]);
613 table.add_predicate(knows).expect("unwrap");
614
615 let at = PredicateInfo::new("at", vec!["Person".to_string(), "Location".to_string()]);
616 table.add_predicate(at).expect("unwrap");
617
618 let friends =
619 PredicateInfo::new("friends", vec!["Person".to_string(), "Person".to_string()]);
620 table.add_predicate(friends).expect("unwrap");
621
622 table
623 }
624
625 #[test]
626 fn test_query_by_name() {
627 let table = setup_table();
628 let mut planner = QueryPlanner::new(&table);
629
630 let query = PredicateQuery::by_name("knows");
631 let results = planner.execute(&query).expect("unwrap");
632
633 assert_eq!(results.len(), 1);
634 assert_eq!(results[0].0, "knows");
635 }
636
637 #[test]
638 fn test_query_by_arity() {
639 let table = setup_table();
640 let mut planner = QueryPlanner::new(&table);
641
642 let query = PredicateQuery::by_arity(2);
643 let results = planner.execute(&query).expect("unwrap");
644
645 assert_eq!(results.len(), 3); }
647
648 #[test]
649 fn test_query_by_signature() {
650 let table = setup_table();
651 let mut planner = QueryPlanner::new(&table);
652
653 let query = PredicateQuery::by_signature(vec!["Person".to_string(), "Person".to_string()]);
654 let results = planner.execute(&query).expect("unwrap");
655
656 assert_eq!(results.len(), 2); }
658
659 #[test]
660 fn test_query_by_domain() {
661 let table = setup_table();
662 let mut planner = QueryPlanner::new(&table);
663
664 let query = PredicateQuery::by_domain("Location");
665 let results = planner.execute(&query).expect("unwrap");
666
667 assert_eq!(results.len(), 1); assert_eq!(results[0].0, "at");
669 }
670
671 #[test]
672 fn test_query_and() {
673 let table = setup_table();
674 let mut planner = QueryPlanner::new(&table);
675
676 let query = PredicateQuery::and(vec![
677 PredicateQuery::by_arity(2),
678 PredicateQuery::by_domain("Location"),
679 ]);
680 let results = planner.execute(&query).expect("unwrap");
681
682 assert_eq!(results.len(), 1); }
684
685 #[test]
686 fn test_query_or() {
687 let table = setup_table();
688 let mut planner = QueryPlanner::new(&table);
689
690 let query = PredicateQuery::or(vec![
691 PredicateQuery::by_name("knows"),
692 PredicateQuery::by_name("at"),
693 ]);
694 let results = planner.execute(&query).expect("unwrap");
695
696 assert_eq!(results.len(), 2); }
698
699 #[test]
700 fn test_predicate_pattern() {
701 let pattern = PredicatePattern::new()
702 .with_name_pattern("know*")
703 .with_arity_range(2, 3)
704 .with_required_domain("Person");
705
706 let knows = PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()]);
707 assert!(pattern.matches("knows", &knows));
708
709 let at = PredicateInfo::new("at", vec!["Person".to_string(), "Location".to_string()]);
710 assert!(!pattern.matches("at", &at));
711 }
712
713 #[test]
714 fn test_query_by_pattern() {
715 let table = setup_table();
716 let mut planner = QueryPlanner::new(&table);
717
718 let pattern = PredicatePattern::new()
719 .with_name_pattern("*friend*")
720 .with_required_domain("Person");
721
722 let query = PredicateQuery::by_pattern(pattern);
723 let results = planner.execute(&query).expect("unwrap");
724
725 assert_eq!(results.len(), 1); }
727
728 #[test]
729 fn test_wildcard_matching() {
730 assert!(matches_wildcard("hello", "h*"));
731 assert!(matches_wildcard("hello", "he??o"));
732 assert!(matches_wildcard("hello", "*"));
733 assert!(matches_wildcard("hello", "hello"));
734 assert!(!matches_wildcard("hello", "h*x"));
735 assert!(matches_wildcard("test123", "test*"));
736 }
737
738 #[test]
739 fn test_statistics() {
740 let mut stats = QueryStatistics::new();
741
742 stats.record_query("ByName", Duration::from_millis(10), 1, 100);
743 stats.record_query("ByName", Duration::from_millis(20), 1, 100);
744 stats.record_query("ByArity", Duration::from_millis(50), 10, 100);
745
746 assert_eq!(stats.total_queries, 3);
747 assert_eq!(stats.get_selectivity("ByName"), 0.01);
748 assert_eq!(stats.get_selectivity("ByArity"), 0.1);
749
750 let top = stats.top_queries(2);
751 assert_eq!(top[0].0, "ByName");
752 assert_eq!(top[0].1, 2);
753 }
754
755 #[test]
756 fn test_plan_caching() {
757 let table = setup_table();
758 let mut planner = QueryPlanner::new(&table);
759
760 let query = PredicateQuery::by_name("knows");
761
762 planner.plan(&query).expect("unwrap");
763 assert_eq!(planner.cache_size(), 1);
764
765 planner.plan(&query).expect("unwrap");
766 assert_eq!(planner.cache_size(), 1); planner.clear_cache();
769 assert_eq!(planner.cache_size(), 0);
770 }
771
772 #[test]
773 fn test_index_strategy_cost() {
774 let stats = QueryStatistics::new();
775
776 let full_scan = IndexStrategy::FullScan;
777 let hash = IndexStrategy::NameHash;
778
779 assert_eq!(full_scan.estimate_cost(1000, &stats), 1000.0);
780 assert_eq!(hash.estimate_cost(1000, &stats), 1.0);
781 }
782}