ruvector_core/advanced_features/
mmr.rs1use crate::error::{Result, RuvectorError};
7use crate::types::{DistanceMetric, SearchResult};
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct MMRConfig {
13 pub lambda: f32,
18 pub metric: DistanceMetric,
20 pub fetch_multiplier: f32,
22}
23
24impl Default for MMRConfig {
25 fn default() -> Self {
26 Self {
27 lambda: 0.5,
28 metric: DistanceMetric::Cosine,
29 fetch_multiplier: 2.0,
30 }
31 }
32}
33
34#[derive(Debug, Clone)]
36pub struct MMRSearch {
37 pub config: MMRConfig,
39}
40
41impl MMRSearch {
42 pub fn new(config: MMRConfig) -> Result<Self> {
44 if !(0.0..=1.0).contains(&config.lambda) {
45 return Err(RuvectorError::InvalidParameter(format!(
46 "Lambda must be in [0, 1], got {}",
47 config.lambda
48 )));
49 }
50
51 Ok(Self { config })
52 }
53
54 pub fn rerank(
64 &self,
65 query: &[f32],
66 mut candidates: Vec<SearchResult>,
67 k: usize,
68 ) -> Result<Vec<SearchResult>> {
69 if candidates.is_empty() {
70 return Ok(Vec::new());
71 }
72
73 if k == 0 {
74 return Ok(Vec::new());
75 }
76
77 if k >= candidates.len() {
78 return Ok(candidates);
79 }
80
81 let mut selected: Vec<SearchResult> = Vec::with_capacity(k);
82 let mut remaining = candidates;
83
84 for _ in 0..k {
86 if remaining.is_empty() {
87 break;
88 }
89
90 let mut best_idx = 0;
92 let mut best_mmr = f32::NEG_INFINITY;
93
94 for (idx, candidate) in remaining.iter().enumerate() {
95 let mmr_score = self.compute_mmr_score(query, candidate, &selected)?;
96
97 if mmr_score > best_mmr {
98 best_mmr = mmr_score;
99 best_idx = idx;
100 }
101 }
102
103 let best = remaining.remove(best_idx);
105 selected.push(best);
106 }
107
108 Ok(selected)
109 }
110
111 fn compute_mmr_score(
113 &self,
114 query: &[f32],
115 candidate: &SearchResult,
116 selected: &[SearchResult],
117 ) -> Result<f32> {
118 let candidate_vec = candidate.vector.as_ref().ok_or_else(|| {
119 RuvectorError::InvalidParameter("Candidate vector not available".to_string())
120 })?;
121
122 let relevance = self.distance_to_similarity(candidate.score);
124
125 let max_similarity = if selected.is_empty() {
127 0.0
128 } else {
129 selected
130 .iter()
131 .filter_map(|s| s.vector.as_ref())
132 .map(|selected_vec| {
133 let dist = compute_distance(candidate_vec, selected_vec, self.config.metric);
134 self.distance_to_similarity(dist)
135 })
136 .max_by(|a, b| a.partial_cmp(b).unwrap())
137 .unwrap_or(0.0)
138 };
139
140 let mmr = self.config.lambda * relevance - (1.0 - self.config.lambda) * max_similarity;
142
143 Ok(mmr)
144 }
145
146 fn distance_to_similarity(&self, distance: f32) -> f32 {
148 match self.config.metric {
149 DistanceMetric::Cosine => 1.0 - distance,
150 DistanceMetric::Euclidean => 1.0 / (1.0 + distance),
151 DistanceMetric::Manhattan => 1.0 / (1.0 + distance),
152 DistanceMetric::DotProduct => -distance, }
154 }
155
156 pub fn search<F>(&self, query: &[f32], k: usize, search_fn: F) -> Result<Vec<SearchResult>>
166 where
167 F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
168 {
169 let fetch_k = (k as f32 * self.config.fetch_multiplier).ceil() as usize;
171 let candidates = search_fn(query, fetch_k)?;
172
173 self.rerank(query, candidates, k)
175 }
176}
177
178fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
180 match metric {
181 DistanceMetric::Euclidean => euclidean_distance(a, b),
182 DistanceMetric::Cosine => cosine_distance(a, b),
183 DistanceMetric::Manhattan => manhattan_distance(a, b),
184 DistanceMetric::DotProduct => dot_product_distance(a, b),
185 }
186}
187
188fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
189 a.iter()
190 .zip(b)
191 .map(|(x, y)| {
192 let diff = x - y;
193 diff * diff
194 })
195 .sum::<f32>()
196 .sqrt()
197}
198
199fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
200 let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
201 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
202 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
203
204 if norm_a == 0.0 || norm_b == 0.0 {
205 1.0
206 } else {
207 1.0 - (dot / (norm_a * norm_b))
208 }
209}
210
211fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
212 a.iter().zip(b).map(|(x, y)| (x - y).abs()).sum()
213}
214
215fn dot_product_distance(a: &[f32], b: &[f32]) -> f32 {
216 let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
217 -dot
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 fn create_search_result(id: &str, score: f32, vector: Vec<f32>) -> SearchResult {
225 SearchResult {
226 id: id.to_string(),
227 score,
228 vector: Some(vector),
229 metadata: None,
230 }
231 }
232
233 #[test]
234 fn test_mmr_config_validation() {
235 let config = MMRConfig {
236 lambda: 0.5,
237 ..Default::default()
238 };
239 assert!(MMRSearch::new(config).is_ok());
240
241 let invalid_config = MMRConfig {
242 lambda: 1.5,
243 ..Default::default()
244 };
245 assert!(MMRSearch::new(invalid_config).is_err());
246 }
247
248 #[test]
249 fn test_mmr_reranking() {
250 let config = MMRConfig {
251 lambda: 0.5,
252 metric: DistanceMetric::Euclidean,
253 fetch_multiplier: 2.0,
254 };
255
256 let mmr = MMRSearch::new(config).unwrap();
257 let query = vec![1.0, 0.0, 0.0];
258
259 let candidates = vec![
261 create_search_result("doc1", 0.1, vec![0.9, 0.1, 0.0]), create_search_result("doc2", 0.15, vec![0.9, 0.0, 0.1]), create_search_result("doc3", 0.5, vec![0.5, 0.5, 0.5]), create_search_result("doc4", 0.6, vec![0.0, 1.0, 0.0]), ];
266
267 let results = mmr.rerank(&query, candidates, 3).unwrap();
268
269 assert_eq!(results.len(), 3);
270 assert_eq!(results[0].id, "doc1");
272 assert!(results.iter().any(|r| r.id == "doc3" || r.id == "doc4"));
274 }
275
276 #[test]
277 fn test_mmr_pure_relevance() {
278 let config = MMRConfig {
279 lambda: 1.0, metric: DistanceMetric::Euclidean,
281 fetch_multiplier: 2.0,
282 };
283
284 let mmr = MMRSearch::new(config).unwrap();
285 let query = vec![1.0, 0.0, 0.0];
286
287 let candidates = vec![
288 create_search_result("doc1", 0.1, vec![0.9, 0.1, 0.0]),
289 create_search_result("doc2", 0.15, vec![0.85, 0.1, 0.05]),
290 create_search_result("doc3", 0.5, vec![0.5, 0.5, 0.0]),
291 ];
292
293 let results = mmr.rerank(&query, candidates, 2).unwrap();
294
295 assert_eq!(results[0].id, "doc1");
297 assert_eq!(results[1].id, "doc2");
298 }
299
300 #[test]
301 fn test_mmr_pure_diversity() {
302 let config = MMRConfig {
303 lambda: 0.0, metric: DistanceMetric::Euclidean,
305 fetch_multiplier: 2.0,
306 };
307
308 let mmr = MMRSearch::new(config).unwrap();
309 let query = vec![1.0, 0.0, 0.0];
310
311 let candidates = vec![
312 create_search_result("doc1", 0.1, vec![0.9, 0.1, 0.0]),
313 create_search_result("doc2", 0.15, vec![0.9, 0.0, 0.1]), create_search_result("doc3", 0.5, vec![0.0, 1.0, 0.0]), ];
316
317 let results = mmr.rerank(&query, candidates, 2).unwrap();
318
319 assert_eq!(results.len(), 2);
321 let has_both_similar =
323 results.iter().any(|r| r.id == "doc1") && results.iter().any(|r| r.id == "doc2");
324 assert!(!has_both_similar);
325 }
326
327 #[test]
328 fn test_mmr_empty_candidates() {
329 let config = MMRConfig::default();
330 let mmr = MMRSearch::new(config).unwrap();
331 let query = vec![1.0, 0.0, 0.0];
332
333 let results = mmr.rerank(&query, Vec::new(), 5).unwrap();
334 assert!(results.is_empty());
335 }
336}