ruvector_core/advanced_features/
mmr.rs

1//! Maximal Marginal Relevance (MMR) for Diversity-Aware Search
2//!
3//! Implements MMR algorithm to balance relevance and diversity in search results:
4//! MMR = λ × Similarity(query, doc) - (1-λ) × max Similarity(doc, selected_docs)
5
6use crate::error::{Result, RuvectorError};
7use crate::types::{DistanceMetric, SearchResult};
8use serde::{Deserialize, Serialize};
9
10/// Configuration for MMR search
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct MMRConfig {
13    /// Lambda parameter: balance between relevance (1.0) and diversity (0.0)
14    /// - λ = 1.0: Pure relevance (standard similarity search)
15    /// - λ = 0.5: Equal balance
16    /// - λ = 0.0: Pure diversity
17    pub lambda: f32,
18    /// Distance metric for similarity computation
19    pub metric: DistanceMetric,
20    /// Fetch multiplier for initial candidates (fetch k * multiplier results)
21    pub fetch_multiplier: f32,
22}
23
24impl Default for MMRConfig {
25    fn default() -> Self {
26        Self {
27            lambda: 0.5,
28            metric: DistanceMetric::Cosine,
29            fetch_multiplier: 2.0,
30        }
31    }
32}
33
34/// MMR search implementation
35#[derive(Debug, Clone)]
36pub struct MMRSearch {
37    /// Configuration
38    pub config: MMRConfig,
39}
40
41impl MMRSearch {
42    /// Create a new MMR search instance
43    pub fn new(config: MMRConfig) -> Result<Self> {
44        if !(0.0..=1.0).contains(&config.lambda) {
45            return Err(RuvectorError::InvalidParameter(format!(
46                "Lambda must be in [0, 1], got {}",
47                config.lambda
48            )));
49        }
50
51        Ok(Self { config })
52    }
53
54    /// Perform MMR-based reranking of search results
55    ///
56    /// # Arguments
57    /// * `query` - Query vector
58    /// * `candidates` - Initial search results (sorted by relevance)
59    /// * `k` - Number of diverse results to return
60    ///
61    /// # Returns
62    /// Reranked results optimizing for both relevance and diversity
63    pub fn rerank(
64        &self,
65        query: &[f32],
66        mut candidates: Vec<SearchResult>,
67        k: usize,
68    ) -> Result<Vec<SearchResult>> {
69        if candidates.is_empty() {
70            return Ok(Vec::new());
71        }
72
73        if k == 0 {
74            return Ok(Vec::new());
75        }
76
77        if k >= candidates.len() {
78            return Ok(candidates);
79        }
80
81        let mut selected: Vec<SearchResult> = Vec::with_capacity(k);
82        let mut remaining = candidates;
83
84        // Iteratively select documents maximizing MMR
85        for _ in 0..k {
86            if remaining.is_empty() {
87                break;
88            }
89
90            // Compute MMR score for each remaining candidate
91            let mut best_idx = 0;
92            let mut best_mmr = f32::NEG_INFINITY;
93
94            for (idx, candidate) in remaining.iter().enumerate() {
95                let mmr_score = self.compute_mmr_score(query, candidate, &selected)?;
96
97                if mmr_score > best_mmr {
98                    best_mmr = mmr_score;
99                    best_idx = idx;
100                }
101            }
102
103            // Move best candidate to selected set
104            let best = remaining.remove(best_idx);
105            selected.push(best);
106        }
107
108        Ok(selected)
109    }
110
111    /// Compute MMR score for a candidate
112    fn compute_mmr_score(
113        &self,
114        query: &[f32],
115        candidate: &SearchResult,
116        selected: &[SearchResult],
117    ) -> Result<f32> {
118        let candidate_vec = candidate.vector.as_ref().ok_or_else(|| {
119            RuvectorError::InvalidParameter("Candidate vector not available".to_string())
120        })?;
121
122        // Relevance: similarity to query (convert distance to similarity)
123        let relevance = self.distance_to_similarity(candidate.score);
124
125        // Diversity: max similarity to already selected documents
126        let max_similarity = if selected.is_empty() {
127            0.0
128        } else {
129            selected
130                .iter()
131                .filter_map(|s| s.vector.as_ref())
132                .map(|selected_vec| {
133                    let dist = compute_distance(candidate_vec, selected_vec, self.config.metric);
134                    self.distance_to_similarity(dist)
135                })
136                .max_by(|a, b| a.partial_cmp(b).unwrap())
137                .unwrap_or(0.0)
138        };
139
140        // MMR = λ × relevance - (1-λ) × max_similarity
141        let mmr = self.config.lambda * relevance - (1.0 - self.config.lambda) * max_similarity;
142
143        Ok(mmr)
144    }
145
146    /// Convert distance to similarity (higher is better)
147    fn distance_to_similarity(&self, distance: f32) -> f32 {
148        match self.config.metric {
149            DistanceMetric::Cosine => 1.0 - distance,
150            DistanceMetric::Euclidean => 1.0 / (1.0 + distance),
151            DistanceMetric::Manhattan => 1.0 / (1.0 + distance),
152            DistanceMetric::DotProduct => -distance, // Dot product is already similarity-like
153        }
154    }
155
156    /// Perform end-to-end MMR search
157    ///
158    /// # Arguments
159    /// * `query` - Query vector
160    /// * `k` - Number of diverse results to return
161    /// * `search_fn` - Function to perform initial similarity search
162    ///
163    /// # Returns
164    /// Diverse search results
165    pub fn search<F>(&self, query: &[f32], k: usize, search_fn: F) -> Result<Vec<SearchResult>>
166    where
167        F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
168    {
169        // Fetch more candidates than needed
170        let fetch_k = (k as f32 * self.config.fetch_multiplier).ceil() as usize;
171        let candidates = search_fn(query, fetch_k)?;
172
173        // Rerank using MMR
174        self.rerank(query, candidates, k)
175    }
176}
177
178// Helper function
179fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
180    match metric {
181        DistanceMetric::Euclidean => euclidean_distance(a, b),
182        DistanceMetric::Cosine => cosine_distance(a, b),
183        DistanceMetric::Manhattan => manhattan_distance(a, b),
184        DistanceMetric::DotProduct => dot_product_distance(a, b),
185    }
186}
187
188fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
189    a.iter()
190        .zip(b)
191        .map(|(x, y)| {
192            let diff = x - y;
193            diff * diff
194        })
195        .sum::<f32>()
196        .sqrt()
197}
198
199fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
200    let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
201    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
202    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
203
204    if norm_a == 0.0 || norm_b == 0.0 {
205        1.0
206    } else {
207        1.0 - (dot / (norm_a * norm_b))
208    }
209}
210
211fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
212    a.iter().zip(b).map(|(x, y)| (x - y).abs()).sum()
213}
214
215fn dot_product_distance(a: &[f32], b: &[f32]) -> f32 {
216    let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
217    -dot
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    fn create_search_result(id: &str, score: f32, vector: Vec<f32>) -> SearchResult {
225        SearchResult {
226            id: id.to_string(),
227            score,
228            vector: Some(vector),
229            metadata: None,
230        }
231    }
232
233    #[test]
234    fn test_mmr_config_validation() {
235        let config = MMRConfig {
236            lambda: 0.5,
237            ..Default::default()
238        };
239        assert!(MMRSearch::new(config).is_ok());
240
241        let invalid_config = MMRConfig {
242            lambda: 1.5,
243            ..Default::default()
244        };
245        assert!(MMRSearch::new(invalid_config).is_err());
246    }
247
248    #[test]
249    fn test_mmr_reranking() {
250        let config = MMRConfig {
251            lambda: 0.5,
252            metric: DistanceMetric::Euclidean,
253            fetch_multiplier: 2.0,
254        };
255
256        let mmr = MMRSearch::new(config).unwrap();
257        let query = vec![1.0, 0.0, 0.0];
258
259        // Create candidates with varying similarity
260        let candidates = vec![
261            create_search_result("doc1", 0.1, vec![0.9, 0.1, 0.0]), // Very similar to query
262            create_search_result("doc2", 0.15, vec![0.9, 0.0, 0.1]), // Similar to doc1 and query
263            create_search_result("doc3", 0.5, vec![0.5, 0.5, 0.5]), // Different from doc1
264            create_search_result("doc4", 0.6, vec![0.0, 1.0, 0.0]), // Very different
265        ];
266
267        let results = mmr.rerank(&query, candidates, 3).unwrap();
268
269        assert_eq!(results.len(), 3);
270        // First result should be most relevant
271        assert_eq!(results[0].id, "doc1");
272        // MMR should promote diversity, so doc3 or doc4 should appear
273        assert!(results.iter().any(|r| r.id == "doc3" || r.id == "doc4"));
274    }
275
276    #[test]
277    fn test_mmr_pure_relevance() {
278        let config = MMRConfig {
279            lambda: 1.0, // Pure relevance
280            metric: DistanceMetric::Euclidean,
281            fetch_multiplier: 2.0,
282        };
283
284        let mmr = MMRSearch::new(config).unwrap();
285        let query = vec![1.0, 0.0, 0.0];
286
287        let candidates = vec![
288            create_search_result("doc1", 0.1, vec![0.9, 0.1, 0.0]),
289            create_search_result("doc2", 0.15, vec![0.85, 0.1, 0.05]),
290            create_search_result("doc3", 0.5, vec![0.5, 0.5, 0.0]),
291        ];
292
293        let results = mmr.rerank(&query, candidates, 2).unwrap();
294
295        // With lambda=1.0, should just preserve relevance order
296        assert_eq!(results[0].id, "doc1");
297        assert_eq!(results[1].id, "doc2");
298    }
299
300    #[test]
301    fn test_mmr_pure_diversity() {
302        let config = MMRConfig {
303            lambda: 0.0, // Pure diversity
304            metric: DistanceMetric::Euclidean,
305            fetch_multiplier: 2.0,
306        };
307
308        let mmr = MMRSearch::new(config).unwrap();
309        let query = vec![1.0, 0.0, 0.0];
310
311        let candidates = vec![
312            create_search_result("doc1", 0.1, vec![0.9, 0.1, 0.0]),
313            create_search_result("doc2", 0.15, vec![0.9, 0.0, 0.1]), // Very similar to doc1
314            create_search_result("doc3", 0.5, vec![0.0, 1.0, 0.0]),  // Very different
315        ];
316
317        let results = mmr.rerank(&query, candidates, 2).unwrap();
318
319        // With lambda=0.0, should maximize diversity
320        assert_eq!(results.len(), 2);
321        // Should not select both doc1 and doc2 (they're too similar)
322        let has_both_similar =
323            results.iter().any(|r| r.id == "doc1") && results.iter().any(|r| r.id == "doc2");
324        assert!(!has_both_similar);
325    }
326
327    #[test]
328    fn test_mmr_empty_candidates() {
329        let config = MMRConfig::default();
330        let mmr = MMRSearch::new(config).unwrap();
331        let query = vec![1.0, 0.0, 0.0];
332
333        let results = mmr.rerank(&query, Vec::new(), 5).unwrap();
334        assert!(results.is_empty());
335    }
336}