1use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14pub type FieldWeights = HashMap<String, f32>;
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct BM25Config {
39 pub k1: f32,
46
47 pub b: f32,
54}
55
56impl Default for BM25Config {
57 fn default() -> Self {
58 Self { k1: 1.2, b: 0.75 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct BM25Stats {
65 pub avg_doc_length: f32,
67
68 pub idf: HashMap<usize, f32>,
71
72 pub num_docs: usize,
74}
75
76impl BM25Stats {
77 pub fn from_corpus<'a, I>(documents: I) -> Self
85 where
86 I: Iterator<Item = (&'a [usize], &'a [f32])>,
87 {
88 let mut doc_count: HashMap<usize, usize> = HashMap::new();
89 let mut total_doc_length = 0.0;
90 let mut num_docs = 0;
91
92 for (indices, values) in documents {
94 num_docs += 1;
95 total_doc_length += values.iter().sum::<f32>();
96
97 for &term_idx in indices {
99 *doc_count.entry(term_idx).or_insert(0) += 1;
100 }
101 }
102
103 let avg_doc_length = if num_docs > 0 {
104 total_doc_length / num_docs as f32
105 } else {
106 0.0
107 };
108
109 let idf = doc_count
113 .into_iter()
114 .map(|(term_idx, df)| {
115 let idf_score =
116 ((num_docs as f32 - df as f32 + 0.5) / (df as f32 + 0.5) + 1.0).ln();
117 (term_idx, idf_score)
118 })
119 .collect();
120
121 BM25Stats {
122 avg_doc_length,
123 idf,
124 num_docs,
125 }
126 }
127
128 pub fn get_idf(&self, term_idx: usize) -> f32 {
130 self.idf.get(&term_idx).copied().unwrap_or(0.0)
131 }
132}
133
134pub fn bm25_score(
183 query_indices: &[usize],
184 query_weights: &[f32],
185 doc_indices: &[usize],
186 doc_values: &[f32],
187 stats: &BM25Stats,
188 config: &BM25Config,
189) -> f32 {
190 let doc_terms: HashMap<usize, f32> = doc_indices
192 .iter()
193 .zip(doc_values.iter())
194 .map(|(&idx, &val)| (idx, val))
195 .collect();
196
197 let doc_length = doc_values.iter().sum::<f32>();
199
200 let mut score = 0.0;
201
202 for (&term_idx, &query_weight) in query_indices.iter().zip(query_weights.iter()) {
204 let term_freq = match doc_terms.get(&term_idx) {
206 Some(&tf) => tf,
207 None => continue,
208 };
209
210 let idf = stats.get_idf(term_idx);
212
213 let numerator = term_freq * (config.k1 + 1.0);
224 let denominator =
225 term_freq + config.k1 * (1.0 - config.b + config.b * doc_length / stats.avg_doc_length);
226
227 score += idf * query_weight * (numerator / denominator);
229 }
230
231 score
232}
233
234pub fn bm25_score_simple(
248 query_indices: &[usize],
249 doc_indices: &[usize],
250 doc_values: &[f32],
251 config: &BM25Config,
252) -> f32 {
253 let doc_terms: HashMap<usize, f32> = doc_indices
254 .iter()
255 .zip(doc_values.iter())
256 .map(|(&idx, &val)| (idx, val))
257 .collect();
258
259 let doc_length = doc_values.iter().sum::<f32>();
260 let avg_doc_length = doc_length; let mut score = 0.0;
263
264 for &term_idx in query_indices {
265 let term_freq = match doc_terms.get(&term_idx) {
266 Some(&tf) => tf,
267 None => continue,
268 };
269
270 let numerator = term_freq * (config.k1 + 1.0);
272 let denominator =
273 term_freq + config.k1 * (1.0 - config.b + config.b * doc_length / avg_doc_length);
274
275 score += numerator / denominator;
276 }
277
278 score
279}
280
281pub fn bm25f_score(
341 query_indices: &[usize],
342 query_weights: &[f32],
343 doc_fields: &HashMap<String, (Vec<usize>, Vec<f32>)>,
344 field_weights: &FieldWeights,
345 stats: &BM25Stats,
346 config: &BM25Config,
347) -> f32 {
348 let mut combined_tf: HashMap<usize, f32> = HashMap::new();
355 let mut total_doc_length = 0.0;
356
357 for (field_name, (indices, values)) in doc_fields {
358 let boost = field_weights.get(field_name).copied().unwrap_or(1.0);
359 let field_length: f32 = values.iter().sum();
360 total_doc_length += field_length * boost;
361
362 for (&term_idx, &freq) in indices.iter().zip(values.iter()) {
364 *combined_tf.entry(term_idx).or_insert(0.0) += freq * boost;
365 }
366 }
367
368 let mut score = 0.0;
369
370 for (&term_idx, &query_weight) in query_indices.iter().zip(query_weights.iter()) {
372 let term_freq = match combined_tf.get(&term_idx) {
374 Some(&tf) => tf,
375 None => continue,
376 };
377
378 let idf = stats.get_idf(term_idx);
380
381 let numerator = term_freq * (config.k1 + 1.0);
383 let denominator = term_freq
384 + config.k1 * (1.0 - config.b + config.b * total_doc_length / stats.avg_doc_length);
385
386 score += idf * query_weight * (numerator / denominator);
387 }
388
389 score
390}
391
392pub fn parse_field_weight(field_spec: &str) -> (&str, f32) {
403 if let Some(pos) = field_spec.find('^') {
404 let field = &field_spec[..pos];
405 let weight_str = &field_spec[pos + 1..];
406 let weight = weight_str.parse::<f32>().unwrap_or(1.0);
407 (field, weight)
408 } else {
409 (field_spec, 1.0)
410 }
411}
412
413pub fn parse_field_weights(field_specs: &[&str]) -> FieldWeights {
428 field_specs
429 .iter()
430 .map(|spec| {
431 let (field, weight) = parse_field_weight(spec);
432 (field.to_string(), weight)
433 })
434 .collect()
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440
441 #[test]
442 fn test_bm25_config_default() {
443 let config = BM25Config::default();
444 assert_eq!(config.k1, 1.2);
445 assert_eq!(config.b, 0.75);
446 }
447
448 #[test]
449 fn test_bm25_stats_from_corpus() {
450 let corpus = vec![
455 (vec![1, 2, 3], vec![1.0, 1.0, 1.0]),
456 (vec![1, 2], vec![1.0, 1.0]),
457 (vec![1, 4], vec![1.0, 1.0]),
458 ];
459
460 let docs: Vec<(&[usize], &[f32])> = corpus
461 .iter()
462 .map(|(indices, values)| (indices.as_slice(), values.as_slice()))
463 .collect();
464
465 let stats = BM25Stats::from_corpus(docs.into_iter());
466
467 assert_eq!(stats.num_docs, 3);
468 assert_eq!(stats.avg_doc_length, (3.0 + 2.0 + 2.0) / 3.0);
469
470 let idf_1 = stats.get_idf(1);
472 assert!(idf_1 > 0.0); let idf_2 = stats.get_idf(2);
476 assert!(idf_2 > idf_1); let idf_5 = stats.get_idf(5);
480 assert_eq!(idf_5, 0.0);
481 }
482
483 #[test]
484 fn test_bm25_score_exact_match() {
485 let mut idf = HashMap::new();
487 idf.insert(1, 1.0);
488 idf.insert(2, 1.0);
489
490 let stats = BM25Stats {
491 avg_doc_length: 2.0,
492 idf,
493 num_docs: 100,
494 };
495
496 let query_indices = vec![1, 2];
497 let query_weights = vec![1.0, 1.0];
498 let doc_indices = vec![1, 2];
499 let doc_values = vec![1.0, 1.0];
500
501 let score = bm25_score(
502 &query_indices,
503 &query_weights,
504 &doc_indices,
505 &doc_values,
506 &stats,
507 &BM25Config::default(),
508 );
509
510 assert!(score > 0.0);
511 }
512
513 #[test]
514 fn test_bm25_score_no_match() {
515 let mut idf = HashMap::new();
517 idf.insert(1, 1.0);
518 idf.insert(2, 1.0);
519 idf.insert(3, 1.0);
520 idf.insert(4, 1.0);
521
522 let stats = BM25Stats {
523 avg_doc_length: 2.0,
524 idf,
525 num_docs: 100,
526 };
527
528 let query_indices = vec![1, 2];
529 let query_weights = vec![1.0, 1.0];
530 let doc_indices = vec![3, 4];
531 let doc_values = vec![1.0, 1.0];
532
533 let score = bm25_score(
534 &query_indices,
535 &query_weights,
536 &doc_indices,
537 &doc_values,
538 &stats,
539 &BM25Config::default(),
540 );
541
542 assert_eq!(score, 0.0);
543 }
544
545 #[test]
546 fn test_bm25_score_partial_match() {
547 let mut idf = HashMap::new();
549 idf.insert(1, 2.0);
550 idf.insert(2, 2.0);
551 idf.insert(3, 2.0);
552
553 let stats = BM25Stats {
554 avg_doc_length: 2.0,
555 idf,
556 num_docs: 100,
557 };
558
559 let query_indices = vec![1, 2];
560 let query_weights = vec![1.0, 1.0];
561 let doc_indices = vec![1, 3];
562 let doc_values = vec![1.0, 1.0];
563
564 let score = bm25_score(
565 &query_indices,
566 &query_weights,
567 &doc_indices,
568 &doc_values,
569 &stats,
570 &BM25Config::default(),
571 );
572
573 assert!(score > 0.0);
575 }
576
577 #[test]
578 fn test_bm25_score_frequency_matters() {
579 let mut idf = HashMap::new();
581 idf.insert(1, 2.0);
582
583 let stats = BM25Stats {
584 avg_doc_length: 5.0,
585 idf,
586 num_docs: 100,
587 };
588
589 let query_indices = vec![1];
590 let query_weights = vec![1.0];
591
592 let doc1_indices = vec![1];
594 let doc1_values = vec![1.0];
595
596 let score1 = bm25_score(
597 &query_indices,
598 &query_weights,
599 &doc1_indices,
600 &doc1_values,
601 &stats,
602 &BM25Config::default(),
603 );
604
605 let doc2_indices = vec![1];
607 let doc2_values = vec![5.0];
608
609 let score2 = bm25_score(
610 &query_indices,
611 &query_weights,
612 &doc2_indices,
613 &doc2_values,
614 &stats,
615 &BM25Config::default(),
616 );
617
618 assert!(score2 > score1);
619 }
620
621 #[test]
622 fn test_bm25_score_simple() {
623 let query_indices = vec![1, 2];
624 let doc_indices = vec![1, 2, 3];
625 let doc_values = vec![2.0, 1.0, 1.0];
626
627 let score = bm25_score_simple(
628 &query_indices,
629 &doc_indices,
630 &doc_values,
631 &BM25Config::default(),
632 );
633
634 assert!(score > 0.0);
635 }
636
637 #[test]
638 fn test_bm25_k1_parameter() {
639 let mut idf = HashMap::new();
641 idf.insert(1, 1.0);
642
643 let stats = BM25Stats {
644 avg_doc_length: 10.0,
645 idf,
646 num_docs: 100,
647 };
648
649 let query_indices = vec![1];
650 let query_weights = vec![1.0];
651 let doc_indices = vec![1];
652 let doc_values = vec![10.0]; let config_low = BM25Config { k1: 0.5, b: 0.75 };
656 let score_low = bm25_score(
657 &query_indices,
658 &query_weights,
659 &doc_indices,
660 &doc_values,
661 &stats,
662 &config_low,
663 );
664
665 let config_high = BM25Config { k1: 3.0, b: 0.75 };
667 let score_high = bm25_score(
668 &query_indices,
669 &query_weights,
670 &doc_indices,
671 &doc_values,
672 &stats,
673 &config_high,
674 );
675
676 assert!(score_high > score_low);
677 }
678
679 #[test]
684 fn test_parse_field_weight_with_boost() {
685 let (field, weight) = parse_field_weight("title^3");
686 assert_eq!(field, "title");
687 assert_eq!(weight, 3.0);
688 }
689
690 #[test]
691 fn test_parse_field_weight_with_float_boost() {
692 let (field, weight) = parse_field_weight("abstract^2.5");
693 assert_eq!(field, "abstract");
694 assert_eq!(weight, 2.5);
695 }
696
697 #[test]
698 fn test_parse_field_weight_without_boost() {
699 let (field, weight) = parse_field_weight("content");
700 assert_eq!(field, "content");
701 assert_eq!(weight, 1.0);
702 }
703
704 #[test]
705 fn test_parse_field_weight_invalid_boost() {
706 let (field, weight) = parse_field_weight("title^invalid");
707 assert_eq!(field, "title");
708 assert_eq!(weight, 1.0); }
710
711 #[test]
712 fn test_parse_field_weights_multiple() {
713 let specs = vec!["title^3", "abstract^2", "content"];
714 let weights = parse_field_weights(&specs);
715
716 assert_eq!(weights.len(), 3);
717 assert_eq!(weights.get("title"), Some(&3.0));
718 assert_eq!(weights.get("abstract"), Some(&2.0));
719 assert_eq!(weights.get("content"), Some(&1.0));
720 }
721
722 #[test]
723 fn test_parse_field_weights_empty() {
724 let specs: Vec<&str> = vec![];
725 let weights = parse_field_weights(&specs);
726 assert_eq!(weights.len(), 0);
727 }
728
729 #[test]
730 fn test_bm25f_single_field_matches_regular_bm25() {
731 let mut idf = HashMap::new();
733 idf.insert(1, 2.0);
734 idf.insert(2, 1.5);
735
736 let stats = BM25Stats {
737 avg_doc_length: 10.0,
738 idf,
739 num_docs: 100,
740 };
741
742 let query_indices = vec![1, 2];
743 let query_weights = vec![1.0, 1.0];
744 let doc_indices = vec![1, 2, 3];
745 let doc_values = vec![2.0, 1.0, 1.0];
746
747 let regular_score = bm25_score(
749 &query_indices,
750 &query_weights,
751 &doc_indices,
752 &doc_values,
753 &stats,
754 &BM25Config::default(),
755 );
756
757 let mut doc_fields = HashMap::new();
759 doc_fields.insert(
760 "content".to_string(),
761 (doc_indices.clone(), doc_values.clone()),
762 );
763
764 let mut field_weights = HashMap::new();
765 field_weights.insert("content".to_string(), 1.0);
766
767 let bm25f_score_result = bm25f_score(
768 &query_indices,
769 &query_weights,
770 &doc_fields,
771 &field_weights,
772 &stats,
773 &BM25Config::default(),
774 );
775
776 assert!((regular_score - bm25f_score_result).abs() < 0.01);
778 }
779
780 #[test]
781 fn test_bm25f_multiple_fields() {
782 let mut idf = HashMap::new();
784 idf.insert(1, 2.0); idf.insert(2, 1.5); idf.insert(3, 1.0); let stats = BM25Stats {
789 avg_doc_length: 10.0,
790 idf,
791 num_docs: 100,
792 };
793
794 let query_indices = vec![1, 2]; let query_weights = vec![1.0, 1.0];
796
797 let mut doc_fields = HashMap::new();
799
800 doc_fields.insert("title".to_string(), (vec![1, 2], vec![1.0, 1.0]));
802
803 doc_fields.insert("abstract".to_string(), (vec![1, 3], vec![1.0, 1.0]));
805
806 doc_fields.insert("content".to_string(), (vec![2, 3], vec![1.0, 1.0]));
808
809 let mut field_weights = HashMap::new();
811 field_weights.insert("title".to_string(), 1.0);
812 field_weights.insert("abstract".to_string(), 1.0);
813 field_weights.insert("content".to_string(), 1.0);
814
815 let score = bm25f_score(
816 &query_indices,
817 &query_weights,
818 &doc_fields,
819 &field_weights,
820 &stats,
821 &BM25Config::default(),
822 );
823
824 assert!(score > 0.0);
825 }
826
827 #[test]
828 fn test_bm25f_title_boost() {
829 let mut idf = HashMap::new();
831 idf.insert(1, 2.0);
832
833 let stats = BM25Stats {
834 avg_doc_length: 10.0,
835 idf,
836 num_docs: 100,
837 };
838
839 let query_indices = vec![1];
840 let query_weights = vec![1.0];
841
842 let mut doc_fields = HashMap::new();
843 doc_fields.insert("title".to_string(), (vec![1], vec![1.0]));
844 doc_fields.insert("content".to_string(), (vec![1], vec![1.0]));
845
846 let mut field_weights_no_boost = HashMap::new();
848 field_weights_no_boost.insert("title".to_string(), 1.0);
849 field_weights_no_boost.insert("content".to_string(), 1.0);
850
851 let score_no_boost = bm25f_score(
852 &query_indices,
853 &query_weights,
854 &doc_fields,
855 &field_weights_no_boost,
856 &stats,
857 &BM25Config::default(),
858 );
859
860 let mut field_weights_with_boost = HashMap::new();
862 field_weights_with_boost.insert("title".to_string(), 3.0);
863 field_weights_with_boost.insert("content".to_string(), 1.0);
864
865 let score_with_boost = bm25f_score(
866 &query_indices,
867 &query_weights,
868 &doc_fields,
869 &field_weights_with_boost,
870 &stats,
871 &BM25Config::default(),
872 );
873
874 assert!(score_with_boost > score_no_boost);
876 }
877
878 #[test]
879 fn test_bm25f_missing_field_weight() {
880 let mut idf = HashMap::new();
882 idf.insert(1, 2.0);
883
884 let stats = BM25Stats {
885 avg_doc_length: 10.0,
886 idf,
887 num_docs: 100,
888 };
889
890 let query_indices = vec![1];
891 let query_weights = vec![1.0];
892
893 let mut doc_fields = HashMap::new();
894 doc_fields.insert("title".to_string(), (vec![1], vec![1.0]));
895 doc_fields.insert("content".to_string(), (vec![1], vec![1.0]));
896
897 let mut field_weights = HashMap::new();
899 field_weights.insert("title".to_string(), 2.0);
900
901 let score = bm25f_score(
902 &query_indices,
903 &query_weights,
904 &doc_fields,
905 &field_weights,
906 &stats,
907 &BM25Config::default(),
908 );
909
910 assert!(score > 0.0);
912 }
913
914 #[test]
915 fn test_bm25f_no_matching_terms() {
916 let mut idf = HashMap::new();
917 idf.insert(1, 2.0);
918 idf.insert(2, 1.5);
919
920 let stats = BM25Stats {
921 avg_doc_length: 10.0,
922 idf,
923 num_docs: 100,
924 };
925
926 let query_indices = vec![1, 2];
927 let query_weights = vec![1.0, 1.0];
928
929 let mut doc_fields = HashMap::new();
930 doc_fields.insert("title".to_string(), (vec![3, 4], vec![1.0, 1.0]));
932
933 let mut field_weights = HashMap::new();
934 field_weights.insert("title".to_string(), 1.0);
935
936 let score = bm25f_score(
937 &query_indices,
938 &query_weights,
939 &doc_fields,
940 &field_weights,
941 &stats,
942 &BM25Config::default(),
943 );
944
945 assert_eq!(score, 0.0);
946 }
947
948 #[test]
949 fn test_bm25f_empty_fields() {
950 let mut idf = HashMap::new();
951 idf.insert(1, 2.0);
952
953 let stats = BM25Stats {
954 avg_doc_length: 10.0,
955 idf,
956 num_docs: 100,
957 };
958
959 let query_indices = vec![1];
960 let query_weights = vec![1.0];
961
962 let doc_fields = HashMap::new(); let field_weights = HashMap::new();
964
965 let score = bm25f_score(
966 &query_indices,
967 &query_weights,
968 &doc_fields,
969 &field_weights,
970 &stats,
971 &BM25Config::default(),
972 );
973
974 assert_eq!(score, 0.0);
975 }
976
977 #[test]
978 fn test_bm25f_realistic_document() {
979 let mut idf = HashMap::new();
981 idf.insert(100, 2.5); idf.insert(200, 2.0); idf.insert(300, 1.8); let stats = BM25Stats {
986 avg_doc_length: 50.0,
987 idf,
988 num_docs: 1000,
989 };
990
991 let query_indices = vec![100, 200, 300];
992 let query_weights = vec![1.0, 1.0, 1.0];
993
994 let mut doc_fields = HashMap::new();
996 doc_fields.insert("title".to_string(), (vec![100, 200], vec![1.0, 1.0])); doc_fields.insert("abstract".to_string(), (vec![200, 300], vec![1.0, 1.0])); doc_fields.insert(
999 "content".to_string(),
1000 (vec![100, 200, 300], vec![2.0, 3.0, 1.0]),
1001 ); let field_weights = parse_field_weights(&["title^3", "abstract^2", "content"]);
1005
1006 let score = bm25f_score(
1007 &query_indices,
1008 &query_weights,
1009 &doc_fields,
1010 &field_weights,
1011 &stats,
1012 &BM25Config::default(),
1013 );
1014
1015 assert!(score > 5.0); }
1018}