1use anyhow::Result;
54use std::collections::HashMap;
55
56use crate::{Embedding, SchemaEmbedder, SchemaStatistics, SymbolTable};
57
58fn cosine_similarity(a: &Embedding, b: &Embedding) -> f64 {
60 let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
61 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
62 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
63
64 if norm_a == 0.0 || norm_b == 0.0 {
65 0.0
66 } else {
67 dot_product / (norm_a * norm_b)
68 }
69}
70
71#[derive(Clone, Debug, PartialEq, Eq)]
73pub enum RecommendationStrategy {
74 Similarity,
76 Pattern,
78 UseCase(String),
80 Hybrid,
82 Collaborative,
84}
85
86#[derive(Clone, Debug)]
88pub struct SchemaScore {
89 pub schema_id: String,
91 pub score: f64,
93 pub reasoning: String,
95 pub factors: HashMap<String, f64>,
97}
98
99impl SchemaScore {
100 pub fn new(schema_id: impl Into<String>, score: f64, reasoning: impl Into<String>) -> Self {
101 Self {
102 schema_id: schema_id.into(),
103 score: score.clamp(0.0, 1.0),
104 reasoning: reasoning.into(),
105 factors: HashMap::new(),
106 }
107 }
108
109 pub fn with_factor(mut self, name: impl Into<String>, value: f64) -> Self {
110 self.factors.insert(name.into(), value);
111 self
112 }
113}
114
115#[derive(Clone, Debug, Default)]
117pub struct RecommendationContext {
118 pub preferences: HashMap<String, f64>,
120 pub history: Vec<String>,
122 pub ratings: HashMap<String, f64>,
124 pub interests: Vec<String>,
126}
127
128impl RecommendationContext {
129 pub fn new() -> Self {
130 Self::default()
131 }
132
133 pub fn with_preference(mut self, key: impl Into<String>, value: f64) -> Self {
134 self.preferences.insert(key.into(), value);
135 self
136 }
137
138 pub fn with_history(mut self, schema_id: impl Into<String>) -> Self {
139 self.history.push(schema_id.into());
140 self
141 }
142
143 pub fn with_rating(mut self, schema_id: impl Into<String>, rating: f64) -> Self {
144 self.ratings.insert(schema_id.into(), rating);
145 self
146 }
147
148 pub fn with_interest(mut self, tag: impl Into<String>) -> Self {
149 self.interests.push(tag.into());
150 self
151 }
152}
153
154#[derive(Clone, Debug)]
156pub struct PatternMatcher {
157 patterns: HashMap<String, Vec<String>>,
158}
159
160impl PatternMatcher {
161 pub fn new() -> Self {
162 Self {
163 patterns: HashMap::new(),
164 }
165 }
166
167 pub fn add_pattern(&mut self, name: impl Into<String>, schema_ids: Vec<String>) {
168 self.patterns.insert(name.into(), schema_ids);
169 }
170
171 pub fn match_pattern(&self, schema: &SymbolTable) -> Vec<String> {
172 let mut matches = Vec::new();
173
174 let domain_count = schema.domains.len();
176 let predicate_count = schema.predicates.len();
177
178 for pattern_name in self.patterns.keys() {
179 let size_match = (pattern_name.contains("small") && domain_count < 5)
181 || (pattern_name.contains("medium") && (5..15).contains(&domain_count))
182 || (pattern_name.contains("large") && domain_count >= 15);
183
184 let complexity_match = (pattern_name.contains("simple") && predicate_count < 10)
185 || (pattern_name.contains("complex") && predicate_count >= 10);
186
187 if size_match || complexity_match {
188 matches.push(pattern_name.clone());
189 }
190 }
191
192 matches
193 }
194}
195
196impl Default for PatternMatcher {
197 fn default() -> Self {
198 Self::new()
199 }
200}
201
202pub struct SchemaRecommender {
204 schemas: HashMap<String, SymbolTable>,
205 embedder: SchemaEmbedder,
206 pattern_matcher: PatternMatcher,
207 usage_counts: HashMap<String, usize>,
208 schema_stats: HashMap<String, SchemaStatistics>,
209}
210
211impl SchemaRecommender {
212 pub fn new() -> Self {
214 Self {
215 schemas: HashMap::new(),
216 embedder: SchemaEmbedder::new(),
217 pattern_matcher: PatternMatcher::new(),
218 usage_counts: HashMap::new(),
219 schema_stats: HashMap::new(),
220 }
221 }
222
223 pub fn add_schema(&mut self, id: impl Into<String>, schema: SymbolTable) {
225 let id = id.into();
226 let stats = SchemaStatistics::compute(&schema);
227 self.schema_stats.insert(id.clone(), stats);
228 self.schemas.insert(id, schema);
229 }
230
231 pub fn remove_schema(&mut self, id: &str) -> Option<SymbolTable> {
233 self.schema_stats.remove(id);
234 self.usage_counts.remove(id);
235 self.schemas.remove(id)
236 }
237
238 pub fn record_usage(&mut self, schema_id: &str) {
240 *self.usage_counts.entry(schema_id.to_string()).or_insert(0) += 1;
241 }
242
243 pub fn recommend(
245 &self,
246 query: &SymbolTable,
247 strategy: RecommendationStrategy,
248 limit: usize,
249 ) -> Result<Vec<SchemaScore>> {
250 match strategy {
251 RecommendationStrategy::Similarity => self.recommend_by_similarity(query, limit),
252 RecommendationStrategy::Pattern => self.recommend_by_pattern(query, limit),
253 RecommendationStrategy::UseCase(use_case) => {
254 self.recommend_by_use_case(query, &use_case, limit)
255 }
256 RecommendationStrategy::Hybrid => self.recommend_hybrid(query, limit),
257 RecommendationStrategy::Collaborative => self.recommend_collaborative(query, limit),
258 }
259 }
260
261 pub fn recommend_with_context(
263 &self,
264 query: &SymbolTable,
265 context: &RecommendationContext,
266 limit: usize,
267 ) -> Result<Vec<SchemaScore>> {
268 let mut base_recommendations = self.recommend_hybrid(query, limit * 2)?;
269
270 for rec in &mut base_recommendations {
272 if let Some(rating) = context.ratings.get(&rec.schema_id) {
274 rec.score = (rec.score + rating) / 2.0;
275 rec.factors.insert("user_rating".to_string(), *rating);
276 }
277
278 if let Some(pos) = context.history.iter().position(|id| id == &rec.schema_id) {
280 let recency_boost = 1.0 - (pos as f64 / context.history.len() as f64) * 0.3;
281 rec.score *= recency_boost;
282 rec.factors.insert("recency".to_string(), recency_boost);
283 }
284
285 for (pref_key, pref_value) in &context.preferences {
287 if rec.schema_id.contains(pref_key) {
288 rec.score = (rec.score + pref_value) / 2.0;
289 rec.factors
290 .insert(format!("preference_{}", pref_key), *pref_value);
291 }
292 }
293 }
294
295 base_recommendations.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
297 base_recommendations.truncate(limit);
298
299 Ok(base_recommendations)
300 }
301
302 fn recommend_by_similarity(
303 &self,
304 query: &SymbolTable,
305 limit: usize,
306 ) -> Result<Vec<SchemaScore>> {
307 let query_embedding = self.embedder.embed_schema(query);
308 let mut similarities = Vec::new();
309
310 for (id, schema) in &self.schemas {
312 let schema_embedding = self.embedder.embed_schema(schema);
313 let similarity = cosine_similarity(&query_embedding, &schema_embedding);
314 similarities.push((id.clone(), similarity));
315 }
316
317 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
319 similarities.truncate(limit);
320
321 Ok(similarities
322 .into_iter()
323 .map(|(id, similarity)| {
324 SchemaScore::new(
325 id.clone(),
326 similarity,
327 format!("Similar schema (cosine similarity: {:.2})", similarity),
328 )
329 .with_factor("embedding_similarity", similarity)
330 })
331 .collect())
332 }
333
334 fn recommend_by_pattern(&self, query: &SymbolTable, limit: usize) -> Result<Vec<SchemaScore>> {
335 let patterns = self.pattern_matcher.match_pattern(query);
336 let mut scores = Vec::new();
337
338 for (id, schema) in &self.schemas {
339 let schema_patterns = self.pattern_matcher.match_pattern(schema);
340 let overlap: usize = patterns
341 .iter()
342 .filter(|p| schema_patterns.contains(p))
343 .count();
344
345 if overlap > 0 {
346 let score = overlap as f64 / patterns.len().max(1) as f64;
347 scores.push(
348 SchemaScore::new(
349 id.clone(),
350 score,
351 format!("Matches {} common patterns", overlap),
352 )
353 .with_factor("pattern_overlap", score),
354 );
355 }
356 }
357
358 scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
359 scores.truncate(limit);
360
361 Ok(scores)
362 }
363
364 fn recommend_by_use_case(
365 &self,
366 query: &SymbolTable,
367 use_case: &str,
368 limit: usize,
369 ) -> Result<Vec<SchemaScore>> {
370 let mut scores = Vec::new();
371 let query_stats = SchemaStatistics::compute(query);
372
373 for id in self.schemas.keys() {
374 if let Some(stats) = self.schema_stats.get(id) {
375 let score = self.compute_use_case_score(use_case, &query_stats, stats);
376 if score > 0.0 {
377 scores.push(
378 SchemaScore::new(
379 id.clone(),
380 score,
381 format!("Suitable for {} use case", use_case),
382 )
383 .with_factor("use_case_match", score),
384 );
385 }
386 }
387 }
388
389 scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
390 scores.truncate(limit);
391
392 Ok(scores)
393 }
394
395 fn recommend_hybrid(&self, query: &SymbolTable, limit: usize) -> Result<Vec<SchemaScore>> {
396 let similarity_recs = self.recommend_by_similarity(query, limit * 2)?;
398 let pattern_recs = self.recommend_by_pattern(query, limit * 2)?;
399
400 let mut combined: HashMap<String, SchemaScore> = HashMap::new();
401
402 for rec in similarity_recs {
404 combined.insert(rec.schema_id.clone(), rec);
405 }
406
407 for rec in pattern_recs {
408 combined
409 .entry(rec.schema_id.clone())
410 .and_modify(|existing| {
411 existing.score = (existing.score + rec.score) / 2.0;
412 existing.reasoning.push_str(&format!("; {}", rec.reasoning));
413 for (k, v) in rec.factors.clone() {
414 existing.factors.insert(k, v);
415 }
416 })
417 .or_insert(rec);
418 }
419
420 let mut results: Vec<SchemaScore> = combined.into_values().collect();
421 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
422 results.truncate(limit);
423
424 Ok(results)
425 }
426
427 fn recommend_collaborative(
428 &self,
429 _query: &SymbolTable,
430 limit: usize,
431 ) -> Result<Vec<SchemaScore>> {
432 let mut scores: Vec<SchemaScore> = self
433 .usage_counts
434 .iter()
435 .map(|(id, count)| {
436 let max_count = self.usage_counts.values().max().unwrap_or(&1);
437 let score = *count as f64 / *max_count as f64;
438 SchemaScore::new(
439 id.clone(),
440 score,
441 format!("Popular schema (used {} times)", count),
442 )
443 .with_factor("usage_count", *count as f64)
444 })
445 .collect();
446
447 scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
448 scores.truncate(limit);
449
450 Ok(scores)
451 }
452
453 fn compute_use_case_score(
454 &self,
455 use_case: &str,
456 query_stats: &SchemaStatistics,
457 candidate_stats: &SchemaStatistics,
458 ) -> f64 {
459 match use_case.to_lowercase().as_str() {
460 "simple" => {
461 let complexity_diff =
463 (query_stats.complexity_score() - candidate_stats.complexity_score()).abs();
464 f64::max(0.0, 1.0 - complexity_diff / 10.0)
465 }
466 "large" => {
467 if candidate_stats.domain_count > 10 {
469 0.8
470 } else {
471 0.3
472 }
473 }
474 "relational" => {
475 let predicate_ratio = candidate_stats.predicate_count as f64
477 / candidate_stats.domain_count.max(1) as f64;
478 (predicate_ratio / 3.0).min(1.0)
479 }
480 _ => 0.5, }
482 }
483
484 pub fn stats(&self) -> RecommenderStats {
486 RecommenderStats {
487 total_schemas: self.schemas.len(),
488 total_patterns: self.pattern_matcher.patterns.len(),
489 total_usage_records: self.usage_counts.values().sum(),
490 most_used_schema: self
491 .usage_counts
492 .iter()
493 .max_by_key(|(_, count)| *count)
494 .map(|(id, _)| id.clone()),
495 }
496 }
497}
498
499impl Default for SchemaRecommender {
500 fn default() -> Self {
501 Self::new()
502 }
503}
504
505#[derive(Clone, Debug)]
507pub struct RecommenderStats {
508 pub total_schemas: usize,
509 pub total_patterns: usize,
510 pub total_usage_records: usize,
511 pub most_used_schema: Option<String>,
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517 use crate::DomainInfo;
518
519 fn create_test_schema(name: &str, domain_count: usize) -> SymbolTable {
520 let mut schema = SymbolTable::new();
521 for i in 0..domain_count {
522 schema
523 .add_domain(DomainInfo::new(format!("{}Domain{}", name, i), 100))
524 .unwrap();
525 }
526 schema
527 }
528
529 #[test]
530 fn test_schema_score_creation() {
531 let score = SchemaScore::new("test", 0.85, "High similarity");
532 assert_eq!(score.schema_id, "test");
533 assert_eq!(score.score, 0.85);
534 assert_eq!(score.reasoning, "High similarity");
535 }
536
537 #[test]
538 fn test_schema_score_with_factors() {
539 let score = SchemaScore::new("test", 0.8, "reason")
540 .with_factor("similarity", 0.9)
541 .with_factor("popularity", 0.7);
542
543 assert_eq!(score.factors.len(), 2);
544 assert_eq!(score.factors.get("similarity"), Some(&0.9));
545 }
546
547 #[test]
548 fn test_recommendation_context() {
549 let context = RecommendationContext::new()
550 .with_preference("users", 0.9)
551 .with_history("schema1")
552 .with_rating("schema2", 0.8)
553 .with_interest("database");
554
555 assert_eq!(context.preferences.get("users"), Some(&0.9));
556 assert_eq!(context.history.len(), 1);
557 assert_eq!(context.ratings.get("schema2"), Some(&0.8));
558 assert_eq!(context.interests.len(), 1);
559 }
560
561 #[test]
562 fn test_pattern_matcher() {
563 let mut matcher = PatternMatcher::new();
564 matcher.add_pattern("small_schema", vec!["s1".to_string()]);
565
566 let schema = create_test_schema("Test", 3);
567 let matches = matcher.match_pattern(&schema);
568
569 assert!(!matches.is_empty());
570 }
571
572 #[test]
573 fn test_recommender_add_remove() {
574 let mut recommender = SchemaRecommender::new();
575 let schema = create_test_schema("Test", 5);
576
577 recommender.add_schema("test1", schema.clone());
578 assert_eq!(recommender.schemas.len(), 1);
579
580 let removed = recommender.remove_schema("test1");
581 assert!(removed.is_some());
582 assert_eq!(recommender.schemas.len(), 0);
583 }
584
585 #[test]
586 fn test_recommend_by_similarity() {
587 let mut recommender = SchemaRecommender::new();
588
589 recommender.add_schema("schema1", create_test_schema("A", 3));
590 recommender.add_schema("schema2", create_test_schema("B", 5));
591 recommender.add_schema("schema3", create_test_schema("C", 3));
592
593 let query = create_test_schema("Query", 3);
594 let recs = recommender
595 .recommend(&query, RecommendationStrategy::Similarity, 2)
596 .unwrap();
597
598 assert!(!recs.is_empty());
599 assert!(recs.len() <= 2);
600 }
601
602 #[test]
603 fn test_recommend_by_pattern() {
604 let mut recommender = SchemaRecommender::new();
605
606 recommender.pattern_matcher.add_pattern(
608 "small_simple",
609 vec!["small1".to_string(), "small2".to_string()],
610 );
611
612 recommender.add_schema("small1", create_test_schema("S1", 2));
613 recommender.add_schema("small2", create_test_schema("S2", 3));
614 recommender.add_schema("large1", create_test_schema("L1", 20));
615
616 let query = create_test_schema("Query", 2);
617 let recs = recommender
618 .recommend(&query, RecommendationStrategy::Pattern, 2)
619 .unwrap();
620
621 assert!(recs.len() <= 2);
624 }
625
626 #[test]
627 fn test_recommend_collaborative() {
628 let mut recommender = SchemaRecommender::new();
629
630 recommender.add_schema("popular", create_test_schema("P", 5));
631 recommender.add_schema("unpopular", create_test_schema("U", 5));
632
633 recommender.record_usage("popular");
634 recommender.record_usage("popular");
635 recommender.record_usage("popular");
636 recommender.record_usage("unpopular");
637
638 let query = create_test_schema("Query", 5);
639 let recs = recommender
640 .recommend(&query, RecommendationStrategy::Collaborative, 2)
641 .unwrap();
642
643 assert!(!recs.is_empty());
644 assert_eq!(recs[0].schema_id, "popular");
645 }
646
647 #[test]
648 fn test_recommend_hybrid() {
649 let mut recommender = SchemaRecommender::new();
650
651 recommender.add_schema("schema1", create_test_schema("A", 3));
652 recommender.add_schema("schema2", create_test_schema("B", 5));
653
654 let query = create_test_schema("Query", 3);
655 let recs = recommender
656 .recommend(&query, RecommendationStrategy::Hybrid, 2)
657 .unwrap();
658
659 assert!(!recs.is_empty());
660 }
661
662 #[test]
663 fn test_recommend_with_context() {
664 let mut recommender = SchemaRecommender::new();
665
666 recommender.add_schema("schema1", create_test_schema("A", 3));
667 recommender.add_schema("schema2", create_test_schema("B", 5));
668
669 let context = RecommendationContext::new()
670 .with_rating("schema1", 0.9)
671 .with_history("schema2");
672
673 let query = create_test_schema("Query", 3);
674 let recs = recommender
675 .recommend_with_context(&query, &context, 2)
676 .unwrap();
677
678 assert!(!recs.is_empty());
679 }
680
681 #[test]
682 fn test_recommender_stats() {
683 let mut recommender = SchemaRecommender::new();
684
685 recommender.add_schema("s1", create_test_schema("A", 3));
686 recommender.add_schema("s2", create_test_schema("B", 5));
687 recommender.record_usage("s1");
688 recommender.record_usage("s1");
689
690 let stats = recommender.stats();
691 assert_eq!(stats.total_schemas, 2);
692 assert_eq!(stats.total_usage_records, 2);
693 assert_eq!(stats.most_used_schema, Some("s1".to_string()));
694 }
695
696 #[test]
697 fn test_use_case_recommendations() {
698 let mut recommender = SchemaRecommender::new();
699
700 recommender.add_schema("simple", create_test_schema("S", 3));
701 recommender.add_schema("complex", create_test_schema("C", 15));
702
703 let query = create_test_schema("Query", 3);
704 let recs = recommender
705 .recommend(
706 &query,
707 RecommendationStrategy::UseCase("large".to_string()),
708 2,
709 )
710 .unwrap();
711
712 assert!(!recs.is_empty());
713 }
714
715 #[test]
716 fn test_record_usage() {
717 let mut recommender = SchemaRecommender::new();
718 recommender.add_schema("test", create_test_schema("T", 5));
719
720 recommender.record_usage("test");
721 recommender.record_usage("test");
722
723 assert_eq!(recommender.usage_counts.get("test"), Some(&2));
724 }
725}