vecstore/vectors/
bm25.rs

1//! BM25 scoring algorithm for sparse vector keyword search
2//!
3//! BM25 (Best Matching 25) is a probabilistic ranking function used for
4//! information retrieval. It's widely used in search engines and is the
5//! foundation of many modern keyword search systems.
6//!
7//! This implementation is optimized for sparse vectors, making it efficient
8//! for large vocabulary spaces where most documents contain only a small
9//! subset of terms.
10
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14/// Field weights for BM25F (multi-field scoring)
15///
16/// Maps field names to boost factors. Higher boost = more important field.
17///
18/// # Example
19/// ```
20/// use std::collections::HashMap;
21///
22/// let mut field_weights = HashMap::new();
23/// field_weights.insert("title".to_string(), 3.0);    // Title 3x more important
24/// field_weights.insert("abstract".to_string(), 2.0); // Abstract 2x
25/// field_weights.insert("content".to_string(), 1.0);  // Content baseline
26/// ```
27pub type FieldWeights = HashMap<String, f32>;
28
29/// BM25 configuration parameters
30///
31/// These parameters control how BM25 scores documents:
32/// - k1: Controls term frequency saturation (typical: 1.2-2.0)
33/// - b: Controls document length normalization (typical: 0.75)
34///
35/// # References
36/// Robertson, S. E., & Zaragoza, H. (2009). The probabilistic relevance framework: BM25 and beyond.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct BM25Config {
39    /// k1 parameter: Controls term frequency saturation
40    ///
41    /// Higher values give more weight to term frequency.
42    /// - k1 = 0: Binary (term present/absent)
43    /// - k1 = 1.2: Default, balanced
44    /// - k1 = 2.0: More emphasis on frequency
45    pub k1: f32,
46
47    /// b parameter: Controls document length normalization
48    ///
49    /// Controls how much document length affects the score.
50    /// - b = 0: No length normalization
51    /// - b = 0.75: Default, balanced
52    /// - b = 1.0: Full length normalization
53    pub b: f32,
54}
55
56impl Default for BM25Config {
57    fn default() -> Self {
58        Self { k1: 1.2, b: 0.75 }
59    }
60}
61
62/// Statistics needed for BM25 scoring across a corpus
63#[derive(Debug, Clone)]
64pub struct BM25Stats {
65    /// Average document length in the corpus (in number of terms)
66    pub avg_doc_length: f32,
67
68    /// Inverse document frequency (IDF) for each term
69    /// Map: term_index -> IDF score
70    pub idf: HashMap<usize, f32>,
71
72    /// Total number of documents in corpus
73    pub num_docs: usize,
74}
75
76impl BM25Stats {
77    /// Create BM25 statistics from a corpus of sparse vectors
78    ///
79    /// # Arguments
80    /// * `documents` - Iterator of (indices, values) pairs representing sparse documents
81    ///
82    /// # Returns
83    /// BM25Stats with computed IDF scores and average document length
84    pub fn from_corpus<'a, I>(documents: I) -> Self
85    where
86        I: Iterator<Item = (&'a [usize], &'a [f32])>,
87    {
88        let mut doc_count: HashMap<usize, usize> = HashMap::new();
89        let mut total_doc_length = 0.0;
90        let mut num_docs = 0;
91
92        // Collect statistics
93        for (indices, values) in documents {
94            num_docs += 1;
95            total_doc_length += values.iter().sum::<f32>();
96
97            // Count documents containing each term
98            for &term_idx in indices {
99                *doc_count.entry(term_idx).or_insert(0) += 1;
100            }
101        }
102
103        let avg_doc_length = if num_docs > 0 {
104            total_doc_length / num_docs as f32
105        } else {
106            0.0
107        };
108
109        // Compute IDF for each term
110        // IDF(t) = log((N - df(t) + 0.5) / (df(t) + 0.5) + 1)
111        // where N = total docs, df(t) = docs containing term t
112        let idf = doc_count
113            .into_iter()
114            .map(|(term_idx, df)| {
115                let idf_score =
116                    ((num_docs as f32 - df as f32 + 0.5) / (df as f32 + 0.5) + 1.0).ln();
117                (term_idx, idf_score)
118            })
119            .collect();
120
121        BM25Stats {
122            avg_doc_length,
123            idf,
124            num_docs,
125        }
126    }
127
128    /// Get IDF for a term, returning 0.0 if term not in corpus
129    pub fn get_idf(&self, term_idx: usize) -> f32 {
130        self.idf.get(&term_idx).copied().unwrap_or(0.0)
131    }
132}
133
134/// Calculate BM25 score between a query and a document
135///
136/// # Arguments
137/// * `query_indices` - Query term indices
138/// * `query_weights` - Query term weights (typically 1.0 for each query term)
139/// * `doc_indices` - Document term indices
140/// * `doc_values` - Document term frequencies (raw counts or TF-IDF)
141/// * `stats` - BM25 statistics from the corpus
142/// * `config` - BM25 configuration parameters
143///
144/// # Returns
145/// BM25 score (higher is better match)
146///
147/// # Example
148/// ```
149/// use vecstore::vectors::{bm25_score, BM25Config, BM25Stats};
150/// use std::collections::HashMap;
151///
152/// // Simple corpus statistics
153/// let mut idf = HashMap::new();
154/// idf.insert(10, 2.0);  // Term 10 has IDF of 2.0
155/// idf.insert(25, 1.5);  // Term 25 has IDF of 1.5
156///
157/// let stats = BM25Stats {
158///     avg_doc_length: 100.0,
159///     idf,
160///     num_docs: 1000,
161/// };
162///
163/// // Query: terms [10, 25]
164/// let query_indices = vec![10, 25];
165/// let query_weights = vec![1.0, 1.0];
166///
167/// // Document: terms [10, 25, 30] with frequencies [3.0, 2.0, 1.0]
168/// let doc_indices = vec![10, 25, 30];
169/// let doc_values = vec![3.0, 2.0, 1.0];
170///
171/// let score = bm25_score(
172///     &query_indices,
173///     &query_weights,
174///     &doc_indices,
175///     &doc_values,
176///     &stats,
177///     &BM25Config::default()
178/// );
179///
180/// assert!(score > 0.0);
181/// ```
182pub fn bm25_score(
183    query_indices: &[usize],
184    query_weights: &[f32],
185    doc_indices: &[usize],
186    doc_values: &[f32],
187    stats: &BM25Stats,
188    config: &BM25Config,
189) -> f32 {
190    // Build document term map for O(1) lookup
191    let doc_terms: HashMap<usize, f32> = doc_indices
192        .iter()
193        .zip(doc_values.iter())
194        .map(|(&idx, &val)| (idx, val))
195        .collect();
196
197    // Document length (sum of all term frequencies)
198    let doc_length = doc_values.iter().sum::<f32>();
199
200    let mut score = 0.0;
201
202    // For each query term, compute BM25 component
203    for (&term_idx, &query_weight) in query_indices.iter().zip(query_weights.iter()) {
204        // Skip if document doesn't contain this term
205        let term_freq = match doc_terms.get(&term_idx) {
206            Some(&tf) => tf,
207            None => continue,
208        };
209
210        // Get IDF for this term
211        let idf = stats.get_idf(term_idx);
212
213        // BM25 formula:
214        // score = IDF(t) * (f(t,d) * (k1 + 1)) / (f(t,d) + k1 * (1 - b + b * |d| / avgdl))
215        //
216        // where:
217        // - IDF(t) = inverse document frequency of term t
218        // - f(t,d) = frequency of term t in document d
219        // - |d| = document length
220        // - avgdl = average document length in corpus
221        // - k1, b = tuning parameters
222
223        let numerator = term_freq * (config.k1 + 1.0);
224        let denominator =
225            term_freq + config.k1 * (1.0 - config.b + config.b * doc_length / stats.avg_doc_length);
226
227        // Multiply by query weight (typically 1.0, but can be used for query boosting)
228        score += idf * query_weight * (numerator / denominator);
229    }
230
231    score
232}
233
234/// Calculate BM25 score with simplified interface (no pre-computed stats)
235///
236/// This is a convenience function for one-off scoring without building corpus statistics.
237/// For batch scoring, use `bm25_score` with pre-computed `BM25Stats` for better performance.
238///
239/// # Arguments
240/// * `query_indices` - Query term indices
241/// * `doc_indices` - Document term indices
242/// * `doc_values` - Document term frequencies
243/// * `config` - BM25 configuration (or use `BM25Config::default()`)
244///
245/// # Returns
246/// Simple frequency-based score (no IDF)
247pub fn bm25_score_simple(
248    query_indices: &[usize],
249    doc_indices: &[usize],
250    doc_values: &[f32],
251    config: &BM25Config,
252) -> f32 {
253    let doc_terms: HashMap<usize, f32> = doc_indices
254        .iter()
255        .zip(doc_values.iter())
256        .map(|(&idx, &val)| (idx, val))
257        .collect();
258
259    let doc_length = doc_values.iter().sum::<f32>();
260    let avg_doc_length = doc_length; // Assume query doc is average
261
262    let mut score = 0.0;
263
264    for &term_idx in query_indices {
265        let term_freq = match doc_terms.get(&term_idx) {
266            Some(&tf) => tf,
267            None => continue,
268        };
269
270        // Simplified BM25 without IDF (assumes IDF = 1.0)
271        let numerator = term_freq * (config.k1 + 1.0);
272        let denominator =
273            term_freq + config.k1 * (1.0 - config.b + config.b * doc_length / avg_doc_length);
274
275        score += numerator / denominator;
276    }
277
278    score
279}
280
281/// Calculate BM25F score with field boosting
282///
283/// BM25F extends BM25 to support multi-field documents where different fields
284/// can have different importance weights (e.g., title more important than body).
285///
286/// This is the algorithm used by Weaviate, Elasticsearch, and other production systems.
287///
288/// # Arguments
289/// * `query_indices` - Query term indices
290/// * `query_weights` - Query term weights (typically 1.0 for each)
291/// * `doc_fields` - Map of field_name -> (term_indices, term_values)
292/// * `field_weights` - Map of field_name -> boost_factor (e.g., "title" -> 3.0)
293/// * `stats` - BM25 statistics from the corpus
294/// * `config` - BM25 configuration parameters
295///
296/// # Returns
297/// BM25F score (higher is better match)
298///
299/// # Example
300/// ```
301/// use vecstore::vectors::{bm25f_score, BM25Config, BM25Stats, FieldWeights};
302/// use std::collections::HashMap;
303///
304/// // Field weights: title is 3x more important than content
305/// let mut field_weights: FieldWeights = HashMap::new();
306/// field_weights.insert("title".to_string(), 3.0);
307/// field_weights.insert("content".to_string(), 1.0);
308///
309/// // Document with multiple fields
310/// let mut doc_fields = HashMap::new();
311/// doc_fields.insert("title".to_string(), (vec![10, 25], vec![1.0, 1.0]));
312/// doc_fields.insert("content".to_string(), (vec![10, 30, 40], vec![2.0, 1.0, 1.0]));
313///
314/// // Stats (simplified for example)
315/// let mut idf = HashMap::new();
316/// idf.insert(10, 2.0);
317/// idf.insert(25, 1.5);
318/// idf.insert(30, 1.0);
319///
320/// let stats = BM25Stats {
321///     avg_doc_length: 10.0,
322///     idf,
323///     num_docs: 1000,
324/// };
325///
326/// let query_indices = vec![10, 25];
327/// let query_weights = vec![1.0, 1.0];
328///
329/// let score = bm25f_score(
330///     &query_indices,
331///     &query_weights,
332///     &doc_fields,
333///     &field_weights,
334///     &stats,
335///     &BM25Config::default()
336/// );
337///
338/// assert!(score > 0.0);
339/// ```
340pub fn bm25f_score(
341    query_indices: &[usize],
342    query_weights: &[f32],
343    doc_fields: &HashMap<String, (Vec<usize>, Vec<f32>)>,
344    field_weights: &FieldWeights,
345    stats: &BM25Stats,
346    config: &BM25Config,
347) -> f32 {
348    // BM25F algorithm:
349    // 1. For each field, compute weighted term frequencies
350    // 2. Combine weighted frequencies across fields
351    // 3. Apply BM25 formula with combined frequencies
352
353    // Build combined term frequency map across all fields
354    let mut combined_tf: HashMap<usize, f32> = HashMap::new();
355    let mut total_doc_length = 0.0;
356
357    for (field_name, (indices, values)) in doc_fields {
358        let boost = field_weights.get(field_name).copied().unwrap_or(1.0);
359        let field_length: f32 = values.iter().sum();
360        total_doc_length += field_length * boost;
361
362        // Add weighted term frequencies from this field
363        for (&term_idx, &freq) in indices.iter().zip(values.iter()) {
364            *combined_tf.entry(term_idx).or_insert(0.0) += freq * boost;
365        }
366    }
367
368    let mut score = 0.0;
369
370    // For each query term, compute BM25F component
371    for (&term_idx, &query_weight) in query_indices.iter().zip(query_weights.iter()) {
372        // Get combined term frequency across all fields
373        let term_freq = match combined_tf.get(&term_idx) {
374            Some(&tf) => tf,
375            None => continue,
376        };
377
378        // Get IDF for this term
379        let idf = stats.get_idf(term_idx);
380
381        // BM25F formula (same as BM25, but with field-weighted frequencies)
382        let numerator = term_freq * (config.k1 + 1.0);
383        let denominator = term_freq
384            + config.k1 * (1.0 - config.b + config.b * total_doc_length / stats.avg_doc_length);
385
386        score += idf * query_weight * (numerator / denominator);
387    }
388
389    score
390}
391
392/// Parse field weights from a string like "title^3" or "content^1.5"
393///
394/// # Example
395/// ```
396/// use vecstore::vectors::parse_field_weight;
397///
398/// assert_eq!(parse_field_weight("title^3"), ("title", 3.0));
399/// assert_eq!(parse_field_weight("content^1.5"), ("content", 1.5));
400/// assert_eq!(parse_field_weight("body"), ("body", 1.0)); // Default weight
401/// ```
402pub fn parse_field_weight(field_spec: &str) -> (&str, f32) {
403    if let Some(pos) = field_spec.find('^') {
404        let field = &field_spec[..pos];
405        let weight_str = &field_spec[pos + 1..];
406        let weight = weight_str.parse::<f32>().unwrap_or(1.0);
407        (field, weight)
408    } else {
409        (field_spec, 1.0)
410    }
411}
412
413/// Parse multiple field weight specifications
414///
415/// # Example
416/// ```
417/// use vecstore::vectors::parse_field_weights;
418/// use std::collections::HashMap;
419///
420/// let fields = vec!["title^3", "abstract^2", "content"];
421/// let weights = parse_field_weights(&fields);
422///
423/// assert_eq!(weights.get("title"), Some(&3.0));
424/// assert_eq!(weights.get("abstract"), Some(&2.0));
425/// assert_eq!(weights.get("content"), Some(&1.0));
426/// ```
427pub fn parse_field_weights(field_specs: &[&str]) -> FieldWeights {
428    field_specs
429        .iter()
430        .map(|spec| {
431            let (field, weight) = parse_field_weight(spec);
432            (field.to_string(), weight)
433        })
434        .collect()
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440
441    #[test]
442    fn test_bm25_config_default() {
443        let config = BM25Config::default();
444        assert_eq!(config.k1, 1.2);
445        assert_eq!(config.b, 0.75);
446    }
447
448    #[test]
449    fn test_bm25_stats_from_corpus() {
450        // Corpus: 3 documents
451        // Doc 1: terms [1, 2, 3]
452        // Doc 2: terms [1, 2]
453        // Doc 3: terms [1, 4]
454        let corpus = vec![
455            (vec![1, 2, 3], vec![1.0, 1.0, 1.0]),
456            (vec![1, 2], vec![1.0, 1.0]),
457            (vec![1, 4], vec![1.0, 1.0]),
458        ];
459
460        let docs: Vec<(&[usize], &[f32])> = corpus
461            .iter()
462            .map(|(indices, values)| (indices.as_slice(), values.as_slice()))
463            .collect();
464
465        let stats = BM25Stats::from_corpus(docs.into_iter());
466
467        assert_eq!(stats.num_docs, 3);
468        assert_eq!(stats.avg_doc_length, (3.0 + 2.0 + 2.0) / 3.0);
469
470        // Term 1 appears in all 3 docs
471        let idf_1 = stats.get_idf(1);
472        assert!(idf_1 > 0.0); // Should have some IDF
473
474        // Term 2 appears in 2 docs
475        let idf_2 = stats.get_idf(2);
476        assert!(idf_2 > idf_1); // Should have higher IDF than term 1
477
478        // Term 5 doesn't appear
479        let idf_5 = stats.get_idf(5);
480        assert_eq!(idf_5, 0.0);
481    }
482
483    #[test]
484    fn test_bm25_score_exact_match() {
485        // Query and document are identical
486        let mut idf = HashMap::new();
487        idf.insert(1, 1.0);
488        idf.insert(2, 1.0);
489
490        let stats = BM25Stats {
491            avg_doc_length: 2.0,
492            idf,
493            num_docs: 100,
494        };
495
496        let query_indices = vec![1, 2];
497        let query_weights = vec![1.0, 1.0];
498        let doc_indices = vec![1, 2];
499        let doc_values = vec![1.0, 1.0];
500
501        let score = bm25_score(
502            &query_indices,
503            &query_weights,
504            &doc_indices,
505            &doc_values,
506            &stats,
507            &BM25Config::default(),
508        );
509
510        assert!(score > 0.0);
511    }
512
513    #[test]
514    fn test_bm25_score_no_match() {
515        // Query and document have no overlapping terms
516        let mut idf = HashMap::new();
517        idf.insert(1, 1.0);
518        idf.insert(2, 1.0);
519        idf.insert(3, 1.0);
520        idf.insert(4, 1.0);
521
522        let stats = BM25Stats {
523            avg_doc_length: 2.0,
524            idf,
525            num_docs: 100,
526        };
527
528        let query_indices = vec![1, 2];
529        let query_weights = vec![1.0, 1.0];
530        let doc_indices = vec![3, 4];
531        let doc_values = vec![1.0, 1.0];
532
533        let score = bm25_score(
534            &query_indices,
535            &query_weights,
536            &doc_indices,
537            &doc_values,
538            &stats,
539            &BM25Config::default(),
540        );
541
542        assert_eq!(score, 0.0);
543    }
544
545    #[test]
546    fn test_bm25_score_partial_match() {
547        // Query [1, 2], Document [1, 3]
548        let mut idf = HashMap::new();
549        idf.insert(1, 2.0);
550        idf.insert(2, 2.0);
551        idf.insert(3, 2.0);
552
553        let stats = BM25Stats {
554            avg_doc_length: 2.0,
555            idf,
556            num_docs: 100,
557        };
558
559        let query_indices = vec![1, 2];
560        let query_weights = vec![1.0, 1.0];
561        let doc_indices = vec![1, 3];
562        let doc_values = vec![1.0, 1.0];
563
564        let score = bm25_score(
565            &query_indices,
566            &query_weights,
567            &doc_indices,
568            &doc_values,
569            &stats,
570            &BM25Config::default(),
571        );
572
573        // Should score > 0 because term 1 matches
574        assert!(score > 0.0);
575    }
576
577    #[test]
578    fn test_bm25_score_frequency_matters() {
579        // Higher term frequency should yield higher score
580        let mut idf = HashMap::new();
581        idf.insert(1, 2.0);
582
583        let stats = BM25Stats {
584            avg_doc_length: 5.0,
585            idf,
586            num_docs: 100,
587        };
588
589        let query_indices = vec![1];
590        let query_weights = vec![1.0];
591
592        // Document 1: term appears once
593        let doc1_indices = vec![1];
594        let doc1_values = vec![1.0];
595
596        let score1 = bm25_score(
597            &query_indices,
598            &query_weights,
599            &doc1_indices,
600            &doc1_values,
601            &stats,
602            &BM25Config::default(),
603        );
604
605        // Document 2: term appears 5 times
606        let doc2_indices = vec![1];
607        let doc2_values = vec![5.0];
608
609        let score2 = bm25_score(
610            &query_indices,
611            &query_weights,
612            &doc2_indices,
613            &doc2_values,
614            &stats,
615            &BM25Config::default(),
616        );
617
618        assert!(score2 > score1);
619    }
620
621    #[test]
622    fn test_bm25_score_simple() {
623        let query_indices = vec![1, 2];
624        let doc_indices = vec![1, 2, 3];
625        let doc_values = vec![2.0, 1.0, 1.0];
626
627        let score = bm25_score_simple(
628            &query_indices,
629            &doc_indices,
630            &doc_values,
631            &BM25Config::default(),
632        );
633
634        assert!(score > 0.0);
635    }
636
637    #[test]
638    fn test_bm25_k1_parameter() {
639        // Test that k1 affects term frequency saturation
640        let mut idf = HashMap::new();
641        idf.insert(1, 1.0);
642
643        let stats = BM25Stats {
644            avg_doc_length: 10.0,
645            idf,
646            num_docs: 100,
647        };
648
649        let query_indices = vec![1];
650        let query_weights = vec![1.0];
651        let doc_indices = vec![1];
652        let doc_values = vec![10.0]; // High frequency
653
654        // Low k1 = more saturation
655        let config_low = BM25Config { k1: 0.5, b: 0.75 };
656        let score_low = bm25_score(
657            &query_indices,
658            &query_weights,
659            &doc_indices,
660            &doc_values,
661            &stats,
662            &config_low,
663        );
664
665        // High k1 = less saturation, more weight on frequency
666        let config_high = BM25Config { k1: 3.0, b: 0.75 };
667        let score_high = bm25_score(
668            &query_indices,
669            &query_weights,
670            &doc_indices,
671            &doc_values,
672            &stats,
673            &config_high,
674        );
675
676        assert!(score_high > score_low);
677    }
678
679    // ============================================================================
680    // BM25F Tests (Field Boosting)
681    // ============================================================================
682
683    #[test]
684    fn test_parse_field_weight_with_boost() {
685        let (field, weight) = parse_field_weight("title^3");
686        assert_eq!(field, "title");
687        assert_eq!(weight, 3.0);
688    }
689
690    #[test]
691    fn test_parse_field_weight_with_float_boost() {
692        let (field, weight) = parse_field_weight("abstract^2.5");
693        assert_eq!(field, "abstract");
694        assert_eq!(weight, 2.5);
695    }
696
697    #[test]
698    fn test_parse_field_weight_without_boost() {
699        let (field, weight) = parse_field_weight("content");
700        assert_eq!(field, "content");
701        assert_eq!(weight, 1.0);
702    }
703
704    #[test]
705    fn test_parse_field_weight_invalid_boost() {
706        let (field, weight) = parse_field_weight("title^invalid");
707        assert_eq!(field, "title");
708        assert_eq!(weight, 1.0); // Should default to 1.0 on parse error
709    }
710
711    #[test]
712    fn test_parse_field_weights_multiple() {
713        let specs = vec!["title^3", "abstract^2", "content"];
714        let weights = parse_field_weights(&specs);
715
716        assert_eq!(weights.len(), 3);
717        assert_eq!(weights.get("title"), Some(&3.0));
718        assert_eq!(weights.get("abstract"), Some(&2.0));
719        assert_eq!(weights.get("content"), Some(&1.0));
720    }
721
722    #[test]
723    fn test_parse_field_weights_empty() {
724        let specs: Vec<&str> = vec![];
725        let weights = parse_field_weights(&specs);
726        assert_eq!(weights.len(), 0);
727    }
728
729    #[test]
730    fn test_bm25f_single_field_matches_regular_bm25() {
731        // BM25F with single field should match regular BM25
732        let mut idf = HashMap::new();
733        idf.insert(1, 2.0);
734        idf.insert(2, 1.5);
735
736        let stats = BM25Stats {
737            avg_doc_length: 10.0,
738            idf,
739            num_docs: 100,
740        };
741
742        let query_indices = vec![1, 2];
743        let query_weights = vec![1.0, 1.0];
744        let doc_indices = vec![1, 2, 3];
745        let doc_values = vec![2.0, 1.0, 1.0];
746
747        // Regular BM25 score
748        let regular_score = bm25_score(
749            &query_indices,
750            &query_weights,
751            &doc_indices,
752            &doc_values,
753            &stats,
754            &BM25Config::default(),
755        );
756
757        // BM25F score with single field (weight=1.0)
758        let mut doc_fields = HashMap::new();
759        doc_fields.insert(
760            "content".to_string(),
761            (doc_indices.clone(), doc_values.clone()),
762        );
763
764        let mut field_weights = HashMap::new();
765        field_weights.insert("content".to_string(), 1.0);
766
767        let bm25f_score_result = bm25f_score(
768            &query_indices,
769            &query_weights,
770            &doc_fields,
771            &field_weights,
772            &stats,
773            &BM25Config::default(),
774        );
775
776        // Should be very close (allowing for floating point precision)
777        assert!((regular_score - bm25f_score_result).abs() < 0.01);
778    }
779
780    #[test]
781    fn test_bm25f_multiple_fields() {
782        // Multi-field document
783        let mut idf = HashMap::new();
784        idf.insert(1, 2.0); // term "rust"
785        idf.insert(2, 1.5); // term "database"
786        idf.insert(3, 1.0); // term "vector"
787
788        let stats = BM25Stats {
789            avg_doc_length: 10.0,
790            idf,
791            num_docs: 100,
792        };
793
794        let query_indices = vec![1, 2]; // searching for "rust database"
795        let query_weights = vec![1.0, 1.0];
796
797        // Document has three fields
798        let mut doc_fields = HashMap::new();
799
800        // Title: "rust database" (both terms appear)
801        doc_fields.insert("title".to_string(), (vec![1, 2], vec![1.0, 1.0]));
802
803        // Abstract: "rust" (only first term)
804        doc_fields.insert("abstract".to_string(), (vec![1, 3], vec![1.0, 1.0]));
805
806        // Content: "database vector" (only second term)
807        doc_fields.insert("content".to_string(), (vec![2, 3], vec![1.0, 1.0]));
808
809        // All fields equal weight
810        let mut field_weights = HashMap::new();
811        field_weights.insert("title".to_string(), 1.0);
812        field_weights.insert("abstract".to_string(), 1.0);
813        field_weights.insert("content".to_string(), 1.0);
814
815        let score = bm25f_score(
816            &query_indices,
817            &query_weights,
818            &doc_fields,
819            &field_weights,
820            &stats,
821            &BM25Config::default(),
822        );
823
824        assert!(score > 0.0);
825    }
826
827    #[test]
828    fn test_bm25f_title_boost() {
829        // Test that title boost increases score
830        let mut idf = HashMap::new();
831        idf.insert(1, 2.0);
832
833        let stats = BM25Stats {
834            avg_doc_length: 10.0,
835            idf,
836            num_docs: 100,
837        };
838
839        let query_indices = vec![1];
840        let query_weights = vec![1.0];
841
842        let mut doc_fields = HashMap::new();
843        doc_fields.insert("title".to_string(), (vec![1], vec![1.0]));
844        doc_fields.insert("content".to_string(), (vec![1], vec![1.0]));
845
846        // No boost
847        let mut field_weights_no_boost = HashMap::new();
848        field_weights_no_boost.insert("title".to_string(), 1.0);
849        field_weights_no_boost.insert("content".to_string(), 1.0);
850
851        let score_no_boost = bm25f_score(
852            &query_indices,
853            &query_weights,
854            &doc_fields,
855            &field_weights_no_boost,
856            &stats,
857            &BM25Config::default(),
858        );
859
860        // Title boosted 3x
861        let mut field_weights_with_boost = HashMap::new();
862        field_weights_with_boost.insert("title".to_string(), 3.0);
863        field_weights_with_boost.insert("content".to_string(), 1.0);
864
865        let score_with_boost = bm25f_score(
866            &query_indices,
867            &query_weights,
868            &doc_fields,
869            &field_weights_with_boost,
870            &stats,
871            &BM25Config::default(),
872        );
873
874        // Boosted score should be higher
875        assert!(score_with_boost > score_no_boost);
876    }
877
878    #[test]
879    fn test_bm25f_missing_field_weight() {
880        // Fields without explicit weights should default to 1.0
881        let mut idf = HashMap::new();
882        idf.insert(1, 2.0);
883
884        let stats = BM25Stats {
885            avg_doc_length: 10.0,
886            idf,
887            num_docs: 100,
888        };
889
890        let query_indices = vec![1];
891        let query_weights = vec![1.0];
892
893        let mut doc_fields = HashMap::new();
894        doc_fields.insert("title".to_string(), (vec![1], vec![1.0]));
895        doc_fields.insert("content".to_string(), (vec![1], vec![1.0]));
896
897        // Only specify weight for title, not content
898        let mut field_weights = HashMap::new();
899        field_weights.insert("title".to_string(), 2.0);
900
901        let score = bm25f_score(
902            &query_indices,
903            &query_weights,
904            &doc_fields,
905            &field_weights,
906            &stats,
907            &BM25Config::default(),
908        );
909
910        // Should still work, content defaults to 1.0
911        assert!(score > 0.0);
912    }
913
914    #[test]
915    fn test_bm25f_no_matching_terms() {
916        let mut idf = HashMap::new();
917        idf.insert(1, 2.0);
918        idf.insert(2, 1.5);
919
920        let stats = BM25Stats {
921            avg_doc_length: 10.0,
922            idf,
923            num_docs: 100,
924        };
925
926        let query_indices = vec![1, 2];
927        let query_weights = vec![1.0, 1.0];
928
929        let mut doc_fields = HashMap::new();
930        // Document has different terms
931        doc_fields.insert("title".to_string(), (vec![3, 4], vec![1.0, 1.0]));
932
933        let mut field_weights = HashMap::new();
934        field_weights.insert("title".to_string(), 1.0);
935
936        let score = bm25f_score(
937            &query_indices,
938            &query_weights,
939            &doc_fields,
940            &field_weights,
941            &stats,
942            &BM25Config::default(),
943        );
944
945        assert_eq!(score, 0.0);
946    }
947
948    #[test]
949    fn test_bm25f_empty_fields() {
950        let mut idf = HashMap::new();
951        idf.insert(1, 2.0);
952
953        let stats = BM25Stats {
954            avg_doc_length: 10.0,
955            idf,
956            num_docs: 100,
957        };
958
959        let query_indices = vec![1];
960        let query_weights = vec![1.0];
961
962        let doc_fields = HashMap::new(); // No fields
963        let field_weights = HashMap::new();
964
965        let score = bm25f_score(
966            &query_indices,
967            &query_weights,
968            &doc_fields,
969            &field_weights,
970            &stats,
971            &BM25Config::default(),
972        );
973
974        assert_eq!(score, 0.0);
975    }
976
977    #[test]
978    fn test_bm25f_realistic_document() {
979        // Realistic example: searching for "rust vector database"
980        let mut idf = HashMap::new();
981        idf.insert(100, 2.5); // "rust" - moderately rare
982        idf.insert(200, 2.0); // "vector" - less rare
983        idf.insert(300, 1.8); // "database" - common
984
985        let stats = BM25Stats {
986            avg_doc_length: 50.0,
987            idf,
988            num_docs: 1000,
989        };
990
991        let query_indices = vec![100, 200, 300];
992        let query_weights = vec![1.0, 1.0, 1.0];
993
994        // Document: Title="Rust Vector Store", Abstract="A fast vector database", Content=long text
995        let mut doc_fields = HashMap::new();
996        doc_fields.insert("title".to_string(), (vec![100, 200], vec![1.0, 1.0])); // "rust vector"
997        doc_fields.insert("abstract".to_string(), (vec![200, 300], vec![1.0, 1.0])); // "vector database"
998        doc_fields.insert(
999            "content".to_string(),
1000            (vec![100, 200, 300], vec![2.0, 3.0, 1.0]),
1001        ); // all terms
1002
1003        // Parse field weights using our helper
1004        let field_weights = parse_field_weights(&["title^3", "abstract^2", "content"]);
1005
1006        let score = bm25f_score(
1007            &query_indices,
1008            &query_weights,
1009            &doc_fields,
1010            &field_weights,
1011            &stats,
1012            &BM25Config::default(),
1013        );
1014
1015        // Should have good score since all terms match and title is boosted
1016        assert!(score > 5.0); // Reasonable threshold for this setup
1017    }
1018}