Skip to main content

viser_quality/
lib.rs

1//! Video quality measurement for the `viser` video-encoding-optimizer workspace.
2//!
3//! Computes VMAF, PSNR, SSIM, SSIMULACRA2, and butteraugli scores between a
4//! reference and a distorted video. VMAF/PSNR/SSIM use FFmpeg's libvmaf filter,
5//! while SSIMULACRA2 and butteraugli shell out to their CLI tools on extracted
6//! PNG frames. See `measure` for the entry point.
7
8use std::path::{Path, PathBuf};
9
10use serde::{Deserialize, Serialize};
11use tokio::process::Command;
12use tracing::warn;
13use viser_ffmpeg::{ProbeCache, ffmpeg_path};
14
15pub mod noref;
16pub mod pool;
17pub use noref::{NoRefOpts, NoRefResult, measure_noref};
18pub use pool::{PoolStrategy, PooledStats};
19
20/// Quality metric type.
21#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
22#[serde(rename_all = "lowercase")]
23pub enum Metric {
24    /// Netflix VMAF perceptual score (0-100, higher is better).
25    #[default]
26    Vmaf,
27    /// Peak signal-to-noise ratio in dB (higher is better).
28    Psnr,
29    /// Structural similarity index (0-1, higher is better).
30    Ssim,
31    /// SSIMULACRA2 perceptual score (higher is better), via the `ssimulacra2` CLI.
32    Ssimulacra2,
33    /// Butteraugli perceptual distance (lower is better), via the `butteraugli` CLI.
34    Butteraugli,
35    /// Multi-scale SSIM (0-1, higher is better), via libvmaf's `float_ms_ssim`.
36    MsSsim,
37    /// Visual information fidelity (higher is better), the mean of libvmaf's VIF scales.
38    Vif,
39    /// CAMBI banding score (lower is better), via libvmaf's `cambi` feature.
40    Cambi,
41    /// Perceptually-weighted PSNR in dB (higher is better), via FFmpeg's `xpsnr` filter.
42    Xpsnr,
43}
44
45/// Aggregate (pooled) quality scores, with optional per-frame breakdown.
46///
47/// Each score is `0.0` when its metric was not requested or is unavailable.
48#[derive(Debug, Clone, Default, Serialize, Deserialize)]
49#[serde(default)]
50pub struct Result {
51    /// Mean VMAF score.
52    pub vmaf: f64,
53    /// Mean luma (Y) PSNR (dB).
54    pub psnr: f64,
55    /// Mean Cb/U-plane PSNR (dB); `0.0` when per-component PSNR is unavailable.
56    pub psnr_u: f64,
57    /// Mean Cr/V-plane PSNR (dB); `0.0` when per-component PSNR is unavailable.
58    pub psnr_v: f64,
59    /// Weighted PSNR `(6·Y + U + V) / 8` (dB); falls back to luma when chroma is absent.
60    pub psnr_avg: f64,
61    /// Mean SSIM.
62    pub ssim: f64,
63    /// SSIMULACRA2 score (mean over sampled frames).
64    pub ssimulacra2: f64,
65    /// Butteraugli distance (mean over sampled frames).
66    pub butteraugli: f64,
67    /// Mean multi-scale SSIM; `0.0` when not requested.
68    pub ms_ssim: f64,
69    /// Mean VIF (visual information fidelity); computed alongside VMAF.
70    pub vif: f64,
71    /// Mean CAMBI banding score (lower is better); `0.0` when not requested.
72    pub cambi: f64,
73    /// Mean weighted XPSNR `(6·Y + U + V) / 8` (dB); `0.0` when not requested.
74    pub xpsnr: f64,
75    /// Distribution statistics (mean, harmonic mean, percentiles, …) per metric.
76    pub pooled: Pooled,
77    /// Per-frame scores; populated only when `MeasureOpts::per_frame` is set.
78    #[serde(skip_serializing_if = "Vec::is_empty")]
79    pub frames: Vec<FrameResult>,
80}
81
82/// Pooled distribution statistics for each metric, computed from per-frame scores.
83#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
84#[serde(default)]
85pub struct Pooled {
86    /// VMAF distribution.
87    pub vmaf: PooledStats,
88    /// Luma (Y) PSNR distribution.
89    pub psnr: PooledStats,
90    /// SSIM distribution.
91    pub ssim: PooledStats,
92    /// SSIMULACRA2 distribution (populated when more than one frame is sampled).
93    pub ssimulacra2: PooledStats,
94    /// Butteraugli distribution (populated when more than one frame is sampled).
95    pub butteraugli: PooledStats,
96    /// Multi-scale SSIM distribution.
97    pub ms_ssim: PooledStats,
98    /// VIF distribution.
99    pub vif: PooledStats,
100    /// CAMBI banding distribution (lower is better).
101    pub cambi: PooledStats,
102    /// Weighted XPSNR distribution (dB).
103    pub xpsnr: PooledStats,
104}
105
106/// Quality scores for a single frame.
107#[derive(Debug, Clone, Default, Serialize, Deserialize)]
108pub struct FrameResult {
109    /// Frame index within the video.
110    pub frame_num: i32,
111    /// VMAF score for this frame.
112    pub vmaf: f64,
113    /// Luma (Y) PSNR (dB) for this frame.
114    pub psnr: f64,
115    /// Cb/U-plane PSNR (dB) for this frame.
116    #[serde(default)]
117    pub psnr_u: f64,
118    /// Cr/V-plane PSNR (dB) for this frame.
119    #[serde(default)]
120    pub psnr_v: f64,
121    /// SSIM for this frame.
122    pub ssim: f64,
123    /// SSIMULACRA2 score for this frame.
124    pub ssimulacra2: f64,
125    /// Butteraugli distance for this frame.
126    pub butteraugli: f64,
127    /// Multi-scale SSIM for this frame.
128    #[serde(default)]
129    pub ms_ssim: f64,
130    /// VIF for this frame.
131    #[serde(default)]
132    pub vif: f64,
133    /// CAMBI banding score for this frame (lower is better).
134    #[serde(default)]
135    pub cambi: f64,
136    /// Weighted XPSNR (dB) for this frame.
137    #[serde(default)]
138    pub xpsnr: f64,
139}
140
141/// Options controlling a `measure` call.
142#[derive(Debug, Clone)]
143pub struct MeasureOpts {
144    /// Metrics to compute; an empty list defaults to VMAF, PSNR, and SSIM.
145    pub metrics: Vec<Metric>,
146    /// Subsample factor for libvmaf (every Nth frame); `0` means no subsampling.
147    pub subsample: i32,
148    /// VMAF model version name (e.g. `"vmaf_v0.6.1"`).
149    pub model: String,
150    /// When `true`, also collect per-frame scores into `Result::frames`.
151    pub per_frame: bool,
152    /// How many frames to measure for SSIMULACRA2/butteraugli. `0` (the default)
153    /// measures the whole clip; `1` a single frame (frame 0, fastest); higher
154    /// values measure that many evenly-spaced frames. Results pool into
155    /// `Result::pooled`.
156    pub frame_samples: usize,
157    /// Optional probe cache reused across measurements to avoid redundant probes.
158    pub probe_cache: Option<ProbeCache>,
159}
160
161impl Default for MeasureOpts {
162    fn default() -> Self {
163        Self {
164            metrics: vec![
165                Metric::Vmaf,
166                Metric::Psnr,
167                Metric::Ssim,
168                Metric::Ssimulacra2,
169                Metric::Butteraugli,
170            ],
171            subsample: 0,
172            model: "vmaf_v0.6.1".into(),
173            per_frame: false,
174            frame_samples: 0,
175            probe_cache: None,
176        }
177    }
178}
179
180/// Computes quality metrics between a reference and distorted video.
181pub async fn measure(
182    reference: &str,
183    distorted: &str,
184    opts: MeasureOpts,
185) -> anyhow::Result<Result> {
186    let model_name = if opts.model.is_empty() { "vmaf_v0.6.1" } else { &opts.model };
187    let metrics = if opts.metrics.is_empty() {
188        vec![Metric::Vmaf, Metric::Psnr, Metric::Ssim]
189    } else {
190        opts.metrics.clone()
191    };
192
193    // Fast path: when every requested metric is PSNR and/or SSIM, measure with
194    // FFmpeg's native `psnr`/`ssim` filters instead of libvmaf. libvmaf always
195    // runs the expensive VMAF feature extraction (ADM/VIF/motion) regardless of
196    // which `feature=`s ride along, so the native filters are ~10-20x cheaper.
197    if metrics.iter().all(|m| matches!(m, Metric::Psnr | Metric::Ssim)) {
198        return measure_native(reference, distorted, &metrics, &opts).await;
199    }
200
201    let tmp = tempfile::Builder::new().prefix("viser-vmaf-").suffix(".json").tempfile()?;
202    let log_path = tmp.path().to_string_lossy().to_string();
203
204    // Build libvmaf filter string
205    let mut vmaf_opts = format!("log_fmt=json:log_path={log_path}:model=version={model_name}");
206
207    // libvmaf accepts the `feature` option only once; repeating `:feature=...`
208    // makes later entries silently override earlier ones (dropping metrics).
209    // Collect all requested features into a single `|`-separated option.
210    let mut features: Vec<&str> = Vec::new();
211    for m in &metrics {
212        match m {
213            Metric::Psnr => features.push("name=psnr"),
214            Metric::Ssim => features.push("name=float_ssim"),
215            Metric::MsSsim => features.push("name=float_ms_ssim"),
216            Metric::Cambi => features.push("name=cambi"),
217            // VIF rides along with VMAF (vif_scale features are always emitted).
218            Metric::Vmaf | Metric::Vif => {}
219            // Measured outside libvmaf.
220            Metric::Xpsnr | Metric::Ssimulacra2 | Metric::Butteraugli => {}
221        }
222    }
223    if !features.is_empty() {
224        vmaf_opts.push_str(&format!(":feature={}", features.join("|")));
225    }
226
227    if opts.subsample > 0 {
228        vmaf_opts.push_str(&format!(":n_subsample={}", opts.subsample));
229    }
230
231    // Probe reference to get resolution for scaling
232    let ref_info = if let Some(ref cache) = opts.probe_cache {
233        cache.probe(reference).await?
234    } else {
235        viser_ffmpeg::probe(reference).await?
236    };
237
238    let ref_video =
239        ref_info.video_stream().ok_or_else(|| anyhow::anyhow!("no video stream in reference"))?;
240
241    if ref_video.bits_per_raw_sample > 8 {
242        warn!(
243            bits_per_sample = ref_video.bits_per_raw_sample,
244            reference = reference,
245            "10-bit content detected; VMAF scores calibrated for 8-bit may differ"
246        );
247    }
248
249    let filtergraph = format!(
250        "[0:v]scale={}:{}:flags=bicubic[dist];[dist][1:v]libvmaf={}",
251        ref_video.width, ref_video.height, vmaf_opts
252    );
253
254    let args = ["-i", distorted, "-i", reference, "-lavfi", &filtergraph, "-f", "null", "-"];
255
256    let output = Command::new(ffmpeg_path())
257        .args(args)
258        .stderr(std::process::Stdio::piped())
259        .output()
260        .await?;
261
262    if !output.status.success() {
263        let stderr = String::from_utf8_lossy(&output.stderr);
264        anyhow::bail!("ffmpeg quality measurement failed: {stderr}");
265    }
266
267    let data = std::fs::read(&log_path)?;
268    let mut result = parse_vmaf_log(&data, opts.per_frame)?;
269
270    // SSIMULACRA2: run CLI on extracted PNG frames (one frame, or full-clip sample).
271    if metrics.contains(&Metric::Ssimulacra2) {
272        let scores = measure_ssimulacra2(reference, distorted, &opts).await?;
273        result.ssimulacra2 = pool::PoolStrategy::Mean.apply(&scores);
274        result.pooled.ssimulacra2 = PooledStats::from_values(&scores);
275    }
276
277    // Butteraugli: run CLI on extracted PNG frames (one frame, or full-clip sample).
278    if metrics.contains(&Metric::Butteraugli) {
279        let scores = measure_butteraugli(reference, distorted, &opts).await?;
280        result.butteraugli = pool::PoolStrategy::Mean.apply(&scores);
281        result.pooled.butteraugli = PooledStats::from_values(&scores);
282    }
283
284    // XPSNR: a separate FFmpeg pass with the `xpsnr` filter (full clip).
285    if metrics.contains(&Metric::Xpsnr) {
286        let scores = measure_xpsnr(reference, distorted, &opts).await?;
287        result.xpsnr = pool::PoolStrategy::Mean.apply(&scores);
288        result.pooled.xpsnr = PooledStats::from_values(&scores);
289        if opts.per_frame && scores.len() == result.frames.len() {
290            for (fr, s) in result.frames.iter_mut().zip(scores) {
291                fr.xpsnr = s;
292            }
293        }
294    }
295
296    Ok(result)
297}
298
299/// Measures PSNR and/or SSIM using FFmpeg's native filters, bypassing libvmaf.
300///
301/// Far cheaper than the libvmaf path because it skips VMAF feature extraction.
302/// `metrics` must be a non-empty subset of `{Psnr, Ssim}`; each metric runs in its
303/// own FFmpeg pass. Honors `opts.subsample` by decimating frames symmetrically with
304/// a `select` filter on both inputs before measuring. Populates only the requested
305/// scalar fields of `Result` (per-frame and pooled stats are left at their defaults).
306async fn measure_native(
307    reference: &str,
308    distorted: &str,
309    metrics: &[Metric],
310    opts: &MeasureOpts,
311) -> anyhow::Result<Result> {
312    let ref_info = if let Some(ref cache) = opts.probe_cache {
313        cache.probe(reference).await?
314    } else {
315        viser_ffmpeg::probe(reference).await?
316    };
317    let ref_video =
318        ref_info.video_stream().ok_or_else(|| anyhow::anyhow!("no video stream in reference"))?;
319
320    // When subsampling, decimate both inputs identically so the compared frames
321    // stay aligned; otherwise pass both through unchanged.
322    let sel = if opts.subsample > 1 {
323        format!("select=not(mod(n\\,{}))", opts.subsample)
324    } else {
325        "null".to_string()
326    };
327
328    let mut result = Result::default();
329    for m in metrics {
330        let filter_name = match m {
331            Metric::Psnr => "psnr",
332            Metric::Ssim => "ssim",
333            _ => continue,
334        };
335
336        let filtergraph = format!(
337            "[0:v]scale={}:{}:flags=bicubic,{sel}[dist];[1:v]{sel}[ref];[dist][ref]{filter_name}",
338            ref_video.width, ref_video.height
339        );
340        let args = ["-i", distorted, "-i", reference, "-lavfi", &filtergraph, "-f", "null", "-"];
341
342        let output = Command::new(ffmpeg_path())
343            .args(args)
344            .stderr(std::process::Stdio::piped())
345            .output()
346            .await?;
347        if !output.status.success() {
348            let stderr = String::from_utf8_lossy(&output.stderr);
349            anyhow::bail!("ffmpeg {filter_name} measurement failed: {stderr}");
350        }
351
352        let stderr = String::from_utf8_lossy(&output.stderr);
353        match m {
354            Metric::Psnr => {
355                let line = stderr
356                    .lines()
357                    .rev()
358                    .find(|l| l.contains("PSNR") && l.contains("average:"))
359                    .ok_or_else(|| anyhow::anyhow!("could not parse PSNR from ffmpeg output"))?;
360                result.psnr = parse_metric_kv(line, "y:").unwrap_or(0.0);
361                result.psnr_u = parse_metric_kv(line, "u:").unwrap_or(0.0);
362                result.psnr_v = parse_metric_kv(line, "v:").unwrap_or(0.0);
363                result.psnr_avg = parse_metric_kv(line, "average:").unwrap_or(result.psnr);
364            }
365            Metric::Ssim => {
366                let line = stderr
367                    .lines()
368                    .rev()
369                    .find(|l| l.contains("SSIM") && l.contains("All:"))
370                    .ok_or_else(|| anyhow::anyhow!("could not parse SSIM from ffmpeg output"))?;
371                result.ssim = parse_metric_kv(line, "All:").unwrap_or(0.0);
372            }
373            _ => {}
374        }
375    }
376
377    Ok(result)
378}
379
380/// Parses the float following `key` in an FFmpeg filter summary line, e.g.
381/// `parse_metric_kv("PSNR y:40.12 average:41.5", "y:") == Some(40.12)`. Stops at the
382/// first character that cannot be part of a number, so trailing tokens like `(19.0)`
383/// after an SSIM value are ignored. Returns `None` for missing keys or `inf`/`nan`.
384fn parse_metric_kv(line: &str, key: &str) -> Option<f64> {
385    let start = line.find(key)? + key.len();
386    let rest = &line[start..];
387    let end = rest
388        .find(|c: char| !matches!(c, '0'..='9' | '.' | '-' | '+' | 'e' | 'E'))
389        .unwrap_or(rest.len());
390    rest[..end].parse().ok()
391}
392
393// libvmaf JSON output structures
394#[derive(Deserialize)]
395struct VmafLog {
396    frames: Vec<VmafFrame>,
397    #[serde(default)]
398    pooled_metrics: std::collections::HashMap<String, PooledMetric>,
399}
400
401#[derive(Deserialize)]
402struct VmafFrame {
403    #[serde(rename = "frameNum")]
404    frame_num: i32,
405    metrics: std::collections::HashMap<String, f64>,
406}
407
408#[derive(Deserialize)]
409struct PooledMetric {
410    mean: f64,
411}
412
413fn parse_vmaf_log(data: &[u8], per_frame: bool) -> anyhow::Result<Result> {
414    let log: VmafLog = serde_json::from_slice(data)?;
415
416    let mut result = Result::default();
417
418    // Scalar (pooled-mean) values, with naming fallbacks across libvmaf versions.
419    result.vmaf = pooled_mean(&log, &["vmaf"]);
420    result.psnr = pooled_mean(&log, &["psnr_y", "psnr"]);
421    result.psnr_u = pooled_mean(&log, &["psnr_cb", "psnr_u"]);
422    result.psnr_v = pooled_mean(&log, &["psnr_cr", "psnr_v"]);
423    result.psnr_avg = if result.psnr_u > 0.0 && result.psnr_v > 0.0 {
424        // Standard 4:2:0 luma-weighted PSNR. Requires both chroma planes;
425        // with only one present the (6Y+U+V)/8 weighting would divide by a
426        // spurious zero term and under-report, so fall back to luma.
427        (6.0 * result.psnr + result.psnr_u + result.psnr_v) / 8.0
428    } else {
429        result.psnr
430    };
431    result.ssim = pooled_mean(&log, &["float_ssim", "ssim"]);
432
433    // Per-frame series for distribution pooling (computed regardless of `per_frame`).
434    let mut vmaf_series = Vec::with_capacity(log.frames.len());
435    let mut psnr_series = Vec::with_capacity(log.frames.len());
436    let mut ssim_series = Vec::with_capacity(log.frames.len());
437    let mut ms_ssim_series = Vec::with_capacity(log.frames.len());
438    let mut vif_series = Vec::with_capacity(log.frames.len());
439    let mut cambi_series = Vec::with_capacity(log.frames.len());
440    for f in &log.frames {
441        if let Some(v) = f.metrics.get("vmaf") {
442            vmaf_series.push(*v);
443        }
444        if let Some(v) = frame_metric(&f.metrics, &["psnr_y", "psnr"]) {
445            psnr_series.push(v);
446        }
447        if let Some(v) = frame_metric(&f.metrics, &["float_ssim", "ssim"]) {
448            ssim_series.push(v);
449        }
450        if let Some(v) = frame_metric(&f.metrics, &["float_ms_ssim", "ms_ssim"]) {
451            ms_ssim_series.push(v);
452        }
453        if let Some(v) = vif_mean(&f.metrics) {
454            vif_series.push(v);
455        }
456        if let Some(v) = f.metrics.get("cambi") {
457            cambi_series.push(*v);
458        }
459    }
460    result.pooled.vmaf = PooledStats::from_values(&vmaf_series);
461    result.pooled.psnr = PooledStats::from_values(&psnr_series);
462    result.pooled.ssim = PooledStats::from_values(&ssim_series);
463    result.pooled.ms_ssim = PooledStats::from_values(&ms_ssim_series);
464    result.pooled.vif = PooledStats::from_values(&vif_series);
465    result.pooled.cambi = PooledStats::from_values(&cambi_series);
466    result.ms_ssim = result.pooled.ms_ssim.mean;
467    result.vif = result.pooled.vif.mean;
468    result.cambi = result.pooled.cambi.mean;
469
470    // When libvmaf omits pooled_metrics but emits per-frame data, fall back to the mean.
471    if result.vmaf == 0.0 {
472        result.vmaf = result.pooled.vmaf.mean;
473    }
474    if result.psnr == 0.0 {
475        result.psnr = result.pooled.psnr.mean;
476        if result.psnr_avg == 0.0 {
477            result.psnr_avg = result.psnr;
478        }
479    }
480    if result.ssim == 0.0 {
481        result.ssim = result.pooled.ssim.mean;
482    }
483
484    if per_frame {
485        for f in &log.frames {
486            result.frames.push(FrameResult {
487                frame_num: f.frame_num,
488                vmaf: f.metrics.get("vmaf").copied().unwrap_or(0.0),
489                psnr: frame_metric(&f.metrics, &["psnr_y", "psnr"]).unwrap_or(0.0),
490                psnr_u: frame_metric(&f.metrics, &["psnr_cb", "psnr_u"]).unwrap_or(0.0),
491                psnr_v: frame_metric(&f.metrics, &["psnr_cr", "psnr_v"]).unwrap_or(0.0),
492                ssim: frame_metric(&f.metrics, &["float_ssim", "ssim"]).unwrap_or(0.0),
493                ssimulacra2: f.metrics.get("ssimulacra2").copied().unwrap_or(0.0),
494                butteraugli: f.metrics.get("butteraugli").copied().unwrap_or(0.0),
495                ms_ssim: frame_metric(&f.metrics, &["float_ms_ssim", "ms_ssim"]).unwrap_or(0.0),
496                vif: vif_mean(&f.metrics).unwrap_or(0.0),
497                cambi: f.metrics.get("cambi").copied().unwrap_or(0.0),
498                xpsnr: 0.0,
499            });
500        }
501    }
502
503    Ok(result)
504}
505
506/// First matching pooled-metric mean across naming variants, or `0.0`.
507fn pooled_mean(log: &VmafLog, keys: &[&str]) -> f64 {
508    for k in keys {
509        if let Some(m) = log.pooled_metrics.get(*k) {
510            return m.mean;
511        }
512    }
513    0.0
514}
515
516/// First matching per-frame metric value across naming variants.
517fn frame_metric(metrics: &std::collections::HashMap<String, f64>, keys: &[&str]) -> Option<f64> {
518    for k in keys {
519        if let Some(v) = metrics.get(*k) {
520            return Some(*v);
521        }
522    }
523    None
524}
525
526/// Mean of libvmaf's four VIF scales (`*_vif_scale0..3`), across naming variants.
527/// Returns `None` when no VIF scale is present.
528fn vif_mean(metrics: &std::collections::HashMap<String, f64>) -> Option<f64> {
529    let mut sum = 0.0;
530    let mut n = 0;
531    for s in 0..4 {
532        if let Some(v) = frame_metric(
533            metrics,
534            &[
535                &format!("integer_vif_scale{s}"),
536                &format!("float_vif_scale{s}"),
537                &format!("vif_scale{s}"),
538            ],
539        ) {
540            sum += v;
541            n += 1;
542        }
543    }
544    if n > 0 { Some(sum / n as f64) } else { None }
545}
546
547/// Evenly-spaced frame indices for a given sample count: a single frame (`0`)
548/// for `samples <= 1`, otherwise `samples` indices across the clip. Full-clip
549/// measurement (`frame_samples == 0`) is handled by the caller, which skips
550/// this and extracts every frame in one pass.
551fn sample_indices(nb_frames: i32, samples: usize) -> Vec<i32> {
552    if samples <= 1 || nb_frames <= 1 {
553        return vec![0];
554    }
555    let count = samples.min(nb_frames as usize);
556    if count <= 1 {
557        return vec![0];
558    }
559    (0..count)
560        .map(|i| ((i as f64) * (nb_frames as f64 - 1.0) / (count as f64 - 1.0)).round() as i32)
561        .collect()
562}
563
564/// Resolve the reference video stream's dimensions and frame count.
565async fn reference_dims(reference: &str, opts: &MeasureOpts) -> anyhow::Result<(i32, i32, i32)> {
566    let ref_info = if let Some(ref cache) = opts.probe_cache {
567        cache.probe(reference).await?
568    } else {
569        viser_ffmpeg::probe(reference).await?
570    };
571    let ref_video =
572        ref_info.video_stream().ok_or_else(|| anyhow::anyhow!("no video stream in reference"))?;
573    Ok((ref_video.width, ref_video.height, ref_video.nb_frames))
574}
575
576/// Extract frames from `input` as PNGs into `dir` in a single decode pass.
577///
578/// `selection == None` extracts every frame (full clip); otherwise just the
579/// given indices. Frames are written as zero-padded sequential PNGs and returned
580/// in extraction (ascending-index) order. One pass per video avoids the
581/// quadratic cost of re-decoding from the start for each frame.
582async fn extract_frames_png(
583    input: &str,
584    selection: Option<&[i32]>,
585    width: i32,
586    height: i32,
587    dir: &Path,
588) -> anyhow::Result<Vec<PathBuf>> {
589    let scale = format!("scale={width}:{height}:flags=bicubic");
590    let vf = match selection {
591        None => scale,
592        Some(indices) => {
593            let sel = indices.iter().map(|i| format!("eq(n\\,{i})")).collect::<Vec<_>>().join("+");
594            format!("select='{sel}',{scale}")
595        }
596    };
597    let pattern = dir.join("%06d.png");
598    let output = Command::new(ffmpeg_path())
599        .args(["-i", input, "-vf", &vf, "-fps_mode", "passthrough", "-c:v", "png"])
600        .arg(&pattern)
601        .stderr(std::process::Stdio::piped())
602        .output()
603        .await?;
604
605    if !output.status.success() {
606        let stderr = String::from_utf8_lossy(&output.stderr);
607        anyhow::bail!("failed to extract frames from {input}: {stderr}");
608    }
609
610    let mut paths: Vec<PathBuf> = std::fs::read_dir(dir)?
611        .filter_map(|e| e.ok().map(|e| e.path()))
612        .filter(|p| p.extension().is_some_and(|x| x == "png"))
613        .collect();
614    paths.sort();
615    Ok(paths)
616}
617
618/// Aligned reference/distorted PNG frame pairs for the perceptual metrics, kept
619/// alive by their temp dirs. `frame_samples == 0` measures the whole clip;
620/// otherwise the evenly-spaced [`sample_indices`].
621struct FramePairs {
622    _ref_dir: tempfile::TempDir,
623    _dist_dir: tempfile::TempDir,
624    pairs: Vec<(PathBuf, PathBuf)>,
625}
626
627async fn extract_frame_pairs(
628    reference: &str,
629    distorted: &str,
630    opts: &MeasureOpts,
631) -> anyhow::Result<FramePairs> {
632    let (width, height, nb_frames) = reference_dims(reference, opts).await?;
633    let (_, _, dist_nb_frames) = reference_dims(distorted, opts).await?;
634    if dist_nb_frames != nb_frames {
635        warn!(
636            reference_frames = nb_frames,
637            distorted_frames = dist_nb_frames,
638            "reference and distorted frame counts differ; sampled perceptual metrics may be misaligned"
639        );
640    }
641
642    let selection: Option<Vec<i32>> = if opts.frame_samples == 0 {
643        None
644    } else {
645        Some(sample_indices(nb_frames, opts.frame_samples))
646    };
647    let sel = selection.as_deref();
648
649    let ref_dir = tempfile::Builder::new().prefix("viser-q-ref-").tempdir()?;
650    let dist_dir = tempfile::Builder::new().prefix("viser-q-dist-").tempdir()?;
651    let ref_paths = extract_frames_png(reference, sel, width, height, ref_dir.path()).await?;
652    let dist_paths = extract_frames_png(distorted, sel, width, height, dist_dir.path()).await?;
653
654    let n = ref_paths.len().min(dist_paths.len());
655    let pairs =
656        ref_paths.into_iter().take(n).zip(dist_paths.into_iter().take(n)).collect::<Vec<_>>();
657    Ok(FramePairs { _ref_dir: ref_dir, _dist_dir: dist_dir, pairs })
658}
659
660/// Run the `ssimulacra2` CLI over the measured frames; one score per frame.
661async fn measure_ssimulacra2(
662    reference: &str,
663    distorted: &str,
664    opts: &MeasureOpts,
665) -> anyhow::Result<Vec<f64>> {
666    let frames = extract_frame_pairs(reference, distorted, opts).await?;
667    let mut scores = Vec::with_capacity(frames.pairs.len());
668    for (ref_png, dist_png) in &frames.pairs {
669        let s2_output = Command::new("ssimulacra2")
670            .arg(ref_png)
671            .arg(dist_png)
672            .stdout(std::process::Stdio::piped())
673            .stderr(std::process::Stdio::null())
674            .output()
675            .await?;
676
677        if !s2_output.status.success() {
678            anyhow::bail!("ssimulacra2 failed: {}", String::from_utf8_lossy(&s2_output.stderr));
679        }
680
681        let stdout_str = String::from_utf8_lossy(&s2_output.stdout);
682        let score: f64 = stdout_str
683            .trim()
684            .parse()
685            .map_err(|_| anyhow::anyhow!("ssimulacra2: could not parse score: {stdout_str}"))?;
686        scores.push(score);
687    }
688
689    Ok(scores)
690}
691
692/// Run the `butteraugli` CLI over the measured frames; one score per frame.
693///
694/// Butteraugli may be absent or silent on success; missing or unparseable output
695/// yields a `0.0` sentinel for that frame rather than failing the measurement.
696async fn measure_butteraugli(
697    reference: &str,
698    distorted: &str,
699    opts: &MeasureOpts,
700) -> anyhow::Result<Vec<f64>> {
701    let frames = extract_frame_pairs(reference, distorted, opts).await?;
702    let mut scores = Vec::with_capacity(frames.pairs.len());
703    for (i, (ref_png, dist_png)) in frames.pairs.iter().enumerate() {
704        let ba_output = Command::new("butteraugli")
705            .arg(ref_png)
706            .arg(dist_png)
707            .stdout(std::process::Stdio::piped())
708            .stderr(std::process::Stdio::null())
709            .output()
710            .await;
711
712        let mut score = 0.0;
713        let mut parsed = false;
714        if let Ok(out) = ba_output
715            && out.status.success()
716        {
717            let stdout_str = String::from_utf8_lossy(&out.stdout);
718            if let Ok(s) = stdout_str.trim().parse::<f64>() {
719                score = s;
720                parsed = true;
721            } else if let Some(last_line) = stdout_str.lines().last() {
722                // butteraugli may emit extra lines; the score is usually the last.
723                if let Ok(s) = last_line.trim().parse::<f64>() {
724                    score = s;
725                    parsed = true;
726                }
727            }
728        }
729        if !parsed {
730            warn!(frame = i, "butteraugli not available or failed; recording 0.0");
731        }
732        scores.push(score);
733    }
734
735    Ok(scores)
736}
737
738/// Parse the number after `tag` (e.g. `"y:"`) on an xpsnr stats line, mapping
739/// non-finite values (identical frames report `inf`) to a `100.0` dB cap.
740fn parse_xpsnr_component(line: &str, tag: &str) -> Option<f64> {
741    let idx = line.find(tag)?;
742    let token = line[idx + tag.len()..].split_whitespace().next()?;
743    match token {
744        "inf" | "-inf" => Some(100.0),
745        t => t.parse::<f64>().ok().map(|x| if x.is_finite() { x } else { 100.0 }),
746    }
747}
748
749/// Run FFmpeg's `xpsnr` filter over the whole clip; one weighted XPSNR
750/// `(6·Y + U + V) / 8` (dB) per frame, parsed from the per-frame stats file.
751async fn measure_xpsnr(
752    reference: &str,
753    distorted: &str,
754    opts: &MeasureOpts,
755) -> anyhow::Result<Vec<f64>> {
756    let (width, height, _nb) = reference_dims(reference, opts).await?;
757    let stats = tempfile::Builder::new().prefix("viser-xpsnr-").suffix(".log").tempfile()?;
758    let stats_path = stats.path().to_string_lossy().to_string();
759
760    // Match the libvmaf path: scale the distorted input to reference dimensions.
761    let filtergraph = format!(
762        "[0:v]scale={width}:{height}:flags=bicubic[dist];[dist][1:v]xpsnr=stats_file={stats_path}"
763    );
764    let output = Command::new(ffmpeg_path())
765        .args(["-i", distorted, "-i", reference, "-lavfi", &filtergraph, "-f", "null", "-"])
766        .stderr(std::process::Stdio::piped())
767        .output()
768        .await?;
769
770    if !output.status.success() {
771        let stderr = String::from_utf8_lossy(&output.stderr);
772        anyhow::bail!("xpsnr measurement failed: {stderr}");
773    }
774
775    let log = std::fs::read_to_string(stats.path())?;
776    let mut scores = Vec::new();
777    for line in log.lines() {
778        // e.g. "n:    1  XPSNR y: 46.9714  XPSNR u: 45.1188  XPSNR v: 45.0873"
779        if let Some(y) = parse_xpsnr_component(line, "y:") {
780            let u = parse_xpsnr_component(line, "u:").unwrap_or(y);
781            let v = parse_xpsnr_component(line, "v:").unwrap_or(y);
782            scores.push((6.0 * y + u + v) / 8.0);
783        }
784    }
785    Ok(scores)
786}
787
788#[cfg(test)]
789mod tests {
790    use super::*;
791
792    #[test]
793    fn test_metric_serde_roundtrip() {
794        for m in
795            &[Metric::Vmaf, Metric::Psnr, Metric::Ssim, Metric::Ssimulacra2, Metric::Butteraugli]
796        {
797            let json = serde_json::to_string(m).unwrap();
798            let back: Metric = serde_json::from_str(&json).unwrap();
799            assert_eq!(*m, back);
800        }
801    }
802
803    #[test]
804    fn test_metric_serde_names() {
805        assert_eq!(serde_json::to_string(&Metric::Vmaf).unwrap(), "\"vmaf\"");
806        assert_eq!(serde_json::to_string(&Metric::Psnr).unwrap(), "\"psnr\"");
807        assert_eq!(serde_json::to_string(&Metric::Ssim).unwrap(), "\"ssim\"");
808        assert_eq!(serde_json::to_string(&Metric::Ssimulacra2).unwrap(), "\"ssimulacra2\"");
809        assert_eq!(serde_json::to_string(&Metric::Butteraugli).unwrap(), "\"butteraugli\"");
810    }
811
812    #[test]
813    fn test_metric_eq() {
814        assert_eq!(Metric::Vmaf, Metric::Vmaf);
815        assert_ne!(Metric::Vmaf, Metric::Psnr);
816        assert_eq!(Metric::Ssimulacra2, Metric::Ssimulacra2);
817        assert_ne!(Metric::Ssimulacra2, Metric::Butteraugli);
818    }
819
820    #[test]
821    fn test_result_default() {
822        let r = Result::default();
823        assert!((r.vmaf - 0.0).abs() < 1e-9);
824        assert!((r.psnr - 0.0).abs() < 1e-9);
825        assert!((r.ssim - 0.0).abs() < 1e-9);
826        assert!((r.ssimulacra2 - 0.0).abs() < 1e-9);
827        assert!((r.butteraugli - 0.0).abs() < 1e-9);
828        assert!(r.frames.is_empty());
829    }
830
831    #[test]
832    fn test_parse_vmaf_log_basic() {
833        let json = br#"{
834            "frames": [
835                {"frameNum": 0, "metrics": {"vmaf": 85.0, "psnr_y": 38.5, "float_ssim": 0.95}}
836            ],
837            "pooled_metrics": {
838                "vmaf": {"mean": 86.5},
839                "psnr_y": {"mean": 39.2},
840                "float_ssim": {"mean": 0.96}
841            }
842        }"#;
843        let result = parse_vmaf_log(json, false).unwrap();
844        assert!((result.vmaf - 86.5).abs() < 1e-9);
845        assert!((result.psnr - 39.2).abs() < 1e-9);
846        assert!((result.ssim - 0.96).abs() < 1e-9);
847        assert!(result.frames.is_empty());
848    }
849
850    #[test]
851    fn test_parse_vmaf_log_per_frame() {
852        let json = br#"{
853            "frames": [
854                {"frameNum": 0, "metrics": {"vmaf": 80.0, "psnr_y": 37.0, "float_ssim": 0.93}},
855                {"frameNum": 1, "metrics": {"vmaf": 90.0, "psnr_y": 40.0, "float_ssim": 0.97}}
856            ],
857            "pooled_metrics": {
858                "vmaf": {"mean": 85.0},
859                "psnr_y": {"mean": 38.5},
860                "float_ssim": {"mean": 0.95}
861            }
862        }"#;
863        let result = parse_vmaf_log(json, true).unwrap();
864        assert_eq!(result.frames.len(), 2);
865        assert_eq!(result.frames[0].frame_num, 0);
866        assert!((result.frames[0].vmaf - 80.0).abs() < 1e-9);
867        assert_eq!(result.frames[1].frame_num, 1);
868        assert!((result.frames[1].vmaf - 90.0).abs() < 1e-9);
869    }
870
871    #[test]
872    fn test_parse_vmaf_log_fallback_psnr() {
873        let json = br#"{
874            "frames": [],
875            "pooled_metrics": {
876                "vmaf": {"mean": 85.0},
877                "psnr": {"mean": 39.0},
878                "ssim": {"mean": 0.94}
879            }
880        }"#;
881        let result = parse_vmaf_log(json, false).unwrap();
882        assert!((result.psnr - 39.0).abs() < 1e-9);
883    }
884
885    #[test]
886    fn test_parse_vmaf_log_missing_metrics() {
887        let json = br#"{
888            "frames": [],
889            "pooled_metrics": {}
890        }"#;
891        let result = parse_vmaf_log(json, false).unwrap();
892        assert!((result.vmaf - 0.0).abs() < 1e-9);
893        assert!((result.psnr - 0.0).abs() < 1e-9);
894        assert!((result.ssim - 0.0).abs() < 1e-9);
895    }
896
897    #[test]
898    fn test_parse_vmaf_log_invalid_json() {
899        assert!(parse_vmaf_log(b"not json", false).is_err());
900    }
901
902    #[test]
903    fn test_result_serde_roundtrip() {
904        let r = Result {
905            vmaf: 85.0,
906            psnr: 38.5,
907            ssim: 0.95,
908            ssimulacra2: 70.0,
909            butteraugli: 0.5,
910            ..Default::default()
911        };
912        let json = serde_json::to_string(&r).unwrap();
913        let back: Result = serde_json::from_str(&json).unwrap();
914        assert!((back.vmaf - 85.0).abs() < 1e-9);
915        assert!((back.ssimulacra2 - 70.0).abs() < 1e-9);
916        assert!((back.butteraugli - 0.5).abs() < 1e-9);
917    }
918
919    #[test]
920    fn test_parse_vmaf_log_per_component_psnr() {
921        let json = br#"{
922            "frames": [],
923            "pooled_metrics": {
924                "vmaf": {"mean": 85.0},
925                "psnr_y": {"mean": 40.0},
926                "psnr_cb": {"mean": 44.0},
927                "psnr_cr": {"mean": 46.0},
928                "float_ssim": {"mean": 0.95}
929            }
930        }"#;
931        let result = parse_vmaf_log(json, false).unwrap();
932        assert!((result.psnr - 40.0).abs() < 1e-9, "luma");
933        assert!((result.psnr_u - 44.0).abs() < 1e-9, "Cb");
934        assert!((result.psnr_v - 46.0).abs() < 1e-9, "Cr");
935        // weighted (6*40 + 44 + 46) / 8 = 41.25
936        assert!((result.psnr_avg - 41.25).abs() < 1e-9, "weighted avg");
937    }
938
939    #[test]
940    fn test_parse_vmaf_log_psnr_avg_falls_back_to_luma() {
941        let json = br#"{
942            "frames": [],
943            "pooled_metrics": {"psnr_y": {"mean": 39.0}}
944        }"#;
945        let result = parse_vmaf_log(json, false).unwrap();
946        assert!((result.psnr_avg - 39.0).abs() < 1e-9);
947    }
948
949    #[test]
950    fn test_parse_vmaf_log_pooled_distribution() {
951        let json = br#"{
952            "frames": [
953                {"frameNum": 0, "metrics": {"vmaf": 80.0, "psnr_y": 37.0, "float_ssim": 0.93}},
954                {"frameNum": 1, "metrics": {"vmaf": 90.0, "psnr_y": 41.0, "float_ssim": 0.97}}
955            ],
956            "pooled_metrics": {"vmaf": {"mean": 85.0}}
957        }"#;
958        let result = parse_vmaf_log(json, false).unwrap();
959        assert_eq!(result.pooled.vmaf.count, 2);
960        assert!((result.pooled.vmaf.min - 80.0).abs() < 1e-9);
961        assert!((result.pooled.vmaf.max - 90.0).abs() < 1e-9);
962        assert!((result.pooled.vmaf.mean - 85.0).abs() < 1e-9);
963        // psnr/ssim distributions are pooled even without a pooled_metrics entry
964        assert!((result.pooled.psnr.min - 37.0).abs() < 1e-9);
965        assert!((result.psnr - 39.0).abs() < 1e-9, "psnr falls back to frame mean");
966    }
967
968    #[test]
969    fn test_sample_indices() {
970        assert_eq!(sample_indices(100, 0), vec![0]);
971        assert_eq!(sample_indices(100, 1), vec![0]);
972        assert_eq!(sample_indices(0, 5), vec![0]);
973        assert_eq!(sample_indices(1, 5), vec![0]);
974        assert_eq!(sample_indices(101, 3), vec![0, 50, 100]);
975        // never asks for more frames than exist
976        assert_eq!(sample_indices(2, 10), vec![0, 1]);
977    }
978
979    #[test]
980    fn test_result_serde_omits_zero_frames() {
981        let r = Result::default();
982        let json = serde_json::to_string(&r).unwrap();
983        assert!(!json.contains("frames"));
984    }
985
986    #[test]
987    fn test_measure_opts_default() {
988        let opts = MeasureOpts::default();
989        assert_eq!(opts.metrics.len(), 5);
990        assert_eq!(opts.subsample, 0);
991        assert_eq!(opts.model, "vmaf_v0.6.1");
992        assert!(!opts.per_frame);
993        assert_eq!(opts.frame_samples, 0);
994        assert!(opts.probe_cache.is_none());
995    }
996
997    #[test]
998    fn test_vif_mean() {
999        let mut m = std::collections::HashMap::new();
1000        m.insert("integer_vif_scale0".to_string(), 0.2);
1001        m.insert("integer_vif_scale1".to_string(), 0.4);
1002        m.insert("integer_vif_scale2".to_string(), 0.6);
1003        m.insert("integer_vif_scale3".to_string(), 0.8);
1004        assert!((vif_mean(&m).unwrap() - 0.5).abs() < 1e-9);
1005
1006        // Naming-variant fallback and partial presence.
1007        let mut m2 = std::collections::HashMap::new();
1008        m2.insert("vif_scale0".to_string(), 1.0);
1009        m2.insert("float_vif_scale1".to_string(), 0.0);
1010        assert!((vif_mean(&m2).unwrap() - 0.5).abs() < 1e-9);
1011
1012        assert!(vif_mean(&std::collections::HashMap::new()).is_none());
1013    }
1014
1015    #[test]
1016    fn test_parse_xpsnr_component() {
1017        let line = "n:    1  XPSNR y: 46.9714  XPSNR u: 45.1188  XPSNR v: 45.0873";
1018        assert!((parse_xpsnr_component(line, "y:").unwrap() - 46.9714).abs() < 1e-9);
1019        assert!((parse_xpsnr_component(line, "u:").unwrap() - 45.1188).abs() < 1e-9);
1020        assert!((parse_xpsnr_component(line, "v:").unwrap() - 45.0873).abs() < 1e-9);
1021        // Identical frames report inf → clamped to the 100 dB cap.
1022        assert_eq!(parse_xpsnr_component("XPSNR y: inf", "y:"), Some(100.0));
1023        assert_eq!(parse_xpsnr_component("nothing here", "y:"), None);
1024    }
1025
1026    #[test]
1027    fn test_parse_vmaf_log_extended_metrics() {
1028        let json = br#"{
1029            "frames": [
1030                {"frameNum": 0, "metrics": {"vmaf": 80.0, "float_ms_ssim": 0.90, "cambi": 2.0,
1031                    "integer_vif_scale0": 0.2, "integer_vif_scale1": 0.4,
1032                    "integer_vif_scale2": 0.6, "integer_vif_scale3": 0.8}},
1033                {"frameNum": 1, "metrics": {"vmaf": 90.0, "float_ms_ssim": 1.00, "cambi": 0.0,
1034                    "integer_vif_scale0": 0.4, "integer_vif_scale1": 0.6,
1035                    "integer_vif_scale2": 0.8, "integer_vif_scale3": 1.0}}
1036            ],
1037            "pooled_metrics": {"vmaf": {"mean": 85.0}}
1038        }"#;
1039        let result = parse_vmaf_log(json, true).unwrap();
1040        // MS-SSIM mean of 0.90 and 1.00.
1041        assert!((result.ms_ssim - 0.95).abs() < 1e-9);
1042        // CAMBI mean of 2.0 and 0.0.
1043        assert!((result.cambi - 1.0).abs() < 1e-9);
1044        // VIF: frame means 0.5 and 0.7 → overall 0.6.
1045        assert!((result.vif - 0.6).abs() < 1e-9);
1046        // Per-frame propagation.
1047        assert!((result.frames[0].ms_ssim - 0.90).abs() < 1e-9);
1048        assert!((result.frames[0].vif - 0.5).abs() < 1e-9);
1049        assert!((result.frames[1].cambi - 0.0).abs() < 1e-9);
1050    }
1051
1052    // ── Extended VMAF log parsing corner cases ──
1053    #[test]
1054    fn test_parse_vmaf_log_ssim_no_float_prefix() {
1055        let json = br#"{
1056            "frames": [{"frameNum": 0, "metrics": {"ssim": 0.92}}],
1057            "pooled_metrics": {"ssim": {"mean": 0.92}}
1058        }"#;
1059        let result = parse_vmaf_log(json, false).unwrap();
1060        assert!((result.ssim - 0.92).abs() < 1e-9);
1061    }
1062
1063    #[test]
1064    fn test_parse_vmaf_log_ms_ssim_fallback_name() {
1065        let json = br#"{
1066            "frames": [{"frameNum": 0, "metrics": {"ms_ssim": 0.88}}],
1067            "pooled_metrics": {}
1068        }"#;
1069        let result = parse_vmaf_log(json, false).unwrap();
1070        assert!((result.ms_ssim - 0.88).abs() < 1e-9);
1071    }
1072
1073    #[test]
1074    fn test_parse_vmaf_log_psnr_cb_cr_fallback_names() {
1075        let json = br#"{
1076            "frames": [],
1077            "pooled_metrics": {
1078                "psnr_y": {"mean": 40.0},
1079                "psnr_cb": {"mean": 44.0},
1080                "psnr_cr": {"mean": 46.0}
1081            }
1082        }"#;
1083        let result = parse_vmaf_log(json, false).unwrap();
1084        assert!((result.psnr_u - 44.0).abs() < 1e-9, "Cb via psnr_cb");
1085        assert!((result.psnr_v - 46.0).abs() < 1e-9, "Cr via psnr_cr");
1086    }
1087
1088    #[test]
1089    fn test_parse_vmaf_log_psnr_u_v_fallback_names() {
1090        let json = br#"{
1091            "frames": [],
1092            "pooled_metrics": {
1093                "psnr_y": {"mean": 40.0},
1094                "psnr_u": {"mean": 43.0},
1095                "psnr_v": {"mean": 45.0}
1096            }
1097        }"#;
1098        let result = parse_vmaf_log(json, false).unwrap();
1099        assert!((result.psnr_u - 43.0).abs() < 1e-9, "Cb via psnr_u");
1100        assert!((result.psnr_v - 45.0).abs() < 1e-9, "Cr via psnr_v");
1101    }
1102
1103    #[test]
1104    fn test_parse_vmaf_log_pooled_missing_fallback_to_frame_mean() {
1105        let json = br#"{
1106            "frames": [
1107                {"frameNum": 0, "metrics": {"vmaf": 80.0, "psnr_y": 36.0, "float_ssim": 0.90}},
1108                {"frameNum": 1, "metrics": {"vmaf": 90.0, "psnr_y": 42.0, "float_ssim": 0.96}}
1109            ],
1110            "pooled_metrics": {}
1111        }"#;
1112        let result = parse_vmaf_log(json, false).unwrap();
1113        assert!((result.vmaf - 85.0).abs() < 1e-9);
1114        assert!((result.psnr - 39.0).abs() < 1e-9);
1115        assert!((result.ssim - 0.93).abs() < 1e-9);
1116    }
1117
1118    #[test]
1119    fn test_parse_vmaf_log_empty_frames_and_pooled() {
1120        let json = br#"{
1121            "frames": [],
1122            "pooled_metrics": {}
1123        }"#;
1124        let result = parse_vmaf_log(json, false).unwrap();
1125        assert!((result.vmaf - 0.0).abs() < 1e-9);
1126        assert!((result.psnr - 0.0).abs() < 1e-9);
1127        assert!((result.ssim - 0.0).abs() < 1e-9);
1128        assert!((result.ms_ssim - 0.0).abs() < 1e-9);
1129        assert!((result.vif - 0.0).abs() < 1e-9);
1130        assert!((result.cambi - 0.0).abs() < 1e-9);
1131    }
1132
1133    #[test]
1134    fn test_parse_vmaf_log_single_frame_with_pooled() {
1135        let json = br#"{
1136            "frames": [{"frameNum": 0, "metrics": {"vmaf": 95.0}}],
1137            "pooled_metrics": {"vmaf": {"mean": 95.0}}
1138        }"#;
1139        let result = parse_vmaf_log(json, false).unwrap();
1140        assert!((result.vmaf - 95.0).abs() < 1e-9);
1141        assert_eq!(result.pooled.vmaf.count, 1);
1142    }
1143
1144    #[test]
1145    fn test_parse_vmaf_log_vif_mixed_naming() {
1146        let json = br#"{
1147            "frames": [{"frameNum": 0, "metrics": {
1148                "integer_vif_scale0": 0.5,
1149                "float_vif_scale0": 0.4,
1150                "vif_scale1": 0.6,
1151                "integer_vif_scale1": 0.6
1152            }}],
1153            "pooled_metrics": {}
1154        }"#;
1155        let result = parse_vmaf_log(json, false).unwrap();
1156        // scale0: integer_vif_scale0=0.5 (first match), scale1: vif_scale1=0.6 (first match)
1157        // mean = (0.5 + 0.6) / 2 = 0.55
1158        assert!((result.vif - 0.55).abs() < 1e-9, "mean of 2 scales with naming variants");
1159    }
1160
1161    #[test]
1162    fn test_parse_vmaf_log_xpsnr_per_frame_propagation() {
1163        let json = br#"{
1164            "frames": [
1165                {"frameNum": 0, "metrics": {"vmaf": 85.0}}
1166            ],
1167            "pooled_metrics": {"vmaf": {"mean": 85.0}}
1168        }"#;
1169        let mut result = parse_vmaf_log(json, true).unwrap();
1170        result.xpsnr = 0.0;
1171        result.frames[0].xpsnr = 45.5;
1172        assert!((result.frames[0].xpsnr - 45.5).abs() < 1e-9);
1173    }
1174
1175    #[test]
1176    fn test_parse_vmaf_log_pooled_distribution_single_frame() {
1177        let json = br#"{
1178            "frames": [{"frameNum": 0, "metrics": {"vmaf": 88.0}}],
1179            "pooled_metrics": {"vmaf": {"mean": 88.0}}
1180        }"#;
1181        let result = parse_vmaf_log(json, false).unwrap();
1182        assert_eq!(result.pooled.vmaf.count, 1);
1183        assert!((result.pooled.vmaf.min - 88.0).abs() < 1e-9);
1184        assert!((result.pooled.vmaf.max - 88.0).abs() < 1e-9);
1185        assert!((result.pooled.vmaf.mean - 88.0).abs() < 1e-9);
1186    }
1187
1188    #[test]
1189    fn test_parse_vmaf_log_per_frame_with_missing_metrics() {
1190        let json = br#"{
1191            "frames": [
1192                {"frameNum": 0, "metrics": {"vmaf": 85.0}},
1193                {"frameNum": 1, "metrics": {}}
1194            ],
1195            "pooled_metrics": {"vmaf": {"mean": 85.0}}
1196        }"#;
1197        let result = parse_vmaf_log(json, true).unwrap();
1198        assert_eq!(result.frames.len(), 2);
1199        assert!((result.frames[0].vmaf - 85.0).abs() < 1e-9);
1200        assert!((result.frames[1].vmaf - 0.0).abs() < 1e-9);
1201    }
1202
1203    #[test]
1204    fn test_parse_xpsnr_component_negative_inf() {
1205        assert_eq!(parse_xpsnr_component("XPSNR y: -inf", "y:"), Some(100.0));
1206    }
1207
1208    #[test]
1209    fn test_parse_xpsnr_component_nan() {
1210        assert_eq!(parse_xpsnr_component("XPSNR y: NaN", "y:"), Some(100.0));
1211    }
1212
1213    #[test]
1214    fn test_parse_xpsnr_component_regular() {
1215        assert!((parse_xpsnr_component("XPSNR u: 44.5678", "u:").unwrap() - 44.5678).abs() < 1e-4);
1216    }
1217
1218    #[test]
1219    fn test_parse_xpsnr_component_bad_format() {
1220        assert_eq!(parse_xpsnr_component("n: 1 XPSNR", "y:"), None);
1221    }
1222
1223    #[test]
1224    fn test_sample_indices_uneven() {
1225        assert_eq!(sample_indices(5, 3), vec![0, 2, 4]);
1226    }
1227
1228    #[test]
1229    fn test_sample_indices_more_samples_than_frames() {
1230        assert_eq!(sample_indices(2, 10), vec![0, 1]);
1231    }
1232
1233    #[test]
1234    fn test_sample_indices_single_frame_input() {
1235        assert_eq!(sample_indices(1, 5), vec![0]);
1236    }
1237
1238    #[test]
1239    fn test_sample_indices_large_values() {
1240        let indices = sample_indices(1000, 5);
1241        assert_eq!(indices.len(), 5);
1242        assert_eq!(indices[0], 0);
1243        assert_eq!(indices[4], 999);
1244    }
1245
1246    // ── pooled_mean and frame_metric ──
1247    #[test]
1248    fn test_pooled_mean_first_match_wins() {
1249        let mut map = std::collections::HashMap::new();
1250        map.insert("psnr_y".to_string(), PooledMetric { mean: 40.0 });
1251        map.insert("psnr".to_string(), PooledMetric { mean: 39.0 });
1252        assert_eq!(
1253            pooled_mean(&VmafLog { frames: vec![], pooled_metrics: map }, &["psnr_y", "psnr"]),
1254            40.0
1255        );
1256    }
1257
1258    #[test]
1259    fn test_pooled_mean_fallback() {
1260        let mut map = std::collections::HashMap::new();
1261        map.insert("psnr".to_string(), PooledMetric { mean: 39.0 });
1262        assert_eq!(
1263            pooled_mean(&VmafLog { frames: vec![], pooled_metrics: map }, &["psnr_y", "psnr"]),
1264            39.0
1265        );
1266    }
1267
1268    #[test]
1269    fn test_pooled_mean_missing_all() {
1270        assert_eq!(
1271            pooled_mean(
1272                &VmafLog { frames: vec![], pooled_metrics: std::collections::HashMap::new() },
1273                &["psnr_y", "psnr"]
1274            ),
1275            0.0
1276        );
1277    }
1278
1279    #[test]
1280    fn test_frame_metric_first_match() {
1281        let mut map = std::collections::HashMap::new();
1282        map.insert("psnr_y".to_string(), 40.0);
1283        map.insert("psnr".to_string(), 39.0);
1284        assert_eq!(frame_metric(&map, &["psnr_y", "psnr"]), Some(40.0));
1285    }
1286
1287    #[test]
1288    fn test_frame_metric_fallback() {
1289        let mut map = std::collections::HashMap::new();
1290        map.insert("psnr".to_string(), 39.0);
1291        assert_eq!(frame_metric(&map, &["psnr_y", "psnr"]), Some(39.0));
1292    }
1293
1294    #[test]
1295    fn test_frame_metric_missing() {
1296        assert_eq!(frame_metric(&std::collections::HashMap::new(), &["psnr_y"]), None);
1297    }
1298}