1use std::cmp::Ordering;
26
27use serde::{Deserialize, Serialize};
28
29#[derive(Debug, Clone)]
31pub struct MetricSeries {
32 pub label: String,
34 pub values: Vec<f64>,
36 pub higher_is_better: bool,
38}
39
40impl MetricSeries {
41 pub fn new(label: impl Into<String>, values: Vec<f64>, higher_is_better: bool) -> Self {
43 Self { label: label.into(), values, higher_is_better }
44 }
45}
46
47pub fn pearson(x: &[f64], y: &[f64]) -> f64 {
51 let n = x.len();
52 if n == 0 || n != y.len() {
53 return 0.0;
54 }
55 let nf = n as f64;
56 let mx = x.iter().sum::<f64>() / nf;
57 let my = y.iter().sum::<f64>() / nf;
58 let mut cov = 0.0;
59 let mut vx = 0.0;
60 let mut vy = 0.0;
61 for i in 0..n {
62 let dx = x[i] - mx;
63 let dy = y[i] - my;
64 cov += dx * dy;
65 vx += dx * dx;
66 vy += dy * dy;
67 }
68 if vx <= 0.0 || vy <= 0.0 {
69 return 0.0;
70 }
71 cov / (vx.sqrt() * vy.sqrt())
72}
73
74pub fn spearman(x: &[f64], y: &[f64]) -> f64 {
76 if x.len() != y.len() || x.is_empty() {
77 return 0.0;
78 }
79 pearson(&ranks(x), &ranks(y))
80}
81
82pub fn kendall_tau(x: &[f64], y: &[f64]) -> f64 {
84 let n = x.len();
85 if n < 2 || n != y.len() {
86 return 0.0;
87 }
88 let mut concordant = 0i64;
89 let mut discordant = 0i64;
90 let mut ties_x = 0i64;
91 let mut ties_y = 0i64;
92 for i in 0..n {
93 for j in (i + 1)..n {
94 let dx = x[i] - x[j];
95 let dy = y[i] - y[j];
96 let tx = dx == 0.0;
97 let ty = dy == 0.0;
98 if tx {
99 ties_x += 1;
100 }
101 if ty {
102 ties_y += 1;
103 }
104 if !tx && !ty {
105 if (dx > 0.0) == (dy > 0.0) {
106 concordant += 1;
107 } else {
108 discordant += 1;
109 }
110 }
111 }
112 }
113 let n0 = (n * (n - 1) / 2) as f64;
114 let denom = ((n0 - ties_x as f64) * (n0 - ties_y as f64)).sqrt();
115 if denom <= 0.0 {
116 return 0.0;
117 }
118 (concordant - discordant) as f64 / denom
119}
120
121fn ranks(values: &[f64]) -> Vec<f64> {
123 let n = values.len();
124 let mut idx: Vec<usize> = (0..n).collect();
125 idx.sort_by(|&a, &b| values[a].partial_cmp(&values[b]).unwrap_or(Ordering::Equal));
126 let mut out = vec![0.0; n];
127 let mut i = 0;
128 while i < n {
129 let mut j = i;
130 while j + 1 < n && values[idx[j + 1]] == values[idx[i]] {
131 j += 1;
132 }
133 let avg = ((i + j) as f64) / 2.0 + 1.0;
135 for k in i..=j {
136 out[idx[k]] = avg;
137 }
138 i = j + 1;
139 }
140 out
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct CorrelationMatrix {
149 pub labels: Vec<String>,
151 pub pearson: Vec<Vec<f64>>,
153 pub spearman: Vec<Vec<f64>>,
155 pub kendall: Vec<Vec<f64>>,
157}
158
159pub fn correlation_matrix(series: &[MetricSeries]) -> CorrelationMatrix {
161 let k = series.len();
162 let labels = series.iter().map(|s| s.label.clone()).collect();
163 let mut pearson_m = vec![vec![0.0; k]; k];
164 let mut spearman_m = vec![vec![0.0; k]; k];
165 let mut kendall_m = vec![vec![0.0; k]; k];
166 for i in 0..k {
167 for j in 0..k {
168 pearson_m[i][j] = pearson(&series[i].values, &series[j].values);
169 spearman_m[i][j] = spearman(&series[i].values, &series[j].values);
170 kendall_m[i][j] = kendall_tau(&series[i].values, &series[j].values);
171 }
172 }
173 CorrelationMatrix { labels, pearson: pearson_m, spearman: spearman_m, kendall: kendall_m }
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct Divergence {
179 pub index: usize,
181 pub spread: f64,
183 pub normalized: Vec<f64>,
185}
186
187pub fn divergences(series: &[MetricSeries]) -> Vec<Divergence> {
194 if series.len() < 2 {
195 return Vec::new();
196 }
197 let n = series[0].values.len();
198 if n == 0 || series.iter().any(|s| s.values.len() != n) {
199 return Vec::new();
200 }
201 let normalized: Vec<Vec<f64>> = series.iter().map(minmax_normalized).collect();
202 let mut out = Vec::with_capacity(n);
203 for idx in 0..n {
204 let vals: Vec<f64> = normalized.iter().map(|s| s[idx]).collect();
205 let min = vals.iter().copied().fold(f64::INFINITY, f64::min);
206 let max = vals.iter().copied().fold(f64::NEG_INFINITY, f64::max);
207 out.push(Divergence { index: idx, spread: max - min, normalized: vals });
208 }
209 out.sort_by(|a, b| b.spread.partial_cmp(&a.spread).unwrap_or(Ordering::Equal));
210 out
211}
212
213fn minmax_normalized(series: &MetricSeries) -> Vec<f64> {
216 let min = series.values.iter().copied().fold(f64::INFINITY, f64::min);
217 let max = series.values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
218 let range = max - min;
219 series
220 .values
221 .iter()
222 .map(|&v| {
223 let q = if range > 0.0 { (v - min) / range } else { 0.5 };
224 if series.higher_is_better { q } else { 1.0 - q }
225 })
226 .collect()
227}
228
229pub fn series_from_frames(frames: &[viser_quality::FrameResult]) -> Vec<MetricSeries> {
235 let mut series = vec![
236 MetricSeries::new("vmaf", frames.iter().map(|f| f.vmaf).collect(), true),
237 MetricSeries::new("psnr", frames.iter().map(|f| f.psnr).collect(), true),
238 MetricSeries::new("ssim", frames.iter().map(|f| f.ssim).collect(), true),
239 ];
240 let mut push_if_present = |label: &str, vals: Vec<f64>, higher: bool| {
242 if vals.iter().any(|v| *v != 0.0) {
243 series.push(MetricSeries::new(label, vals, higher));
244 }
245 };
246 push_if_present("ms_ssim", frames.iter().map(|f| f.ms_ssim).collect(), true);
247 push_if_present("vif", frames.iter().map(|f| f.vif).collect(), true);
248 push_if_present("cambi", frames.iter().map(|f| f.cambi).collect(), false);
249 push_if_present("xpsnr", frames.iter().map(|f| f.xpsnr).collect(), true);
250 push_if_present("ssimulacra2", frames.iter().map(|f| f.ssimulacra2).collect(), true);
251 push_if_present("butteraugli", frames.iter().map(|f| f.butteraugli).collect(), false);
252 series
253}
254
255impl CorrelationMatrix {
256 pub fn to_markdown(&self) -> String {
258 let mut out = String::from("| metric |");
259 for label in &self.labels {
260 out.push_str(&format!(" {label} |"));
261 }
262 out.push_str("\n|---|");
263 for _ in &self.labels {
264 out.push_str("---|");
265 }
266 out.push('\n');
267 for (i, label) in self.labels.iter().enumerate() {
268 out.push_str(&format!("| {label} |"));
269 for j in 0..self.labels.len() {
270 out.push_str(&format!(" {:.3} |", self.spearman[i][j]));
271 }
272 out.push('\n');
273 }
274 out
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn pearson_perfect_positive_and_negative() {
284 let x = [1.0, 2.0, 3.0, 4.0];
285 let up = [2.0, 4.0, 6.0, 8.0];
286 let down = [8.0, 6.0, 4.0, 2.0];
287 assert!((pearson(&x, &up) - 1.0).abs() < 1e-9);
288 assert!((pearson(&x, &down) + 1.0).abs() < 1e-9);
289 }
290
291 #[test]
292 fn pearson_guards() {
293 assert_eq!(pearson(&[], &[]), 0.0);
294 assert_eq!(pearson(&[1.0, 2.0], &[1.0]), 0.0);
295 assert_eq!(pearson(&[5.0, 5.0, 5.0], &[1.0, 2.0, 3.0]), 0.0); }
297
298 #[test]
299 fn spearman_monotonic_nonlinear() {
300 let x = [1.0, 2.0, 3.0, 4.0, 5.0];
302 let y = [1.0, 4.0, 9.0, 16.0, 25.0];
303 assert!((spearman(&x, &y) - 1.0).abs() < 1e-9);
304 assert!(pearson(&x, &y) < 1.0);
305 }
306
307 #[test]
308 fn spearman_handles_ties() {
309 let x = [1.0, 2.0, 2.0, 3.0];
310 let y = [10.0, 20.0, 20.0, 30.0];
311 assert!((spearman(&x, &y) - 1.0).abs() < 1e-9);
312 }
313
314 #[test]
315 fn kendall_perfect_and_reversed() {
316 let x = [1.0, 2.0, 3.0, 4.0];
317 let up = [1.0, 2.0, 3.0, 4.0];
318 let down = [4.0, 3.0, 2.0, 1.0];
319 assert!((kendall_tau(&x, &up) - 1.0).abs() < 1e-9);
320 assert!((kendall_tau(&x, &down) + 1.0).abs() < 1e-9);
321 }
322
323 #[test]
324 fn correlation_matrix_diagonal_is_one() {
325 let series = vec![
326 MetricSeries::new("a", vec![1.0, 2.0, 3.0], true),
327 MetricSeries::new("b", vec![3.0, 1.0, 2.0], true),
328 ];
329 let m = correlation_matrix(&series);
330 assert!((m.pearson[0][0] - 1.0).abs() < 1e-9);
331 assert!((m.spearman[1][1] - 1.0).abs() < 1e-9);
332 assert!((m.spearman[0][1] - m.spearman[1][0]).abs() < 1e-9);
334 }
335
336 #[test]
337 fn divergence_flags_disagreement() {
338 let series = vec![
340 MetricSeries::new("a", vec![0.0, 100.0, 50.0], true),
341 MetricSeries::new("b", vec![0.0, 0.0, 50.0], true),
342 ];
343 let d = divergences(&series);
344 assert_eq!(d.len(), 3);
345 assert_eq!(d[0].index, 1);
347 assert!((d[0].spread - 1.0).abs() < 1e-9);
348 }
349
350 #[test]
351 fn divergence_respects_polarity() {
352 let series = vec![
354 MetricSeries::new("vmaf", vec![100.0, 50.0, 0.0], true),
355 MetricSeries::new("butteraugli", vec![0.0, 1.0, 2.0], false),
356 ];
357 let d = divergences(&series);
359 assert!(d.iter().all(|x| x.spread < 1e-9));
360 }
361
362 #[test]
363 fn divergence_guards() {
364 assert!(divergences(&[]).is_empty());
365 assert!(divergences(&[MetricSeries::new("a", vec![1.0], true)]).is_empty());
366 let misaligned = vec![
367 MetricSeries::new("a", vec![1.0, 2.0], true),
368 MetricSeries::new("b", vec![1.0], true),
369 ];
370 assert!(divergences(&misaligned).is_empty());
371 }
372
373 #[test]
374 fn series_from_frames_skips_empty_metrics() {
375 use viser_quality::FrameResult;
376 let frames = vec![
377 FrameResult { frame_num: 0, vmaf: 80.0, psnr: 37.0, ssim: 0.9, ..Default::default() },
378 FrameResult { frame_num: 1, vmaf: 90.0, psnr: 40.0, ssim: 0.95, ..Default::default() },
379 ];
380 let series = series_from_frames(&frames);
381 assert_eq!(series.len(), 3);
383 assert_eq!(series[0].label, "vmaf");
384 }
385
386 #[test]
387 fn markdown_render() {
388 let series = vec![
389 MetricSeries::new("vmaf", vec![1.0, 2.0, 3.0], true),
390 MetricSeries::new("psnr", vec![1.0, 2.0, 3.0], true),
391 ];
392 let md = correlation_matrix(&series).to_markdown();
393 assert!(md.contains("| vmaf |"));
394 assert!(md.contains("1.000"));
395 }
396}