Skip to main content

triplets_core/
metrics.rs

1use std::collections::HashMap;
2
3use crate::data::{ChunkView, RecordChunk};
4use crate::types::SourceId;
5
6/// Aggregate skew metrics for per-source sample counts.
7#[derive(Clone, Debug, PartialEq)]
8pub struct SourceSkew {
9    /// Total sample count across all sources.
10    pub total: usize,
11    /// Number of sources represented in the metric.
12    pub sources: usize,
13    /// Minimum per-source sample count.
14    pub min: usize,
15    /// Maximum per-source sample count.
16    pub max: usize,
17    /// Mean per-source sample count.
18    pub mean: f64,
19    /// Largest source share (`max / total`).
20    pub max_share: f64,
21    /// Smallest source share (`min / total`).
22    pub min_share: f64,
23    /// Imbalance ratio (`max / min`, or `inf` when `min == 0`).
24    pub ratio: f64,
25    /// Per-source counts and shares sorted descending by count.
26    pub per_source: Vec<SourceShare>,
27}
28
29/// Per-source share of a batch for skew inspection.
30#[derive(Clone, Debug, PartialEq)]
31pub struct SourceShare {
32    /// Source identifier.
33    pub source: SourceId,
34    /// Number of samples drawn from this source.
35    pub count: usize,
36    /// Fraction of total samples contributed by this source.
37    pub share: f64,
38}
39
40/// Compute skew metrics from per-source counts.
41/// The map keys are source IDs (the `RecordId` prefix before `::`).
42pub fn source_skew(counts: &HashMap<SourceId, usize>) -> Option<SourceSkew> {
43    if counts.is_empty() {
44        return None;
45    }
46    let total: usize = counts.values().sum();
47    let sources = counts.len();
48    let min = *counts.values().min().expect("counts non-empty");
49    let max = *counts.values().max().expect("counts non-empty");
50    let mean = total as f64 / sources as f64;
51    let max_share = if total == 0 {
52        0.0
53    } else {
54        max as f64 / total as f64
55    };
56    let min_share = if total == 0 {
57        0.0
58    } else {
59        min as f64 / total as f64
60    };
61    let ratio = if min == 0 {
62        f64::INFINITY
63    } else {
64        max as f64 / min as f64
65    };
66    let mut per_source: Vec<SourceShare> = counts
67        .iter()
68        .map(|(source, count)| SourceShare {
69            source: source.clone(),
70            count: *count,
71            share: if total == 0 {
72                0.0
73            } else {
74                *count as f64 / total as f64
75            },
76        })
77        .collect();
78    per_source.sort_by(|a, b| b.count.cmp(&a.count).then_with(|| a.source.cmp(&b.source)));
79    Some(SourceSkew {
80        total,
81        sources,
82        min,
83        max,
84        mean,
85        max_share,
86        min_share,
87        ratio,
88        per_source,
89    })
90}
91
92/// Compute normalized distance between two chunk windows from the same section.
93///
94/// Returns `Some(distance)` in `[0.0, 1.0]` when both chunks are `Window` views
95/// from the same `(record_id, section_idx)`. Returns `None` when distance is not
96/// comparable (different records/sections or non-window views).
97pub fn window_chunk_distance(anchor: &RecordChunk, positive: &RecordChunk) -> Option<f32> {
98    if anchor.record_id != positive.record_id || anchor.section_idx != positive.section_idx {
99        return None;
100    }
101    match (&anchor.view, &positive.view) {
102        (ChunkView::Window { index: left, .. }, ChunkView::Window { index: right, .. }) => {
103            let delta = left.abs_diff(*right) as f32;
104            Some(delta / (delta + 1.0))
105        }
106        _ => None,
107    }
108}
109
110/// Convert chunk distance into a chunk proximity score in `[0.0, 1.0]`.
111///
112/// A higher score means anchor/positive chunks are closer in the document.
113/// When distance cannot be computed, returns `1.0` (neutral multiplier).
114pub fn chunk_proximity_score(anchor: &RecordChunk, positive: &RecordChunk) -> f32 {
115    window_chunk_distance(anchor, positive)
116        .map(|distance| 1.0 - distance)
117        .unwrap_or(1.0)
118}
119
120/// Backward-compatible alias for `chunk_proximity_score`.
121pub fn chunk_distance_relevance_score(anchor: &RecordChunk, positive: &RecordChunk) -> f32 {
122    chunk_proximity_score(anchor, positive)
123}
124
125/// Proximity score of a window chunk to the section head (index 0).
126///
127/// Returns a value in `(0.0, 1.0]` using `1 / (index + 1)`.
128/// - index `0` -> `1.0`
129/// - index `1` -> `0.5`
130/// - index `3` -> `0.25`
131pub fn window_index_proximity(index: usize) -> f32 {
132    1.0 / (index as f32 + 1.0)
133}
134
135/// Compute byte-level Jaccard and cosine similarity scores between two strings.
136///
137/// Uses raw UTF-8 byte occurrence frequencies (no tokenisation), so it is fast
138/// and dependency-free. Returns `(jaccard, cosine)` each in `[0.0, 1.0]`;
139/// both are `0.0` when either input is empty.
140///
141/// Used by BM25 ranking tests to verify top-ranked candidates beat the
142/// uniform-pool baseline, and by the `extended-metrics` demo output.
143#[cfg(any(feature = "extended-metrics", all(test, feature = "bm25-mining")))]
144/// Compute Jaccard similarity and cosine similarity from byte-level n-gram profiles.
145pub fn lexical_similarity_scores(left: &str, right: &str) -> (f32, f32) {
146    if left.is_empty() || right.is_empty() {
147        return (0.0, 0.0);
148    }
149
150    let mut left_freq = [0.0_f32; 256];
151    let mut right_freq = [0.0_f32; 256];
152    let mut left_bits = [0_u8; 32];
153    let mut right_bits = [0_u8; 32];
154
155    for byte in left.as_bytes() {
156        let idx = *byte as usize;
157        left_freq[idx] += 1.0;
158        left_bits[idx / 8] |= 1_u8 << (idx % 8);
159    }
160    for byte in right.as_bytes() {
161        let idx = *byte as usize;
162        right_freq[idx] += 1.0;
163        right_bits[idx / 8] |= 1_u8 << (idx % 8);
164    }
165
166    let dot: f32 = left_freq
167        .iter()
168        .zip(right_freq.iter())
169        .map(|(a, b)| a * b)
170        .sum();
171    let left_norm_sq: f32 = left_freq.iter().map(|v| v * v).sum();
172    let right_norm_sq: f32 = right_freq.iter().map(|v| v * v).sum();
173    let cosine = if left_norm_sq > 0.0 && right_norm_sq > 0.0 {
174        dot / (left_norm_sq.sqrt() * right_norm_sq.sqrt())
175    } else {
176        0.0
177    };
178
179    let mut intersection = 0_u32;
180    let mut union = 0_u32;
181    for i in 0..left_bits.len() {
182        intersection += (left_bits[i] & right_bits[i]).count_ones();
183        union += (left_bits[i] | right_bits[i]).count_ones();
184    }
185    let jaccard = if union > 0 {
186        intersection as f32 / union as f32
187    } else {
188        0.0
189    };
190
191    (jaccard, cosine)
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    fn window_chunk(record_id: &str, section_idx: usize, index: usize) -> RecordChunk {
199        RecordChunk {
200            record_id: record_id.to_string(),
201            section_idx,
202            view: ChunkView::Window {
203                index,
204                overlap: 0,
205                span: 16,
206            },
207            text: "x".to_string(),
208            tokens_estimate: 1,
209            quality: crate::data::QualityScore::default(),
210            kvp_meta: Default::default(),
211        }
212    }
213
214    #[test]
215    fn source_skew_returns_none_for_empty_counts() {
216        let counts = HashMap::new();
217        assert!(source_skew(&counts).is_none());
218    }
219
220    #[test]
221    fn source_skew_reports_balance() {
222        let mut counts = HashMap::new();
223        counts.insert("A".to_string(), 2);
224        counts.insert("B".to_string(), 2);
225        let skew = source_skew(&counts).expect("skew");
226        assert_eq!(skew.total, 4);
227        assert_eq!(skew.sources, 2);
228        assert_eq!(skew.min, 2);
229        assert_eq!(skew.max, 2);
230        assert!((skew.max_share - 0.5).abs() < 1e-6);
231        assert!((skew.ratio - 1.0).abs() < 1e-6);
232        assert_eq!(skew.per_source.len(), 2);
233        assert!(
234            skew.per_source
235                .iter()
236                .all(|entry| (entry.share - 0.5).abs() < 1e-6)
237        );
238    }
239
240    #[test]
241    fn source_skew_reports_imbalance() {
242        let mut counts = HashMap::new();
243        counts.insert("A".to_string(), 4);
244        counts.insert("B".to_string(), 2);
245        counts.insert("C".to_string(), 2);
246        let skew = source_skew(&counts).expect("skew");
247        assert_eq!(skew.total, 8);
248        assert_eq!(skew.sources, 3);
249        assert_eq!(skew.min, 2);
250        assert_eq!(skew.max, 4);
251        assert!((skew.max_share - 0.5).abs() < 1e-6);
252        assert!((skew.ratio - 2.0).abs() < 1e-6);
253        assert_eq!(skew.per_source[0].source, "A");
254        assert_eq!(skew.per_source[0].count, 4);
255    }
256
257    #[test]
258    fn source_skew_zero_totals_report_zero_shares_and_infinite_ratio() {
259        let mut counts = HashMap::new();
260        counts.insert("B".to_string(), 0);
261        counts.insert("A".to_string(), 0);
262
263        let skew = source_skew(&counts).expect("skew");
264        assert_eq!(skew.total, 0);
265        assert_eq!(skew.min, 0);
266        assert_eq!(skew.max, 0);
267        assert_eq!(skew.max_share, 0.0);
268        assert_eq!(skew.min_share, 0.0);
269        assert!(skew.ratio.is_infinite());
270        assert_eq!(skew.per_source[0].source, "A");
271        assert_eq!(skew.per_source[1].source, "B");
272        assert!(skew.per_source.iter().all(|entry| entry.share == 0.0));
273    }
274
275    #[test]
276    fn window_chunk_distance_uses_index_delta() {
277        let a = window_chunk("record", 0, 1);
278        let b = window_chunk("record", 0, 4);
279        let distance = window_chunk_distance(&a, &b).expect("distance");
280        assert!((distance - 0.75).abs() < 1e-6, "distance={distance}");
281    }
282
283    #[test]
284    fn chunk_proximity_score_inverts_distance() {
285        let a = window_chunk("record", 0, 1);
286        let b = window_chunk("record", 0, 4);
287        let proximity = chunk_proximity_score(&a, &b);
288        assert!((proximity - 0.25).abs() < 1e-6, "proximity={proximity}");
289    }
290
291    #[test]
292    fn chunk_proximity_score_is_neutral_when_not_comparable() {
293        let a = window_chunk("record_a", 0, 1);
294        let b = window_chunk("record_b", 0, 4);
295        assert_eq!(window_chunk_distance(&a, &b), None);
296        assert_eq!(chunk_proximity_score(&a, &b), 1.0);
297    }
298
299    #[test]
300    fn chunk_distance_relevance_score_alias_matches_proximity() {
301        let a = window_chunk("record", 0, 1);
302        let b = window_chunk("record", 0, 4);
303        assert_eq!(
304            chunk_distance_relevance_score(&a, &b),
305            chunk_proximity_score(&a, &b)
306        );
307    }
308
309    #[test]
310    fn window_index_proximity_scores_drop_with_index() {
311        assert!((window_index_proximity(0) - 1.0).abs() < 1e-6);
312        assert!((window_index_proximity(1) - 0.5).abs() < 1e-6);
313        assert!((window_index_proximity(3) - 0.25).abs() < 1e-6);
314    }
315
316    #[cfg(any(feature = "bm25-mining", feature = "extended-metrics"))]
317    #[test]
318    fn lexical_similarity_identical_strings_score_one() {
319        let (j, c) = lexical_similarity_scores("hello world", "hello world");
320        assert!((j - 1.0).abs() < 1e-6, "jaccard={j}");
321        assert!((c - 1.0).abs() < 1e-6, "cosine={c}");
322    }
323
324    #[cfg(any(feature = "bm25-mining", feature = "extended-metrics"))]
325    #[test]
326    fn lexical_similarity_empty_inputs_score_zero() {
327        assert_eq!(lexical_similarity_scores("", "hello"), (0.0, 0.0));
328        assert_eq!(lexical_similarity_scores("hello", ""), (0.0, 0.0));
329        assert_eq!(lexical_similarity_scores("", ""), (0.0, 0.0));
330    }
331
332    #[cfg(any(feature = "bm25-mining", feature = "extended-metrics"))]
333    #[test]
334    fn lexical_similarity_scores_are_in_unit_range() {
335        let cases = [
336            ("foo bar baz", "qux quux"),
337            ("abc", "abc def"),
338            ("the quick brown fox", "jumped over the lazy dog"),
339        ];
340        for (a, b) in cases {
341            let (j, c) = lexical_similarity_scores(a, b);
342            assert!(
343                (0.0..=1.0).contains(&j),
344                "jaccard={j} out of range for ({a:?}, {b:?})"
345            );
346            assert!(
347                (0.0..=1.0).contains(&c),
348                "cosine={c} out of range for ({a:?}, {b:?})"
349            );
350        }
351    }
352}