Skip to main content

tensorlogic_adapters/
recommendation.rs

1//! Schema recommendation system.
2//!
3//! This module provides intelligent schema recommendations based on similarity,
4//! patterns, use cases, and collaborative filtering techniques.
5//!
6//! # Overview
7//!
8//! The recommendation system helps users discover relevant schemas by:
9//! - Finding similar schemas based on embeddings
10//! - Identifying common patterns across schema collections
11//! - Recommending schemas for specific use cases
12//! - Learning from user interactions and preferences
13//!
14//! # Architecture
15//!
16//! - **SchemaRecommender**: Main recommendation engine
17//! - **RecommendationStrategy**: Multiple recommendation approaches
18//! - **SchemaScore**: Scored recommendation with reasoning
19//! - **RecommendationContext**: User context and preferences
20//! - **PatternMatcher**: Pattern-based schema matching
21//!
22//! # Example
23//!
24//! ```rust
25//! use tensorlogic_adapters::{
26//!     SchemaRecommender, RecommendationStrategy, SymbolTable, DomainInfo
27//! };
28//!
29//! let mut recommender = SchemaRecommender::new();
30//!
31//! // Add schemas to the recommendation pool
32//! let mut schema1 = SymbolTable::new();
33//! schema1.add_domain(DomainInfo::new("Person", 100)).unwrap();
34//! recommender.add_schema("users", schema1);
35//!
36//! let mut schema2 = SymbolTable::new();
37//! schema2.add_domain(DomainInfo::new("Product", 200)).unwrap();
38//! recommender.add_schema("products", schema2);
39//!
40//! // Get recommendations
41//! let mut query = SymbolTable::new();
42//! query.add_domain(DomainInfo::new("User", 50)).unwrap();
43//!
44//! let recommendations = recommender.recommend(
45//!     &query,
46//!     RecommendationStrategy::Similarity,
47//!     5
48//! ).unwrap();
49//!
50//! assert!(!recommendations.is_empty());
51//! ```
52
53use anyhow::Result;
54use std::collections::HashMap;
55
56use crate::{Embedding, SchemaEmbedder, SchemaStatistics, SymbolTable};
57
58/// Compute cosine similarity between two embeddings.
59fn cosine_similarity(a: &Embedding, b: &Embedding) -> f64 {
60    let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
61    let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
62    let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
63
64    if norm_a == 0.0 || norm_b == 0.0 {
65        0.0
66    } else {
67        dot_product / (norm_a * norm_b)
68    }
69}
70
71/// Strategy for generating recommendations.
72#[derive(Clone, Debug, PartialEq, Eq)]
73pub enum RecommendationStrategy {
74    /// Similarity-based using embeddings
75    Similarity,
76    /// Pattern-based matching
77    Pattern,
78    /// Use-case specific recommendations
79    UseCase(String),
80    /// Hybrid approach combining multiple strategies
81    Hybrid,
82    /// Collaborative filtering based on usage
83    Collaborative,
84}
85
86/// A scored schema recommendation.
87#[derive(Clone, Debug)]
88pub struct SchemaScore {
89    /// Schema identifier
90    pub schema_id: String,
91    /// Recommendation score (0.0 to 1.0)
92    pub score: f64,
93    /// Reasoning for the recommendation
94    pub reasoning: String,
95    /// Contributing factors to the score
96    pub factors: HashMap<String, f64>,
97}
98
99impl SchemaScore {
100    pub fn new(schema_id: impl Into<String>, score: f64, reasoning: impl Into<String>) -> Self {
101        Self {
102            schema_id: schema_id.into(),
103            score: score.clamp(0.0, 1.0),
104            reasoning: reasoning.into(),
105            factors: HashMap::new(),
106        }
107    }
108
109    pub fn with_factor(mut self, name: impl Into<String>, value: f64) -> Self {
110        self.factors.insert(name.into(), value);
111        self
112    }
113}
114
115/// Context for generating recommendations.
116#[derive(Clone, Debug, Default)]
117pub struct RecommendationContext {
118    /// User preferences
119    pub preferences: HashMap<String, f64>,
120    /// Previously viewed schemas
121    pub history: Vec<String>,
122    /// Explicit user ratings
123    pub ratings: HashMap<String, f64>,
124    /// Tags or categories of interest
125    pub interests: Vec<String>,
126}
127
128impl RecommendationContext {
129    pub fn new() -> Self {
130        Self::default()
131    }
132
133    pub fn with_preference(mut self, key: impl Into<String>, value: f64) -> Self {
134        self.preferences.insert(key.into(), value);
135        self
136    }
137
138    pub fn with_history(mut self, schema_id: impl Into<String>) -> Self {
139        self.history.push(schema_id.into());
140        self
141    }
142
143    pub fn with_rating(mut self, schema_id: impl Into<String>, rating: f64) -> Self {
144        self.ratings.insert(schema_id.into(), rating);
145        self
146    }
147
148    pub fn with_interest(mut self, tag: impl Into<String>) -> Self {
149        self.interests.push(tag.into());
150        self
151    }
152}
153
154/// Pattern matcher for schema recommendations.
155#[derive(Clone, Debug)]
156pub struct PatternMatcher {
157    patterns: HashMap<String, Vec<String>>,
158}
159
160impl PatternMatcher {
161    pub fn new() -> Self {
162        Self {
163            patterns: HashMap::new(),
164        }
165    }
166
167    pub fn add_pattern(&mut self, name: impl Into<String>, schema_ids: Vec<String>) {
168        self.patterns.insert(name.into(), schema_ids);
169    }
170
171    pub fn match_pattern(&self, schema: &SymbolTable) -> Vec<String> {
172        let mut matches = Vec::new();
173
174        // Simple pattern matching based on domain count and structure
175        let domain_count = schema.domains.len();
176        let predicate_count = schema.predicates.len();
177
178        for pattern_name in self.patterns.keys() {
179            // Match based on size heuristics or complexity
180            let size_match = (pattern_name.contains("small") && domain_count < 5)
181                || (pattern_name.contains("medium") && (5..15).contains(&domain_count))
182                || (pattern_name.contains("large") && domain_count >= 15);
183
184            let complexity_match = (pattern_name.contains("simple") && predicate_count < 10)
185                || (pattern_name.contains("complex") && predicate_count >= 10);
186
187            if size_match || complexity_match {
188                matches.push(pattern_name.clone());
189            }
190        }
191
192        matches
193    }
194}
195
196impl Default for PatternMatcher {
197    fn default() -> Self {
198        Self::new()
199    }
200}
201
202/// Schema recommendation engine.
203pub struct SchemaRecommender {
204    schemas: HashMap<String, SymbolTable>,
205    embedder: SchemaEmbedder,
206    pattern_matcher: PatternMatcher,
207    usage_counts: HashMap<String, usize>,
208    schema_stats: HashMap<String, SchemaStatistics>,
209}
210
211impl SchemaRecommender {
212    /// Create a new recommendation engine.
213    pub fn new() -> Self {
214        Self {
215            schemas: HashMap::new(),
216            embedder: SchemaEmbedder::new(),
217            pattern_matcher: PatternMatcher::new(),
218            usage_counts: HashMap::new(),
219            schema_stats: HashMap::new(),
220        }
221    }
222
223    /// Add a schema to the recommendation pool.
224    pub fn add_schema(&mut self, id: impl Into<String>, schema: SymbolTable) {
225        let id = id.into();
226        let stats = SchemaStatistics::compute(&schema);
227        self.schema_stats.insert(id.clone(), stats);
228        self.schemas.insert(id, schema);
229    }
230
231    /// Remove a schema from the pool.
232    pub fn remove_schema(&mut self, id: &str) -> Option<SymbolTable> {
233        self.schema_stats.remove(id);
234        self.usage_counts.remove(id);
235        self.schemas.remove(id)
236    }
237
238    /// Record schema usage for collaborative filtering.
239    pub fn record_usage(&mut self, schema_id: &str) {
240        *self.usage_counts.entry(schema_id.to_string()).or_insert(0) += 1;
241    }
242
243    /// Get recommendations for a query schema.
244    pub fn recommend(
245        &self,
246        query: &SymbolTable,
247        strategy: RecommendationStrategy,
248        limit: usize,
249    ) -> Result<Vec<SchemaScore>> {
250        match strategy {
251            RecommendationStrategy::Similarity => self.recommend_by_similarity(query, limit),
252            RecommendationStrategy::Pattern => self.recommend_by_pattern(query, limit),
253            RecommendationStrategy::UseCase(use_case) => {
254                self.recommend_by_use_case(query, &use_case, limit)
255            }
256            RecommendationStrategy::Hybrid => self.recommend_hybrid(query, limit),
257            RecommendationStrategy::Collaborative => self.recommend_collaborative(query, limit),
258        }
259    }
260
261    /// Get recommendations with context.
262    pub fn recommend_with_context(
263        &self,
264        query: &SymbolTable,
265        context: &RecommendationContext,
266        limit: usize,
267    ) -> Result<Vec<SchemaScore>> {
268        let mut base_recommendations = self.recommend_hybrid(query, limit * 2)?;
269
270        // Adjust scores based on context
271        for rec in &mut base_recommendations {
272            // Boost based on user ratings
273            if let Some(rating) = context.ratings.get(&rec.schema_id) {
274                rec.score = (rec.score + rating) / 2.0;
275                rec.factors.insert("user_rating".to_string(), *rating);
276            }
277
278            // Boost based on history (recency)
279            if let Some(pos) = context.history.iter().position(|id| id == &rec.schema_id) {
280                let recency_boost = 1.0 - (pos as f64 / context.history.len() as f64) * 0.3;
281                rec.score *= recency_boost;
282                rec.factors.insert("recency".to_string(), recency_boost);
283            }
284
285            // Adjust based on preferences
286            for (pref_key, pref_value) in &context.preferences {
287                if rec.schema_id.contains(pref_key) {
288                    rec.score = (rec.score + pref_value) / 2.0;
289                    rec.factors
290                        .insert(format!("preference_{}", pref_key), *pref_value);
291                }
292            }
293        }
294
295        // Re-sort and limit
296        base_recommendations.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
297        base_recommendations.truncate(limit);
298
299        Ok(base_recommendations)
300    }
301
302    fn recommend_by_similarity(
303        &self,
304        query: &SymbolTable,
305        limit: usize,
306    ) -> Result<Vec<SchemaScore>> {
307        let query_embedding = self.embedder.embed_schema(query);
308        let mut similarities = Vec::new();
309
310        // Compute similarity with each schema
311        for (id, schema) in &self.schemas {
312            let schema_embedding = self.embedder.embed_schema(schema);
313            let similarity = cosine_similarity(&query_embedding, &schema_embedding);
314            similarities.push((id.clone(), similarity));
315        }
316
317        // Sort by similarity (highest first)
318        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
319        similarities.truncate(limit);
320
321        Ok(similarities
322            .into_iter()
323            .map(|(id, similarity)| {
324                SchemaScore::new(
325                    id.clone(),
326                    similarity,
327                    format!("Similar schema (cosine similarity: {:.2})", similarity),
328                )
329                .with_factor("embedding_similarity", similarity)
330            })
331            .collect())
332    }
333
334    fn recommend_by_pattern(&self, query: &SymbolTable, limit: usize) -> Result<Vec<SchemaScore>> {
335        let patterns = self.pattern_matcher.match_pattern(query);
336        let mut scores = Vec::new();
337
338        for (id, schema) in &self.schemas {
339            let schema_patterns = self.pattern_matcher.match_pattern(schema);
340            let overlap: usize = patterns
341                .iter()
342                .filter(|p| schema_patterns.contains(p))
343                .count();
344
345            if overlap > 0 {
346                let score = overlap as f64 / patterns.len().max(1) as f64;
347                scores.push(
348                    SchemaScore::new(
349                        id.clone(),
350                        score,
351                        format!("Matches {} common patterns", overlap),
352                    )
353                    .with_factor("pattern_overlap", score),
354                );
355            }
356        }
357
358        scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
359        scores.truncate(limit);
360
361        Ok(scores)
362    }
363
364    fn recommend_by_use_case(
365        &self,
366        query: &SymbolTable,
367        use_case: &str,
368        limit: usize,
369    ) -> Result<Vec<SchemaScore>> {
370        let mut scores = Vec::new();
371        let query_stats = SchemaStatistics::compute(query);
372
373        for id in self.schemas.keys() {
374            if let Some(stats) = self.schema_stats.get(id) {
375                let score = self.compute_use_case_score(use_case, &query_stats, stats);
376                if score > 0.0 {
377                    scores.push(
378                        SchemaScore::new(
379                            id.clone(),
380                            score,
381                            format!("Suitable for {} use case", use_case),
382                        )
383                        .with_factor("use_case_match", score),
384                    );
385                }
386            }
387        }
388
389        scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
390        scores.truncate(limit);
391
392        Ok(scores)
393    }
394
395    fn recommend_hybrid(&self, query: &SymbolTable, limit: usize) -> Result<Vec<SchemaScore>> {
396        // Combine similarity and pattern matching
397        let similarity_recs = self.recommend_by_similarity(query, limit * 2)?;
398        let pattern_recs = self.recommend_by_pattern(query, limit * 2)?;
399
400        let mut combined: HashMap<String, SchemaScore> = HashMap::new();
401
402        // Merge recommendations
403        for rec in similarity_recs {
404            combined.insert(rec.schema_id.clone(), rec);
405        }
406
407        for rec in pattern_recs {
408            combined
409                .entry(rec.schema_id.clone())
410                .and_modify(|existing| {
411                    existing.score = (existing.score + rec.score) / 2.0;
412                    existing.reasoning.push_str(&format!("; {}", rec.reasoning));
413                    for (k, v) in rec.factors.clone() {
414                        existing.factors.insert(k, v);
415                    }
416                })
417                .or_insert(rec);
418        }
419
420        let mut results: Vec<SchemaScore> = combined.into_values().collect();
421        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
422        results.truncate(limit);
423
424        Ok(results)
425    }
426
427    fn recommend_collaborative(
428        &self,
429        _query: &SymbolTable,
430        limit: usize,
431    ) -> Result<Vec<SchemaScore>> {
432        let mut scores: Vec<SchemaScore> = self
433            .usage_counts
434            .iter()
435            .map(|(id, count)| {
436                let max_count = self.usage_counts.values().max().unwrap_or(&1);
437                let score = *count as f64 / *max_count as f64;
438                SchemaScore::new(
439                    id.clone(),
440                    score,
441                    format!("Popular schema (used {} times)", count),
442                )
443                .with_factor("usage_count", *count as f64)
444            })
445            .collect();
446
447        scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
448        scores.truncate(limit);
449
450        Ok(scores)
451    }
452
453    fn compute_use_case_score(
454        &self,
455        use_case: &str,
456        query_stats: &SchemaStatistics,
457        candidate_stats: &SchemaStatistics,
458    ) -> f64 {
459        match use_case.to_lowercase().as_str() {
460            "simple" => {
461                // Prefer schemas with similar low complexity
462                let complexity_diff =
463                    (query_stats.complexity_score() - candidate_stats.complexity_score()).abs();
464                f64::max(0.0, 1.0 - complexity_diff / 10.0)
465            }
466            "large" => {
467                // Prefer schemas with many domains
468                if candidate_stats.domain_count > 10 {
469                    0.8
470                } else {
471                    0.3
472                }
473            }
474            "relational" => {
475                // Prefer schemas with many predicates
476                let predicate_ratio = candidate_stats.predicate_count as f64
477                    / candidate_stats.domain_count.max(1) as f64;
478                (predicate_ratio / 3.0).min(1.0)
479            }
480            _ => 0.5, // Default score
481        }
482    }
483
484    /// Get statistics about the recommendation pool.
485    pub fn stats(&self) -> RecommenderStats {
486        RecommenderStats {
487            total_schemas: self.schemas.len(),
488            total_patterns: self.pattern_matcher.patterns.len(),
489            total_usage_records: self.usage_counts.values().sum(),
490            most_used_schema: self
491                .usage_counts
492                .iter()
493                .max_by_key(|(_, count)| *count)
494                .map(|(id, _)| id.clone()),
495        }
496    }
497}
498
499impl Default for SchemaRecommender {
500    fn default() -> Self {
501        Self::new()
502    }
503}
504
505/// Statistics about the recommendation engine.
506#[derive(Clone, Debug)]
507pub struct RecommenderStats {
508    pub total_schemas: usize,
509    pub total_patterns: usize,
510    pub total_usage_records: usize,
511    pub most_used_schema: Option<String>,
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517    use crate::DomainInfo;
518
519    fn create_test_schema(name: &str, domain_count: usize) -> SymbolTable {
520        let mut schema = SymbolTable::new();
521        for i in 0..domain_count {
522            schema
523                .add_domain(DomainInfo::new(format!("{}Domain{}", name, i), 100))
524                .unwrap();
525        }
526        schema
527    }
528
529    #[test]
530    fn test_schema_score_creation() {
531        let score = SchemaScore::new("test", 0.85, "High similarity");
532        assert_eq!(score.schema_id, "test");
533        assert_eq!(score.score, 0.85);
534        assert_eq!(score.reasoning, "High similarity");
535    }
536
537    #[test]
538    fn test_schema_score_with_factors() {
539        let score = SchemaScore::new("test", 0.8, "reason")
540            .with_factor("similarity", 0.9)
541            .with_factor("popularity", 0.7);
542
543        assert_eq!(score.factors.len(), 2);
544        assert_eq!(score.factors.get("similarity"), Some(&0.9));
545    }
546
547    #[test]
548    fn test_recommendation_context() {
549        let context = RecommendationContext::new()
550            .with_preference("users", 0.9)
551            .with_history("schema1")
552            .with_rating("schema2", 0.8)
553            .with_interest("database");
554
555        assert_eq!(context.preferences.get("users"), Some(&0.9));
556        assert_eq!(context.history.len(), 1);
557        assert_eq!(context.ratings.get("schema2"), Some(&0.8));
558        assert_eq!(context.interests.len(), 1);
559    }
560
561    #[test]
562    fn test_pattern_matcher() {
563        let mut matcher = PatternMatcher::new();
564        matcher.add_pattern("small_schema", vec!["s1".to_string()]);
565
566        let schema = create_test_schema("Test", 3);
567        let matches = matcher.match_pattern(&schema);
568
569        assert!(!matches.is_empty());
570    }
571
572    #[test]
573    fn test_recommender_add_remove() {
574        let mut recommender = SchemaRecommender::new();
575        let schema = create_test_schema("Test", 5);
576
577        recommender.add_schema("test1", schema.clone());
578        assert_eq!(recommender.schemas.len(), 1);
579
580        let removed = recommender.remove_schema("test1");
581        assert!(removed.is_some());
582        assert_eq!(recommender.schemas.len(), 0);
583    }
584
585    #[test]
586    fn test_recommend_by_similarity() {
587        let mut recommender = SchemaRecommender::new();
588
589        recommender.add_schema("schema1", create_test_schema("A", 3));
590        recommender.add_schema("schema2", create_test_schema("B", 5));
591        recommender.add_schema("schema3", create_test_schema("C", 3));
592
593        let query = create_test_schema("Query", 3);
594        let recs = recommender
595            .recommend(&query, RecommendationStrategy::Similarity, 2)
596            .unwrap();
597
598        assert!(!recs.is_empty());
599        assert!(recs.len() <= 2);
600    }
601
602    #[test]
603    fn test_recommend_by_pattern() {
604        let mut recommender = SchemaRecommender::new();
605
606        // Register patterns
607        recommender.pattern_matcher.add_pattern(
608            "small_simple",
609            vec!["small1".to_string(), "small2".to_string()],
610        );
611
612        recommender.add_schema("small1", create_test_schema("S1", 2));
613        recommender.add_schema("small2", create_test_schema("S2", 3));
614        recommender.add_schema("large1", create_test_schema("L1", 20));
615
616        let query = create_test_schema("Query", 2);
617        let recs = recommender
618            .recommend(&query, RecommendationStrategy::Pattern, 2)
619            .unwrap();
620
621        // Pattern matching may return empty if no patterns match
622        // This is expected behavior
623        assert!(recs.len() <= 2);
624    }
625
626    #[test]
627    fn test_recommend_collaborative() {
628        let mut recommender = SchemaRecommender::new();
629
630        recommender.add_schema("popular", create_test_schema("P", 5));
631        recommender.add_schema("unpopular", create_test_schema("U", 5));
632
633        recommender.record_usage("popular");
634        recommender.record_usage("popular");
635        recommender.record_usage("popular");
636        recommender.record_usage("unpopular");
637
638        let query = create_test_schema("Query", 5);
639        let recs = recommender
640            .recommend(&query, RecommendationStrategy::Collaborative, 2)
641            .unwrap();
642
643        assert!(!recs.is_empty());
644        assert_eq!(recs[0].schema_id, "popular");
645    }
646
647    #[test]
648    fn test_recommend_hybrid() {
649        let mut recommender = SchemaRecommender::new();
650
651        recommender.add_schema("schema1", create_test_schema("A", 3));
652        recommender.add_schema("schema2", create_test_schema("B", 5));
653
654        let query = create_test_schema("Query", 3);
655        let recs = recommender
656            .recommend(&query, RecommendationStrategy::Hybrid, 2)
657            .unwrap();
658
659        assert!(!recs.is_empty());
660    }
661
662    #[test]
663    fn test_recommend_with_context() {
664        let mut recommender = SchemaRecommender::new();
665
666        recommender.add_schema("schema1", create_test_schema("A", 3));
667        recommender.add_schema("schema2", create_test_schema("B", 5));
668
669        let context = RecommendationContext::new()
670            .with_rating("schema1", 0.9)
671            .with_history("schema2");
672
673        let query = create_test_schema("Query", 3);
674        let recs = recommender
675            .recommend_with_context(&query, &context, 2)
676            .unwrap();
677
678        assert!(!recs.is_empty());
679    }
680
681    #[test]
682    fn test_recommender_stats() {
683        let mut recommender = SchemaRecommender::new();
684
685        recommender.add_schema("s1", create_test_schema("A", 3));
686        recommender.add_schema("s2", create_test_schema("B", 5));
687        recommender.record_usage("s1");
688        recommender.record_usage("s1");
689
690        let stats = recommender.stats();
691        assert_eq!(stats.total_schemas, 2);
692        assert_eq!(stats.total_usage_records, 2);
693        assert_eq!(stats.most_used_schema, Some("s1".to_string()));
694    }
695
696    #[test]
697    fn test_use_case_recommendations() {
698        let mut recommender = SchemaRecommender::new();
699
700        recommender.add_schema("simple", create_test_schema("S", 3));
701        recommender.add_schema("complex", create_test_schema("C", 15));
702
703        let query = create_test_schema("Query", 3);
704        let recs = recommender
705            .recommend(
706                &query,
707                RecommendationStrategy::UseCase("large".to_string()),
708                2,
709            )
710            .unwrap();
711
712        assert!(!recs.is_empty());
713    }
714
715    #[test]
716    fn test_record_usage() {
717        let mut recommender = SchemaRecommender::new();
718        recommender.add_schema("test", create_test_schema("T", 5));
719
720        recommender.record_usage("test");
721        recommender.record_usage("test");
722
723        assert_eq!(recommender.usage_counts.get("test"), Some(&2));
724    }
725}