Skip to main content

scirs2_text/evaluation/
sts.rs

1//! STS (Semantic Textual Similarity) benchmark evaluation.
2//!
3//! Provides `sts_evaluate`, `load_sts_from_tsv`, `StsReport`, `StsDatasetFormat`.
4//!
5//! Protocol-only: this module evaluates pairs already loaded into memory.
6//! No dataset is downloaded automatically.
7//!
8//! ## Example
9//!
10//! ```rust
11//! use scirs2_text::evaluation::sts::{sts_evaluate, StsReport};
12//! use scirs2_core::ndarray::Array1;
13//!
14//! let pairs: Vec<(Vec<String>, Vec<String>, f32)> = vec![
15//!     (
16//!         vec!["hello".into()],
17//!         vec!["hello".into()],
18//!         5.0,
19//!     ),
20//! ];
21//! let embed = |tokens: &[String]| {
22//!     let mut v = Array1::zeros(tokens.len().max(1));
23//!     for (i, _) in tokens.iter().enumerate() {
24//!         v[i] = 1.0f32;
25//!     }
26//!     v
27//! };
28//! let report = sts_evaluate(&embed, &pairs).unwrap();
29//! assert!(report.n_pairs == 1);
30//! ```
31
32use crate::error::{Result, TextError};
33use scirs2_core::ndarray::Array1;
34use std::path::Path;
35
36/// Dataset format variants for documentation / future parsing.
37#[derive(Debug, Clone, Copy, PartialEq)]
38pub enum StsDatasetFormat {
39    /// STS-B format: score 0–5, tab-separated with columns
40    /// `idx\tgenre\tfile\tyear\tsid\tscore\tsentence1\tsentence2`
41    StsB,
42    /// SICK format: score 1–5
43    Sick,
44    /// STS 2012–2016 format: score 0–5
45    Sts12to16,
46}
47
48/// Report returned by [`sts_evaluate`].
49#[derive(Debug, Clone)]
50pub struct StsReport {
51    /// Pearson correlation between cosine-similarity predictions and gold scores.
52    pub pearson: f32,
53    /// Spearman rank correlation between predictions and gold scores.
54    pub spearman: f32,
55    /// Mean squared error between predictions and gold scores.
56    pub mse: f32,
57    /// Cosine-similarity prediction for each pair (same order as input).
58    pub predictions: Vec<f32>,
59    /// Gold similarity labels (same order as input).
60    pub gold_labels: Vec<f32>,
61    /// Total number of sentence pairs evaluated.
62    pub n_pairs: usize,
63}
64
65type StsPairs = Vec<(Vec<String>, Vec<String>, f32)>;
66
67/// Load STS sentence pairs from a TSV file.
68///
69/// The parser tries multiple common column layouts:
70/// - 3 columns → `score`, `sentence1`, `sentence2`
71/// - 8+ columns (STS-B style) → score at index 4, sentence1 at 5, sentence2 at 6
72///
73/// Each sentence is tokenized by splitting on whitespace.
74///
75/// # Errors
76///
77/// Returns [`TextError::IoError`] if the file cannot be opened or read.
78pub fn load_sts_from_tsv(path: impl AsRef<Path>) -> Result<StsPairs> {
79    use std::fs::File;
80    use std::io::{BufRead, BufReader};
81
82    let file = File::open(path.as_ref()).map_err(|e| TextError::IoError(e.to_string()))?;
83    let reader = BufReader::new(file);
84    let mut pairs = Vec::new();
85
86    for line in reader.lines() {
87        let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
88        let line = line.trim();
89        if line.is_empty() {
90            continue;
91        }
92        let fields: Vec<&str> = line.split('\t').collect();
93
94        // Determine column layout
95        let (score_str, s1, s2) = if fields.len() >= 8 {
96            (fields[4], fields[5], fields[6])
97        } else if fields.len() >= 3 {
98            (fields[0], fields[1], fields[2])
99        } else {
100            continue;
101        };
102
103        let score: f32 = match score_str.trim().parse() {
104            Ok(v) => v,
105            Err(_) => continue, // skip header or malformed lines
106        };
107
108        let tokens1: Vec<String> = s1.split_whitespace().map(str::to_owned).collect();
109        let tokens2: Vec<String> = s2.split_whitespace().map(str::to_owned).collect();
110        pairs.push((tokens1, tokens2, score));
111    }
112
113    Ok(pairs)
114}
115
116/// Compute Pearson correlation coefficient between two slices.
117fn pearson_correlation(x: &[f32], y: &[f32]) -> f32 {
118    let n = x.len() as f32;
119    if n == 0.0 {
120        return 0.0;
121    }
122    let mx = x.iter().sum::<f32>() / n;
123    let my = y.iter().sum::<f32>() / n;
124    let num: f32 = x.iter().zip(y).map(|(a, b)| (a - mx) * (b - my)).sum();
125    let da: f32 = x.iter().map(|a| (a - mx).powi(2)).sum::<f32>().sqrt();
126    let db: f32 = y.iter().map(|b| (b - my).powi(2)).sum::<f32>().sqrt();
127    if da == 0.0 || db == 0.0 {
128        0.0
129    } else {
130        num / (da * db)
131    }
132}
133
134/// Compute Spearman rank correlation via rank transform followed by Pearson.
135fn spearman_correlation(x: &[f32], y: &[f32]) -> f32 {
136    /// Rank a slice (1-based, ties receive fractional average ranks).
137    fn rank(v: &[f32]) -> Vec<f32> {
138        let mut indexed: Vec<(usize, f32)> = v.iter().cloned().enumerate().collect();
139        indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
140        let mut ranks = vec![0.0f32; v.len()];
141        let mut i = 0;
142        while i < indexed.len() {
143            // Find all ties
144            let val = indexed[i].1;
145            let mut j = i + 1;
146            while j < indexed.len() && indexed[j].1 == val {
147                j += 1;
148            }
149            // Average rank for the tie group
150            let avg_rank = (i + j + 1) as f32 / 2.0; // 1-based midpoint
151            for item in &indexed[i..j] {
152                ranks[item.0] = avg_rank;
153            }
154            i = j;
155        }
156        ranks
157    }
158
159    let rx = rank(x);
160    let ry = rank(y);
161    pearson_correlation(&rx, &ry)
162}
163
164/// Evaluate semantic textual similarity using cosine similarity of embeddings vs gold labels.
165///
166/// `embed_fn` maps a token list to a dense embedding vector.
167/// `pairs` is a slice of `(tokens1, tokens2, gold_score)`.
168///
169/// Returns [`StsReport`] with Pearson/Spearman correlations, MSE, and raw predictions.
170///
171/// # Errors
172///
173/// Returns [`TextError::InvalidInput`] if `pairs` is empty.
174pub fn sts_evaluate(
175    embed_fn: &dyn Fn(&[String]) -> Array1<f32>,
176    pairs: &[(Vec<String>, Vec<String>, f32)],
177) -> Result<StsReport> {
178    if pairs.is_empty() {
179        return Err(TextError::InvalidInput(
180            "STS dataset is empty; at least one pair is required".into(),
181        ));
182    }
183
184    let mut predictions = Vec::with_capacity(pairs.len());
185    let mut gold_labels = Vec::with_capacity(pairs.len());
186
187    for (s1_tokens, s2_tokens, gold) in pairs {
188        let e1 = embed_fn(s1_tokens);
189        let e2 = embed_fn(s2_tokens);
190
191        let dot = e1.dot(&e2);
192        let n1 = e1.dot(&e1).sqrt();
193        let n2 = e2.dot(&e2).sqrt();
194        let cosine = if n1 == 0.0 || n2 == 0.0 {
195            0.0f32
196        } else {
197            dot / (n1 * n2)
198        };
199
200        predictions.push(cosine);
201        gold_labels.push(*gold);
202    }
203
204    let pearson = pearson_correlation(&predictions, &gold_labels);
205    let spearman = spearman_correlation(&predictions, &gold_labels);
206    let mse = predictions
207        .iter()
208        .zip(&gold_labels)
209        .map(|(p, g)| (p - g).powi(2))
210        .sum::<f32>()
211        / predictions.len() as f32;
212
213    Ok(StsReport {
214        pearson,
215        spearman,
216        mse,
217        predictions,
218        gold_labels,
219        n_pairs: pairs.len(),
220    })
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use scirs2_core::ndarray::Array1;
227
228    /// Simple bag-of-words embed: returns a fixed-dim vector with 1.0 for each token present.
229    fn bow_embed(tokens: &[String], dim: usize) -> Array1<f32> {
230        let mut v = Array1::zeros(dim);
231        for (i, _tok) in tokens.iter().enumerate() {
232            let idx = i % dim;
233            v[idx] += 1.0;
234        }
235        v
236    }
237
238    #[test]
239    fn sts_empty_returns_error() {
240        let result = sts_evaluate(&|t| bow_embed(t, 4), &[]);
241        assert!(result.is_err());
242    }
243
244    #[test]
245    fn sts_single_pair_identical_tokens() {
246        let pairs = vec![(vec!["cat".to_string()], vec!["cat".to_string()], 5.0f32)];
247        let report = sts_evaluate(&|t| bow_embed(t, 4), &pairs).expect("evaluate");
248        assert_eq!(report.n_pairs, 1);
249        // cosine of identical vectors is 1.0
250        assert!((report.predictions[0] - 1.0).abs() < 1e-5);
251    }
252
253    #[test]
254    fn sts_mse_is_non_negative() {
255        let pairs = vec![
256            (vec!["a".to_string()], vec!["b".to_string()], 2.5f32),
257            (vec!["c".to_string()], vec!["c".to_string()], 4.0f32),
258        ];
259        let report = sts_evaluate(&|t| bow_embed(t, 4), &pairs).expect("evaluate");
260        assert!(report.mse >= 0.0);
261    }
262}