1use anyhow::Result;
54use std::collections::HashMap;
55
56use crate::{Embedding, SchemaEmbedder, SchemaStatistics, SymbolTable};
57
58fn cosine_similarity(a: &Embedding, b: &Embedding) -> f64 {
60 let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
61 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
62 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
63
64 if norm_a == 0.0 || norm_b == 0.0 {
65 0.0
66 } else {
67 dot_product / (norm_a * norm_b)
68 }
69}
70
71#[derive(Clone, Debug, PartialEq, Eq)]
73pub enum RecommendationStrategy {
74 Similarity,
76 Pattern,
78 UseCase(String),
80 Hybrid,
82 Collaborative,
84}
85
86#[derive(Clone, Debug)]
88pub struct SchemaScore {
89 pub schema_id: String,
91 pub score: f64,
93 pub reasoning: String,
95 pub factors: HashMap<String, f64>,
97}
98
99impl SchemaScore {
100 pub fn new(schema_id: impl Into<String>, score: f64, reasoning: impl Into<String>) -> Self {
101 Self {
102 schema_id: schema_id.into(),
103 score: score.clamp(0.0, 1.0),
104 reasoning: reasoning.into(),
105 factors: HashMap::new(),
106 }
107 }
108
109 pub fn with_factor(mut self, name: impl Into<String>, value: f64) -> Self {
110 self.factors.insert(name.into(), value);
111 self
112 }
113}
114
115#[derive(Clone, Debug, Default)]
117pub struct RecommendationContext {
118 pub preferences: HashMap<String, f64>,
120 pub history: Vec<String>,
122 pub ratings: HashMap<String, f64>,
124 pub interests: Vec<String>,
126}
127
128impl RecommendationContext {
129 pub fn new() -> Self {
130 Self::default()
131 }
132
133 pub fn with_preference(mut self, key: impl Into<String>, value: f64) -> Self {
134 self.preferences.insert(key.into(), value);
135 self
136 }
137
138 pub fn with_history(mut self, schema_id: impl Into<String>) -> Self {
139 self.history.push(schema_id.into());
140 self
141 }
142
143 pub fn with_rating(mut self, schema_id: impl Into<String>, rating: f64) -> Self {
144 self.ratings.insert(schema_id.into(), rating);
145 self
146 }
147
148 pub fn with_interest(mut self, tag: impl Into<String>) -> Self {
149 self.interests.push(tag.into());
150 self
151 }
152}
153
154#[derive(Clone, Debug)]
156pub struct PatternMatcher {
157 patterns: HashMap<String, Vec<String>>,
158}
159
160impl PatternMatcher {
161 pub fn new() -> Self {
162 Self {
163 patterns: HashMap::new(),
164 }
165 }
166
167 pub fn add_pattern(&mut self, name: impl Into<String>, schema_ids: Vec<String>) {
168 self.patterns.insert(name.into(), schema_ids);
169 }
170
171 pub fn match_pattern(&self, schema: &SymbolTable) -> Vec<String> {
172 let mut matches = Vec::new();
173
174 let domain_count = schema.domains.len();
176 let predicate_count = schema.predicates.len();
177
178 for pattern_name in self.patterns.keys() {
179 let size_match = (pattern_name.contains("small") && domain_count < 5)
181 || (pattern_name.contains("medium") && (5..15).contains(&domain_count))
182 || (pattern_name.contains("large") && domain_count >= 15);
183
184 let complexity_match = (pattern_name.contains("simple") && predicate_count < 10)
185 || (pattern_name.contains("complex") && predicate_count >= 10);
186
187 if size_match || complexity_match {
188 matches.push(pattern_name.clone());
189 }
190 }
191
192 matches
193 }
194}
195
196impl Default for PatternMatcher {
197 fn default() -> Self {
198 Self::new()
199 }
200}
201
202pub struct SchemaRecommender {
204 schemas: HashMap<String, SymbolTable>,
205 embedder: SchemaEmbedder,
206 pattern_matcher: PatternMatcher,
207 usage_counts: HashMap<String, usize>,
208 schema_stats: HashMap<String, SchemaStatistics>,
209}
210
211impl SchemaRecommender {
212 pub fn new() -> Self {
214 Self {
215 schemas: HashMap::new(),
216 embedder: SchemaEmbedder::new(),
217 pattern_matcher: PatternMatcher::new(),
218 usage_counts: HashMap::new(),
219 schema_stats: HashMap::new(),
220 }
221 }
222
223 pub fn add_schema(&mut self, id: impl Into<String>, schema: SymbolTable) {
225 let id = id.into();
226 let stats = SchemaStatistics::compute(&schema);
227 self.schema_stats.insert(id.clone(), stats);
228 self.schemas.insert(id, schema);
229 }
230
231 pub fn remove_schema(&mut self, id: &str) -> Option<SymbolTable> {
233 self.schema_stats.remove(id);
234 self.usage_counts.remove(id);
235 self.schemas.remove(id)
236 }
237
238 pub fn record_usage(&mut self, schema_id: &str) {
240 *self.usage_counts.entry(schema_id.to_string()).or_insert(0) += 1;
241 }
242
243 pub fn recommend(
245 &self,
246 query: &SymbolTable,
247 strategy: RecommendationStrategy,
248 limit: usize,
249 ) -> Result<Vec<SchemaScore>> {
250 match strategy {
251 RecommendationStrategy::Similarity => self.recommend_by_similarity(query, limit),
252 RecommendationStrategy::Pattern => self.recommend_by_pattern(query, limit),
253 RecommendationStrategy::UseCase(use_case) => {
254 self.recommend_by_use_case(query, &use_case, limit)
255 }
256 RecommendationStrategy::Hybrid => self.recommend_hybrid(query, limit),
257 RecommendationStrategy::Collaborative => self.recommend_collaborative(query, limit),
258 }
259 }
260
261 pub fn recommend_with_context(
263 &self,
264 query: &SymbolTable,
265 context: &RecommendationContext,
266 limit: usize,
267 ) -> Result<Vec<SchemaScore>> {
268 let mut base_recommendations = self.recommend_hybrid(query, limit * 2)?;
269
270 for rec in &mut base_recommendations {
272 if let Some(rating) = context.ratings.get(&rec.schema_id) {
274 rec.score = (rec.score + rating) / 2.0;
275 rec.factors.insert("user_rating".to_string(), *rating);
276 }
277
278 if let Some(pos) = context.history.iter().position(|id| id == &rec.schema_id) {
280 let recency_boost = 1.0 - (pos as f64 / context.history.len() as f64) * 0.3;
281 rec.score *= recency_boost;
282 rec.factors.insert("recency".to_string(), recency_boost);
283 }
284
285 for (pref_key, pref_value) in &context.preferences {
287 if rec.schema_id.contains(pref_key) {
288 rec.score = (rec.score + pref_value) / 2.0;
289 rec.factors
290 .insert(format!("preference_{}", pref_key), *pref_value);
291 }
292 }
293 }
294
295 base_recommendations.sort_by(|a, b| {
297 b.score
298 .partial_cmp(&a.score)
299 .unwrap_or(std::cmp::Ordering::Equal)
300 });
301 base_recommendations.truncate(limit);
302
303 Ok(base_recommendations)
304 }
305
306 fn recommend_by_similarity(
307 &self,
308 query: &SymbolTable,
309 limit: usize,
310 ) -> Result<Vec<SchemaScore>> {
311 let query_embedding = self.embedder.embed_schema(query);
312 let mut similarities = Vec::new();
313
314 for (id, schema) in &self.schemas {
316 let schema_embedding = self.embedder.embed_schema(schema);
317 let similarity = cosine_similarity(&query_embedding, &schema_embedding);
318 similarities.push((id.clone(), similarity));
319 }
320
321 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
323 similarities.truncate(limit);
324
325 Ok(similarities
326 .into_iter()
327 .map(|(id, similarity)| {
328 SchemaScore::new(
329 id.clone(),
330 similarity,
331 format!("Similar schema (cosine similarity: {:.2})", similarity),
332 )
333 .with_factor("embedding_similarity", similarity)
334 })
335 .collect())
336 }
337
338 fn recommend_by_pattern(&self, query: &SymbolTable, limit: usize) -> Result<Vec<SchemaScore>> {
339 let patterns = self.pattern_matcher.match_pattern(query);
340 let mut scores = Vec::new();
341
342 for (id, schema) in &self.schemas {
343 let schema_patterns = self.pattern_matcher.match_pattern(schema);
344 let overlap: usize = patterns
345 .iter()
346 .filter(|p| schema_patterns.contains(p))
347 .count();
348
349 if overlap > 0 {
350 let score = overlap as f64 / patterns.len().max(1) as f64;
351 scores.push(
352 SchemaScore::new(
353 id.clone(),
354 score,
355 format!("Matches {} common patterns", overlap),
356 )
357 .with_factor("pattern_overlap", score),
358 );
359 }
360 }
361
362 scores.sort_by(|a, b| {
363 b.score
364 .partial_cmp(&a.score)
365 .unwrap_or(std::cmp::Ordering::Equal)
366 });
367 scores.truncate(limit);
368
369 Ok(scores)
370 }
371
372 fn recommend_by_use_case(
373 &self,
374 query: &SymbolTable,
375 use_case: &str,
376 limit: usize,
377 ) -> Result<Vec<SchemaScore>> {
378 let mut scores = Vec::new();
379 let query_stats = SchemaStatistics::compute(query);
380
381 for id in self.schemas.keys() {
382 if let Some(stats) = self.schema_stats.get(id) {
383 let score = self.compute_use_case_score(use_case, &query_stats, stats);
384 if score > 0.0 {
385 scores.push(
386 SchemaScore::new(
387 id.clone(),
388 score,
389 format!("Suitable for {} use case", use_case),
390 )
391 .with_factor("use_case_match", score),
392 );
393 }
394 }
395 }
396
397 scores.sort_by(|a, b| {
398 b.score
399 .partial_cmp(&a.score)
400 .unwrap_or(std::cmp::Ordering::Equal)
401 });
402 scores.truncate(limit);
403
404 Ok(scores)
405 }
406
407 fn recommend_hybrid(&self, query: &SymbolTable, limit: usize) -> Result<Vec<SchemaScore>> {
408 let similarity_recs = self.recommend_by_similarity(query, limit * 2)?;
410 let pattern_recs = self.recommend_by_pattern(query, limit * 2)?;
411
412 let mut combined: HashMap<String, SchemaScore> = HashMap::new();
413
414 for rec in similarity_recs {
416 combined.insert(rec.schema_id.clone(), rec);
417 }
418
419 for rec in pattern_recs {
420 combined
421 .entry(rec.schema_id.clone())
422 .and_modify(|existing| {
423 existing.score = (existing.score + rec.score) / 2.0;
424 existing.reasoning.push_str(&format!("; {}", rec.reasoning));
425 for (k, v) in rec.factors.clone() {
426 existing.factors.insert(k, v);
427 }
428 })
429 .or_insert(rec);
430 }
431
432 let mut results: Vec<SchemaScore> = combined.into_values().collect();
433 results.sort_by(|a, b| {
434 b.score
435 .partial_cmp(&a.score)
436 .unwrap_or(std::cmp::Ordering::Equal)
437 });
438 results.truncate(limit);
439
440 Ok(results)
441 }
442
443 fn recommend_collaborative(
444 &self,
445 _query: &SymbolTable,
446 limit: usize,
447 ) -> Result<Vec<SchemaScore>> {
448 let mut scores: Vec<SchemaScore> = self
449 .usage_counts
450 .iter()
451 .map(|(id, count)| {
452 let max_count = self.usage_counts.values().max().unwrap_or(&1);
453 let score = *count as f64 / *max_count as f64;
454 SchemaScore::new(
455 id.clone(),
456 score,
457 format!("Popular schema (used {} times)", count),
458 )
459 .with_factor("usage_count", *count as f64)
460 })
461 .collect();
462
463 scores.sort_by(|a, b| {
464 b.score
465 .partial_cmp(&a.score)
466 .unwrap_or(std::cmp::Ordering::Equal)
467 });
468 scores.truncate(limit);
469
470 Ok(scores)
471 }
472
473 fn compute_use_case_score(
474 &self,
475 use_case: &str,
476 query_stats: &SchemaStatistics,
477 candidate_stats: &SchemaStatistics,
478 ) -> f64 {
479 match use_case.to_lowercase().as_str() {
480 "simple" => {
481 let complexity_diff =
483 (query_stats.complexity_score() - candidate_stats.complexity_score()).abs();
484 f64::max(0.0, 1.0 - complexity_diff / 10.0)
485 }
486 "large" => {
487 if candidate_stats.domain_count > 10 {
489 0.8
490 } else {
491 0.3
492 }
493 }
494 "relational" => {
495 let predicate_ratio = candidate_stats.predicate_count as f64
497 / candidate_stats.domain_count.max(1) as f64;
498 (predicate_ratio / 3.0).min(1.0)
499 }
500 _ => 0.5, }
502 }
503
504 pub fn stats(&self) -> RecommenderStats {
506 RecommenderStats {
507 total_schemas: self.schemas.len(),
508 total_patterns: self.pattern_matcher.patterns.len(),
509 total_usage_records: self.usage_counts.values().sum(),
510 most_used_schema: self
511 .usage_counts
512 .iter()
513 .max_by_key(|(_, count)| *count)
514 .map(|(id, _)| id.clone()),
515 }
516 }
517}
518
519impl Default for SchemaRecommender {
520 fn default() -> Self {
521 Self::new()
522 }
523}
524
525#[derive(Clone, Debug)]
527pub struct RecommenderStats {
528 pub total_schemas: usize,
529 pub total_patterns: usize,
530 pub total_usage_records: usize,
531 pub most_used_schema: Option<String>,
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537 use crate::DomainInfo;
538
539 fn create_test_schema(name: &str, domain_count: usize) -> SymbolTable {
540 let mut schema = SymbolTable::new();
541 for i in 0..domain_count {
542 schema
543 .add_domain(DomainInfo::new(format!("{}Domain{}", name, i), 100))
544 .expect("unwrap");
545 }
546 schema
547 }
548
549 #[test]
550 fn test_schema_score_creation() {
551 let score = SchemaScore::new("test", 0.85, "High similarity");
552 assert_eq!(score.schema_id, "test");
553 assert_eq!(score.score, 0.85);
554 assert_eq!(score.reasoning, "High similarity");
555 }
556
557 #[test]
558 fn test_schema_score_with_factors() {
559 let score = SchemaScore::new("test", 0.8, "reason")
560 .with_factor("similarity", 0.9)
561 .with_factor("popularity", 0.7);
562
563 assert_eq!(score.factors.len(), 2);
564 assert_eq!(score.factors.get("similarity"), Some(&0.9));
565 }
566
567 #[test]
568 fn test_recommendation_context() {
569 let context = RecommendationContext::new()
570 .with_preference("users", 0.9)
571 .with_history("schema1")
572 .with_rating("schema2", 0.8)
573 .with_interest("database");
574
575 assert_eq!(context.preferences.get("users"), Some(&0.9));
576 assert_eq!(context.history.len(), 1);
577 assert_eq!(context.ratings.get("schema2"), Some(&0.8));
578 assert_eq!(context.interests.len(), 1);
579 }
580
581 #[test]
582 fn test_pattern_matcher() {
583 let mut matcher = PatternMatcher::new();
584 matcher.add_pattern("small_schema", vec!["s1".to_string()]);
585
586 let schema = create_test_schema("Test", 3);
587 let matches = matcher.match_pattern(&schema);
588
589 assert!(!matches.is_empty());
590 }
591
592 #[test]
593 fn test_recommender_add_remove() {
594 let mut recommender = SchemaRecommender::new();
595 let schema = create_test_schema("Test", 5);
596
597 recommender.add_schema("test1", schema.clone());
598 assert_eq!(recommender.schemas.len(), 1);
599
600 let removed = recommender.remove_schema("test1");
601 assert!(removed.is_some());
602 assert_eq!(recommender.schemas.len(), 0);
603 }
604
605 #[test]
606 fn test_recommend_by_similarity() {
607 let mut recommender = SchemaRecommender::new();
608
609 recommender.add_schema("schema1", create_test_schema("A", 3));
610 recommender.add_schema("schema2", create_test_schema("B", 5));
611 recommender.add_schema("schema3", create_test_schema("C", 3));
612
613 let query = create_test_schema("Query", 3);
614 let recs = recommender
615 .recommend(&query, RecommendationStrategy::Similarity, 2)
616 .expect("unwrap");
617
618 assert!(!recs.is_empty());
619 assert!(recs.len() <= 2);
620 }
621
622 #[test]
623 fn test_recommend_by_pattern() {
624 let mut recommender = SchemaRecommender::new();
625
626 recommender.pattern_matcher.add_pattern(
628 "small_simple",
629 vec!["small1".to_string(), "small2".to_string()],
630 );
631
632 recommender.add_schema("small1", create_test_schema("S1", 2));
633 recommender.add_schema("small2", create_test_schema("S2", 3));
634 recommender.add_schema("large1", create_test_schema("L1", 20));
635
636 let query = create_test_schema("Query", 2);
637 let recs = recommender
638 .recommend(&query, RecommendationStrategy::Pattern, 2)
639 .expect("unwrap");
640
641 assert!(recs.len() <= 2);
644 }
645
646 #[test]
647 fn test_recommend_collaborative() {
648 let mut recommender = SchemaRecommender::new();
649
650 recommender.add_schema("popular", create_test_schema("P", 5));
651 recommender.add_schema("unpopular", create_test_schema("U", 5));
652
653 recommender.record_usage("popular");
654 recommender.record_usage("popular");
655 recommender.record_usage("popular");
656 recommender.record_usage("unpopular");
657
658 let query = create_test_schema("Query", 5);
659 let recs = recommender
660 .recommend(&query, RecommendationStrategy::Collaborative, 2)
661 .expect("unwrap");
662
663 assert!(!recs.is_empty());
664 assert_eq!(recs[0].schema_id, "popular");
665 }
666
667 #[test]
668 fn test_recommend_hybrid() {
669 let mut recommender = SchemaRecommender::new();
670
671 recommender.add_schema("schema1", create_test_schema("A", 3));
672 recommender.add_schema("schema2", create_test_schema("B", 5));
673
674 let query = create_test_schema("Query", 3);
675 let recs = recommender
676 .recommend(&query, RecommendationStrategy::Hybrid, 2)
677 .expect("unwrap");
678
679 assert!(!recs.is_empty());
680 }
681
682 #[test]
683 fn test_recommend_with_context() {
684 let mut recommender = SchemaRecommender::new();
685
686 recommender.add_schema("schema1", create_test_schema("A", 3));
687 recommender.add_schema("schema2", create_test_schema("B", 5));
688
689 let context = RecommendationContext::new()
690 .with_rating("schema1", 0.9)
691 .with_history("schema2");
692
693 let query = create_test_schema("Query", 3);
694 let recs = recommender
695 .recommend_with_context(&query, &context, 2)
696 .expect("unwrap");
697
698 assert!(!recs.is_empty());
699 }
700
701 #[test]
702 fn test_recommender_stats() {
703 let mut recommender = SchemaRecommender::new();
704
705 recommender.add_schema("s1", create_test_schema("A", 3));
706 recommender.add_schema("s2", create_test_schema("B", 5));
707 recommender.record_usage("s1");
708 recommender.record_usage("s1");
709
710 let stats = recommender.stats();
711 assert_eq!(stats.total_schemas, 2);
712 assert_eq!(stats.total_usage_records, 2);
713 assert_eq!(stats.most_used_schema, Some("s1".to_string()));
714 }
715
716 #[test]
717 fn test_use_case_recommendations() {
718 let mut recommender = SchemaRecommender::new();
719
720 recommender.add_schema("simple", create_test_schema("S", 3));
721 recommender.add_schema("complex", create_test_schema("C", 15));
722
723 let query = create_test_schema("Query", 3);
724 let recs = recommender
725 .recommend(
726 &query,
727 RecommendationStrategy::UseCase("large".to_string()),
728 2,
729 )
730 .expect("unwrap");
731
732 assert!(!recs.is_empty());
733 }
734
735 #[test]
736 fn test_record_usage() {
737 let mut recommender = SchemaRecommender::new();
738 recommender.add_schema("test", create_test_schema("T", 5));
739
740 recommender.record_usage("test");
741 recommender.record_usage("test");
742
743 assert_eq!(recommender.usage_counts.get("test"), Some(&2));
744 }
745}