1use anyhow::Result;
43use std::collections::{HashMap, HashSet};
44use std::time::{Duration, Instant};
45
46use crate::{PredicateInfo, SymbolTable};
47
48#[derive(Clone, Debug, PartialEq, Eq, Hash)]
50pub enum PredicateQuery {
51 ByName(String),
53 ByArity(usize),
55 BySignature(Vec<String>),
57 ByDomain(String),
59 ByPattern(PredicatePattern),
61 And(Vec<PredicateQuery>),
63 Or(Vec<PredicateQuery>),
65}
66
67impl PredicateQuery {
68 pub fn by_name(name: impl Into<String>) -> Self {
69 Self::ByName(name.into())
70 }
71
72 pub fn by_arity(arity: usize) -> Self {
73 Self::ByArity(arity)
74 }
75
76 pub fn by_signature(domains: Vec<String>) -> Self {
77 Self::BySignature(domains)
78 }
79
80 pub fn by_domain(domain: impl Into<String>) -> Self {
81 Self::ByDomain(domain.into())
82 }
83
84 pub fn by_pattern(pattern: PredicatePattern) -> Self {
85 Self::ByPattern(pattern)
86 }
87
88 pub fn and(queries: Vec<PredicateQuery>) -> Self {
89 Self::And(queries)
90 }
91
92 pub fn or(queries: Vec<PredicateQuery>) -> Self {
93 Self::Or(queries)
94 }
95}
96
97#[derive(Clone, Debug, PartialEq, Eq, Hash)]
99pub struct PredicatePattern {
100 pub name_pattern: Option<String>,
102 pub min_arity: Option<usize>,
104 pub max_arity: Option<usize>,
106 pub required_domains: Vec<String>,
108 pub excluded_domains: Vec<String>,
110}
111
112impl PredicatePattern {
113 pub fn new() -> Self {
114 Self {
115 name_pattern: None,
116 min_arity: None,
117 max_arity: None,
118 required_domains: Vec::new(),
119 excluded_domains: Vec::new(),
120 }
121 }
122
123 pub fn with_name_pattern(mut self, pattern: impl Into<String>) -> Self {
124 self.name_pattern = Some(pattern.into());
125 self
126 }
127
128 pub fn with_arity_range(mut self, min: usize, max: usize) -> Self {
129 self.min_arity = Some(min);
130 self.max_arity = Some(max);
131 self
132 }
133
134 pub fn with_required_domain(mut self, domain: impl Into<String>) -> Self {
135 self.required_domains.push(domain.into());
136 self
137 }
138
139 pub fn with_excluded_domain(mut self, domain: impl Into<String>) -> Self {
140 self.excluded_domains.push(domain.into());
141 self
142 }
143
144 pub fn matches(&self, name: &str, predicate: &PredicateInfo) -> bool {
146 if let Some(pattern) = &self.name_pattern {
148 if !matches_wildcard(name, pattern) {
149 return false;
150 }
151 }
152
153 let arity = predicate.arg_domains.len();
155 if let Some(min) = self.min_arity {
156 if arity < min {
157 return false;
158 }
159 }
160 if let Some(max) = self.max_arity {
161 if arity > max {
162 return false;
163 }
164 }
165
166 let domain_set: HashSet<_> = predicate.arg_domains.iter().collect();
168 for required in &self.required_domains {
169 if !domain_set.contains(required) {
170 return false;
171 }
172 }
173
174 for excluded in &self.excluded_domains {
176 if domain_set.contains(excluded) {
177 return false;
178 }
179 }
180
181 true
182 }
183}
184
185impl Default for PredicatePattern {
186 fn default() -> Self {
187 Self::new()
188 }
189}
190
191fn matches_wildcard(text: &str, pattern: &str) -> bool {
193 let text_chars: Vec<char> = text.chars().collect();
194 let pattern_chars: Vec<char> = pattern.chars().collect();
195
196 let mut dp = vec![vec![false; pattern_chars.len() + 1]; text_chars.len() + 1];
197 dp[0][0] = true;
198
199 for j in 1..=pattern_chars.len() {
201 if pattern_chars[j - 1] == '*' {
202 dp[0][j] = dp[0][j - 1];
203 }
204 }
205
206 for i in 1..=text_chars.len() {
207 for j in 1..=pattern_chars.len() {
208 if pattern_chars[j - 1] == '*' {
209 dp[i][j] = dp[i - 1][j] || dp[i][j - 1];
210 } else if pattern_chars[j - 1] == '?' || text_chars[i - 1] == pattern_chars[j - 1] {
211 dp[i][j] = dp[i - 1][j - 1];
212 }
213 }
214 }
215
216 dp[text_chars.len()][pattern_chars.len()]
217}
218
219#[derive(Clone, Debug)]
221pub struct QueryStatistics {
222 query_counts: HashMap<String, usize>,
224 selectivity: HashMap<String, f64>,
226 avg_execution_time: HashMap<String, Duration>,
228 total_queries: usize,
230}
231
232impl QueryStatistics {
233 pub fn new() -> Self {
234 Self {
235 query_counts: HashMap::new(),
236 selectivity: HashMap::new(),
237 avg_execution_time: HashMap::new(),
238 total_queries: 0,
239 }
240 }
241
242 pub fn record_query(
244 &mut self,
245 query_type: impl Into<String>,
246 duration: Duration,
247 result_count: usize,
248 total_predicates: usize,
249 ) {
250 let query_type = query_type.into();
251
252 *self.query_counts.entry(query_type.clone()).or_insert(0) += 1;
253 self.total_queries += 1;
254
255 let selectivity = if total_predicates > 0 {
256 result_count as f64 / total_predicates as f64
257 } else {
258 0.0
259 };
260
261 self.selectivity.insert(query_type.clone(), selectivity);
262
263 let count = self.query_counts[&query_type];
264 let current_avg = self
265 .avg_execution_time
266 .get(&query_type)
267 .copied()
268 .unwrap_or(Duration::ZERO);
269
270 let new_avg = (current_avg * (count as u32 - 1) + duration) / count as u32;
271 self.avg_execution_time.insert(query_type, new_avg);
272 }
273
274 pub fn top_queries(&self, limit: usize) -> Vec<(String, usize)> {
276 let mut queries: Vec<_> = self
277 .query_counts
278 .iter()
279 .map(|(k, v)| (k.clone(), *v))
280 .collect();
281 queries.sort_by(|a, b| b.1.cmp(&a.1));
282 queries.truncate(limit);
283 queries
284 }
285
286 pub fn get_selectivity(&self, query_type: &str) -> f64 {
288 self.selectivity.get(query_type).copied().unwrap_or(1.0)
289 }
290
291 pub fn get_avg_time(&self, query_type: &str) -> Duration {
293 self.avg_execution_time
294 .get(query_type)
295 .copied()
296 .unwrap_or(Duration::ZERO)
297 }
298}
299
300impl Default for QueryStatistics {
301 fn default() -> Self {
302 Self::new()
303 }
304}
305
306#[derive(Clone, Debug, PartialEq, Eq)]
308pub enum IndexStrategy {
309 FullScan,
311 NameHash,
313 ArityRange,
315 SignatureHash,
317 DomainInverted,
319 Composite(Vec<IndexStrategy>),
321}
322
323impl IndexStrategy {
324 pub fn estimate_cost(&self, predicates_count: usize, _stats: &QueryStatistics) -> f64 {
326 match self {
327 IndexStrategy::FullScan => predicates_count as f64,
328 IndexStrategy::NameHash => 1.0, IndexStrategy::ArityRange => (predicates_count as f64).sqrt(), IndexStrategy::SignatureHash => 1.0, IndexStrategy::DomainInverted => (predicates_count as f64).log2(), IndexStrategy::Composite(strategies) => {
333 strategies
335 .iter()
336 .map(|s| s.estimate_cost(predicates_count, _stats))
337 .min_by(|a, b| a.partial_cmp(b).unwrap())
338 .unwrap_or(predicates_count as f64)
339 }
340 }
341 }
342}
343
344#[derive(Clone, Debug)]
346pub struct QueryPlan {
347 query: PredicateQuery,
348 strategy: IndexStrategy,
349 estimated_cost: f64,
350 estimated_results: usize,
351}
352
353impl QueryPlan {
354 pub fn new(query: PredicateQuery, strategy: IndexStrategy) -> Self {
355 Self {
356 query,
357 strategy,
358 estimated_cost: 0.0,
359 estimated_results: 0,
360 }
361 }
362
363 pub fn with_cost(mut self, cost: f64) -> Self {
364 self.estimated_cost = cost;
365 self
366 }
367
368 pub fn with_estimated_results(mut self, count: usize) -> Self {
369 self.estimated_results = count;
370 self
371 }
372
373 pub fn execute(&self, table: &SymbolTable) -> Result<Vec<(String, PredicateInfo)>> {
375 match &self.query {
376 PredicateQuery::ByName(name) => {
377 if let Some(pred) = table.predicates.get(name) {
378 Ok(vec![(name.clone(), pred.clone())])
379 } else {
380 Ok(Vec::new())
381 }
382 }
383 PredicateQuery::ByArity(arity) => {
384 let results: Vec<_> = table
385 .predicates
386 .iter()
387 .filter(|(_, pred)| pred.arg_domains.len() == *arity)
388 .map(|(name, pred)| (name.clone(), pred.clone()))
389 .collect();
390 Ok(results)
391 }
392 PredicateQuery::BySignature(signature) => {
393 let results: Vec<_> = table
394 .predicates
395 .iter()
396 .filter(|(_, pred)| pred.arg_domains == *signature)
397 .map(|(name, pred)| (name.clone(), pred.clone()))
398 .collect();
399 Ok(results)
400 }
401 PredicateQuery::ByDomain(domain) => {
402 let results: Vec<_> = table
403 .predicates
404 .iter()
405 .filter(|(_, pred)| pred.arg_domains.contains(domain))
406 .map(|(name, pred)| (name.clone(), pred.clone()))
407 .collect();
408 Ok(results)
409 }
410 PredicateQuery::ByPattern(pattern) => {
411 let results: Vec<_> = table
412 .predicates
413 .iter()
414 .filter(|(name, pred)| pattern.matches(name, pred))
415 .map(|(name, pred)| (name.clone(), pred.clone()))
416 .collect();
417 Ok(results)
418 }
419 PredicateQuery::And(queries) => {
420 if queries.is_empty() {
421 return Ok(Vec::new());
422 }
423
424 let mut results: HashSet<String> = self
426 .execute_subquery(&queries[0], table)?
427 .into_iter()
428 .map(|(name, _)| name)
429 .collect();
430
431 for query in &queries[1..] {
433 let subresults: HashSet<String> = self
434 .execute_subquery(query, table)?
435 .into_iter()
436 .map(|(name, _)| name)
437 .collect();
438 results.retain(|name| subresults.contains(name));
439 }
440
441 Ok(results
442 .into_iter()
443 .filter_map(|name| {
444 table
445 .predicates
446 .get(&name)
447 .map(|pred| (name.clone(), pred.clone()))
448 })
449 .collect())
450 }
451 PredicateQuery::Or(queries) => {
452 let mut results_map: HashMap<String, PredicateInfo> = HashMap::new();
453
454 for query in queries {
455 let subresults = self.execute_subquery(query, table)?;
456 for (name, pred) in subresults {
457 results_map.insert(name, pred);
458 }
459 }
460
461 Ok(results_map.into_iter().collect())
462 }
463 }
464 }
465
466 fn execute_subquery(
467 &self,
468 query: &PredicateQuery,
469 table: &SymbolTable,
470 ) -> Result<Vec<(String, PredicateInfo)>> {
471 let subplan = QueryPlan::new(query.clone(), self.strategy.clone());
472 subplan.execute(table)
473 }
474
475 pub fn query(&self) -> &PredicateQuery {
476 &self.query
477 }
478
479 pub fn strategy(&self) -> &IndexStrategy {
480 &self.strategy
481 }
482
483 pub fn estimated_cost(&self) -> f64 {
484 self.estimated_cost
485 }
486}
487
488pub struct QueryPlanner<'a> {
490 table: &'a SymbolTable,
491 statistics: QueryStatistics,
492 plan_cache: HashMap<PredicateQuery, QueryPlan>,
493}
494
495impl<'a> QueryPlanner<'a> {
496 pub fn new(table: &'a SymbolTable) -> Self {
497 Self {
498 table,
499 statistics: QueryStatistics::new(),
500 plan_cache: HashMap::new(),
501 }
502 }
503
504 pub fn with_statistics(mut self, statistics: QueryStatistics) -> Self {
505 self.statistics = statistics;
506 self
507 }
508
509 pub fn plan(&mut self, query: &PredicateQuery) -> Result<QueryPlan> {
511 if let Some(cached) = self.plan_cache.get(query) {
513 return Ok(cached.clone());
514 }
515
516 let plan = self.generate_plan(query)?;
517 self.plan_cache.insert(query.clone(), plan.clone());
518 Ok(plan)
519 }
520
521 fn generate_plan(&self, query: &PredicateQuery) -> Result<QueryPlan> {
523 let strategy = self.select_strategy(query);
524 let cost = strategy.estimate_cost(self.table.predicates.len(), &self.statistics);
525
526 let plan = QueryPlan::new(query.clone(), strategy).with_cost(cost);
527
528 Ok(plan)
529 }
530
531 fn select_strategy(&self, query: &PredicateQuery) -> IndexStrategy {
533 Self::select_strategy_static(query)
534 }
535
536 fn select_strategy_static(query: &PredicateQuery) -> IndexStrategy {
538 match query {
539 PredicateQuery::ByName(_) => IndexStrategy::NameHash,
540 PredicateQuery::ByArity(_) => IndexStrategy::ArityRange,
541 PredicateQuery::BySignature(_) => IndexStrategy::SignatureHash,
542 PredicateQuery::ByDomain(_) => IndexStrategy::DomainInverted,
543 PredicateQuery::ByPattern(_) => {
544 IndexStrategy::FullScan
546 }
547 PredicateQuery::And(queries) => {
548 let strategies: Vec<_> = queries.iter().map(Self::select_strategy_static).collect();
550 IndexStrategy::Composite(strategies)
551 }
552 PredicateQuery::Or(queries) => {
553 let strategies: Vec<_> = queries.iter().map(Self::select_strategy_static).collect();
555 IndexStrategy::Composite(strategies)
556 }
557 }
558 }
559
560 pub fn execute(&mut self, query: &PredicateQuery) -> Result<Vec<(String, PredicateInfo)>> {
562 let start = Instant::now();
563 let plan = self.plan(query)?;
564 let results = plan.execute(self.table)?;
565 let duration = start.elapsed();
566
567 let query_type = format!("{:?}", query)
568 .split('(')
569 .next()
570 .unwrap_or("Unknown")
571 .to_string();
572 self.statistics.record_query(
573 query_type,
574 duration,
575 results.len(),
576 self.table.predicates.len(),
577 );
578
579 Ok(results)
580 }
581
582 pub fn statistics(&self) -> &QueryStatistics {
584 &self.statistics
585 }
586
587 pub fn clear_cache(&mut self) {
589 self.plan_cache.clear();
590 }
591
592 pub fn cache_size(&self) -> usize {
594 self.plan_cache.len()
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601 use crate::DomainInfo;
602
603 fn setup_table() -> SymbolTable {
604 let mut table = SymbolTable::new();
605 table.add_domain(DomainInfo::new("Person", 100)).unwrap();
606 table.add_domain(DomainInfo::new("Location", 50)).unwrap();
607
608 let knows = PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()]);
609 table.add_predicate(knows).unwrap();
610
611 let at = PredicateInfo::new("at", vec!["Person".to_string(), "Location".to_string()]);
612 table.add_predicate(at).unwrap();
613
614 let friends =
615 PredicateInfo::new("friends", vec!["Person".to_string(), "Person".to_string()]);
616 table.add_predicate(friends).unwrap();
617
618 table
619 }
620
621 #[test]
622 fn test_query_by_name() {
623 let table = setup_table();
624 let mut planner = QueryPlanner::new(&table);
625
626 let query = PredicateQuery::by_name("knows");
627 let results = planner.execute(&query).unwrap();
628
629 assert_eq!(results.len(), 1);
630 assert_eq!(results[0].0, "knows");
631 }
632
633 #[test]
634 fn test_query_by_arity() {
635 let table = setup_table();
636 let mut planner = QueryPlanner::new(&table);
637
638 let query = PredicateQuery::by_arity(2);
639 let results = planner.execute(&query).unwrap();
640
641 assert_eq!(results.len(), 3); }
643
644 #[test]
645 fn test_query_by_signature() {
646 let table = setup_table();
647 let mut planner = QueryPlanner::new(&table);
648
649 let query = PredicateQuery::by_signature(vec!["Person".to_string(), "Person".to_string()]);
650 let results = planner.execute(&query).unwrap();
651
652 assert_eq!(results.len(), 2); }
654
655 #[test]
656 fn test_query_by_domain() {
657 let table = setup_table();
658 let mut planner = QueryPlanner::new(&table);
659
660 let query = PredicateQuery::by_domain("Location");
661 let results = planner.execute(&query).unwrap();
662
663 assert_eq!(results.len(), 1); assert_eq!(results[0].0, "at");
665 }
666
667 #[test]
668 fn test_query_and() {
669 let table = setup_table();
670 let mut planner = QueryPlanner::new(&table);
671
672 let query = PredicateQuery::and(vec![
673 PredicateQuery::by_arity(2),
674 PredicateQuery::by_domain("Location"),
675 ]);
676 let results = planner.execute(&query).unwrap();
677
678 assert_eq!(results.len(), 1); }
680
681 #[test]
682 fn test_query_or() {
683 let table = setup_table();
684 let mut planner = QueryPlanner::new(&table);
685
686 let query = PredicateQuery::or(vec![
687 PredicateQuery::by_name("knows"),
688 PredicateQuery::by_name("at"),
689 ]);
690 let results = planner.execute(&query).unwrap();
691
692 assert_eq!(results.len(), 2); }
694
695 #[test]
696 fn test_predicate_pattern() {
697 let pattern = PredicatePattern::new()
698 .with_name_pattern("know*")
699 .with_arity_range(2, 3)
700 .with_required_domain("Person");
701
702 let knows = PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()]);
703 assert!(pattern.matches("knows", &knows));
704
705 let at = PredicateInfo::new("at", vec!["Person".to_string(), "Location".to_string()]);
706 assert!(!pattern.matches("at", &at));
707 }
708
709 #[test]
710 fn test_query_by_pattern() {
711 let table = setup_table();
712 let mut planner = QueryPlanner::new(&table);
713
714 let pattern = PredicatePattern::new()
715 .with_name_pattern("*friend*")
716 .with_required_domain("Person");
717
718 let query = PredicateQuery::by_pattern(pattern);
719 let results = planner.execute(&query).unwrap();
720
721 assert_eq!(results.len(), 1); }
723
724 #[test]
725 fn test_wildcard_matching() {
726 assert!(matches_wildcard("hello", "h*"));
727 assert!(matches_wildcard("hello", "he??o"));
728 assert!(matches_wildcard("hello", "*"));
729 assert!(matches_wildcard("hello", "hello"));
730 assert!(!matches_wildcard("hello", "h*x"));
731 assert!(matches_wildcard("test123", "test*"));
732 }
733
734 #[test]
735 fn test_statistics() {
736 let mut stats = QueryStatistics::new();
737
738 stats.record_query("ByName", Duration::from_millis(10), 1, 100);
739 stats.record_query("ByName", Duration::from_millis(20), 1, 100);
740 stats.record_query("ByArity", Duration::from_millis(50), 10, 100);
741
742 assert_eq!(stats.total_queries, 3);
743 assert_eq!(stats.get_selectivity("ByName"), 0.01);
744 assert_eq!(stats.get_selectivity("ByArity"), 0.1);
745
746 let top = stats.top_queries(2);
747 assert_eq!(top[0].0, "ByName");
748 assert_eq!(top[0].1, 2);
749 }
750
751 #[test]
752 fn test_plan_caching() {
753 let table = setup_table();
754 let mut planner = QueryPlanner::new(&table);
755
756 let query = PredicateQuery::by_name("knows");
757
758 planner.plan(&query).unwrap();
759 assert_eq!(planner.cache_size(), 1);
760
761 planner.plan(&query).unwrap();
762 assert_eq!(planner.cache_size(), 1); planner.clear_cache();
765 assert_eq!(planner.cache_size(), 0);
766 }
767
768 #[test]
769 fn test_index_strategy_cost() {
770 let stats = QueryStatistics::new();
771
772 let full_scan = IndexStrategy::FullScan;
773 let hash = IndexStrategy::NameHash;
774
775 assert_eq!(full_scan.estimate_cost(1000, &stats), 1000.0);
776 assert_eq!(hash.estimate_cost(1000, &stats), 1.0);
777 }
778}