scirs2_text/evaluation/
sts.rs1use crate::error::{Result, TextError};
33use scirs2_core::ndarray::Array1;
34use std::path::Path;
35
36#[derive(Debug, Clone, Copy, PartialEq)]
38pub enum StsDatasetFormat {
39 StsB,
42 Sick,
44 Sts12to16,
46}
47
48#[derive(Debug, Clone)]
50pub struct StsReport {
51 pub pearson: f32,
53 pub spearman: f32,
55 pub mse: f32,
57 pub predictions: Vec<f32>,
59 pub gold_labels: Vec<f32>,
61 pub n_pairs: usize,
63}
64
65type StsPairs = Vec<(Vec<String>, Vec<String>, f32)>;
66
67pub fn load_sts_from_tsv(path: impl AsRef<Path>) -> Result<StsPairs> {
79 use std::fs::File;
80 use std::io::{BufRead, BufReader};
81
82 let file = File::open(path.as_ref()).map_err(|e| TextError::IoError(e.to_string()))?;
83 let reader = BufReader::new(file);
84 let mut pairs = Vec::new();
85
86 for line in reader.lines() {
87 let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
88 let line = line.trim();
89 if line.is_empty() {
90 continue;
91 }
92 let fields: Vec<&str> = line.split('\t').collect();
93
94 let (score_str, s1, s2) = if fields.len() >= 8 {
96 (fields[4], fields[5], fields[6])
97 } else if fields.len() >= 3 {
98 (fields[0], fields[1], fields[2])
99 } else {
100 continue;
101 };
102
103 let score: f32 = match score_str.trim().parse() {
104 Ok(v) => v,
105 Err(_) => continue, };
107
108 let tokens1: Vec<String> = s1.split_whitespace().map(str::to_owned).collect();
109 let tokens2: Vec<String> = s2.split_whitespace().map(str::to_owned).collect();
110 pairs.push((tokens1, tokens2, score));
111 }
112
113 Ok(pairs)
114}
115
116fn pearson_correlation(x: &[f32], y: &[f32]) -> f32 {
118 let n = x.len() as f32;
119 if n == 0.0 {
120 return 0.0;
121 }
122 let mx = x.iter().sum::<f32>() / n;
123 let my = y.iter().sum::<f32>() / n;
124 let num: f32 = x.iter().zip(y).map(|(a, b)| (a - mx) * (b - my)).sum();
125 let da: f32 = x.iter().map(|a| (a - mx).powi(2)).sum::<f32>().sqrt();
126 let db: f32 = y.iter().map(|b| (b - my).powi(2)).sum::<f32>().sqrt();
127 if da == 0.0 || db == 0.0 {
128 0.0
129 } else {
130 num / (da * db)
131 }
132}
133
134fn spearman_correlation(x: &[f32], y: &[f32]) -> f32 {
136 fn rank(v: &[f32]) -> Vec<f32> {
138 let mut indexed: Vec<(usize, f32)> = v.iter().cloned().enumerate().collect();
139 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
140 let mut ranks = vec![0.0f32; v.len()];
141 let mut i = 0;
142 while i < indexed.len() {
143 let val = indexed[i].1;
145 let mut j = i + 1;
146 while j < indexed.len() && indexed[j].1 == val {
147 j += 1;
148 }
149 let avg_rank = (i + j + 1) as f32 / 2.0; for item in &indexed[i..j] {
152 ranks[item.0] = avg_rank;
153 }
154 i = j;
155 }
156 ranks
157 }
158
159 let rx = rank(x);
160 let ry = rank(y);
161 pearson_correlation(&rx, &ry)
162}
163
164pub fn sts_evaluate(
175 embed_fn: &dyn Fn(&[String]) -> Array1<f32>,
176 pairs: &[(Vec<String>, Vec<String>, f32)],
177) -> Result<StsReport> {
178 if pairs.is_empty() {
179 return Err(TextError::InvalidInput(
180 "STS dataset is empty; at least one pair is required".into(),
181 ));
182 }
183
184 let mut predictions = Vec::with_capacity(pairs.len());
185 let mut gold_labels = Vec::with_capacity(pairs.len());
186
187 for (s1_tokens, s2_tokens, gold) in pairs {
188 let e1 = embed_fn(s1_tokens);
189 let e2 = embed_fn(s2_tokens);
190
191 let dot = e1.dot(&e2);
192 let n1 = e1.dot(&e1).sqrt();
193 let n2 = e2.dot(&e2).sqrt();
194 let cosine = if n1 == 0.0 || n2 == 0.0 {
195 0.0f32
196 } else {
197 dot / (n1 * n2)
198 };
199
200 predictions.push(cosine);
201 gold_labels.push(*gold);
202 }
203
204 let pearson = pearson_correlation(&predictions, &gold_labels);
205 let spearman = spearman_correlation(&predictions, &gold_labels);
206 let mse = predictions
207 .iter()
208 .zip(&gold_labels)
209 .map(|(p, g)| (p - g).powi(2))
210 .sum::<f32>()
211 / predictions.len() as f32;
212
213 Ok(StsReport {
214 pearson,
215 spearman,
216 mse,
217 predictions,
218 gold_labels,
219 n_pairs: pairs.len(),
220 })
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use scirs2_core::ndarray::Array1;
227
228 fn bow_embed(tokens: &[String], dim: usize) -> Array1<f32> {
230 let mut v = Array1::zeros(dim);
231 for (i, _tok) in tokens.iter().enumerate() {
232 let idx = i % dim;
233 v[idx] += 1.0;
234 }
235 v
236 }
237
238 #[test]
239 fn sts_empty_returns_error() {
240 let result = sts_evaluate(&|t| bow_embed(t, 4), &[]);
241 assert!(result.is_err());
242 }
243
244 #[test]
245 fn sts_single_pair_identical_tokens() {
246 let pairs = vec![(vec!["cat".to_string()], vec!["cat".to_string()], 5.0f32)];
247 let report = sts_evaluate(&|t| bow_embed(t, 4), &pairs).expect("evaluate");
248 assert_eq!(report.n_pairs, 1);
249 assert!((report.predictions[0] - 1.0).abs() < 1e-5);
251 }
252
253 #[test]
254 fn sts_mse_is_non_negative() {
255 let pairs = vec![
256 (vec!["a".to_string()], vec!["b".to_string()], 2.5f32),
257 (vec!["c".to_string()], vec!["c".to_string()], 4.0f32),
258 ];
259 let report = sts_evaluate(&|t| bow_embed(t, 4), &pairs).expect("evaluate");
260 assert!(report.mse >= 0.0);
261 }
262}