Skip to main content

scirs2_metrics/keypoint/
mod.rs

1//! Keypoint Detection Metrics
2//!
3//! This module provides evaluation metrics for human pose estimation and
4//! keypoint detection tasks, following COCO benchmark conventions:
5//!
6//! - **OKS** (Object Keypoint Similarity): COCO-standard per-instance metric
7//! - **PCK** (Percentage of Correct Keypoints): threshold-based accuracy
8//! - **PCKh**: PCK normalised by head size
9//! - **Mean OKS**: dataset-level OKS averaging
10//! - **Mean Keypoint Error**: average Euclidean distance for visible keypoints
11
12use crate::error::{MetricsError, Result};
13
14// ─────────────────────────────────────────────────────────────────────────────
15// Data Structures
16// ─────────────────────────────────────────────────────────────────────────────
17
18/// A single pose annotation with 2-D keypoint coordinates and visibility flags.
19#[derive(Debug, Clone)]
20pub struct KeypointAnnotation {
21    /// `(x, y)` coordinates for each keypoint.
22    pub keypoints: Vec<[f64; 2]>,
23    /// Visibility flag per keypoint: `0` = absent, `1` = occluded, `2` = visible.
24    pub visibility: Vec<u8>,
25    /// Square root of the object area (used for OKS scale normalisation).
26    pub scale: f64,
27}
28
29impl KeypointAnnotation {
30    /// Validate that `keypoints` and `visibility` have the same length.
31    pub fn validate(&self) -> Result<()> {
32        if self.keypoints.len() != self.visibility.len() {
33            return Err(MetricsError::DimensionMismatch(format!(
34                "keypoints len {} != visibility len {}",
35                self.keypoints.len(),
36                self.visibility.len()
37            )));
38        }
39        Ok(())
40    }
41}
42
43// ─────────────────────────────────────────────────────────────────────────────
44// COCO Constants
45// ─────────────────────────────────────────────────────────────────────────────
46
47/// Default COCO sigmas for the 17 COCO body keypoints.
48///
49/// Order: nose, left_eye, right_eye, left_ear, right_ear,
50/// left_shoulder, right_shoulder, left_elbow, right_elbow,
51/// left_wrist, right_wrist, left_hip, right_hip,
52/// left_knee, right_knee, left_ankle, right_ankle.
53pub fn coco_sigmas() -> Vec<f64> {
54    vec![
55        0.026, // nose
56        0.025, // left_eye
57        0.025, // right_eye
58        0.035, // left_ear
59        0.035, // right_ear
60        0.079, // left_shoulder
61        0.079, // right_shoulder
62        0.072, // left_elbow
63        0.072, // right_elbow
64        0.062, // left_wrist
65        0.062, // right_wrist
66        0.107, // left_hip
67        0.107, // right_hip
68        0.087, // left_knee
69        0.087, // right_knee
70        0.089, // left_ankle
71        0.089, // right_ankle
72    ]
73}
74
75// ─────────────────────────────────────────────────────────────────────────────
76// OKS
77// ─────────────────────────────────────────────────────────────────────────────
78
79/// Object Keypoint Similarity (OKS) — COCO standard.
80///
81/// ```text
82/// OKS = Σ_i [exp(−d_i² / (2 s² k_i²)) * δ(v_i > 0)] / Σ_i [δ(v_i > 0)]
83/// ```
84///
85/// where `d_i` is the Euclidean distance between predicted and GT keypoint `i`,
86/// `s` is the object scale (square root of area), and `k_i` is the per-keypoint
87/// constant from `sigmas`.
88///
89/// # Arguments
90/// * `predicted`    — predicted annotation
91/// * `ground_truth` — ground-truth annotation
92/// * `sigmas`       — per-keypoint sigma constants (same length as keypoints)
93pub fn object_keypoint_similarity(
94    predicted: &KeypointAnnotation,
95    ground_truth: &KeypointAnnotation,
96    sigmas: &[f64],
97) -> Result<f64> {
98    predicted.validate()?;
99    ground_truth.validate()?;
100
101    let n = predicted.keypoints.len();
102    if n == 0 {
103        return Err(MetricsError::InvalidInput(
104            "keypoint annotations must have at least one keypoint".to_string(),
105        ));
106    }
107    if ground_truth.keypoints.len() != n {
108        return Err(MetricsError::DimensionMismatch(format!(
109            "predicted has {n} keypoints but GT has {}",
110            ground_truth.keypoints.len()
111        )));
112    }
113    if sigmas.len() != n {
114        return Err(MetricsError::DimensionMismatch(format!(
115            "sigmas len {} != keypoints len {n}",
116            sigmas.len()
117        )));
118    }
119
120    let s = ground_truth.scale;
121    if s <= 0.0 {
122        return Err(MetricsError::InvalidInput(
123            "object scale must be positive".to_string(),
124        ));
125    }
126
127    let mut numerator = 0.0_f64;
128    let mut denominator = 0.0_f64;
129
130    for i in 0..n {
131        let v_gt = ground_truth.visibility[i];
132        if v_gt == 0 {
133            // Keypoint absent in GT — skip.
134            continue;
135        }
136        denominator += 1.0;
137
138        let [px, py] = predicted.keypoints[i];
139        let [gx, gy] = ground_truth.keypoints[i];
140        let d_sq = (px - gx).powi(2) + (py - gy).powi(2);
141        let ki = sigmas[i];
142        let e = -d_sq / (2.0 * s * s * ki * ki);
143        numerator += e.exp();
144    }
145
146    if denominator == 0.0 {
147        return Ok(0.0);
148    }
149    Ok(numerator / denominator)
150}
151
152// ─────────────────────────────────────────────────────────────────────────────
153// PCK / PCKh
154// ─────────────────────────────────────────────────────────────────────────────
155
156/// Percentage of Correct Keypoints (PCK) at threshold `t`.
157///
158/// A keypoint is "correct" when the predicted distance to GT is
159/// `< threshold_fraction * reference_distance`.
160///
161/// Only visible keypoints (`visibility[i] > 0`) are evaluated.
162pub fn pck(
163    predicted: &[[f64; 2]],
164    ground_truth: &[[f64; 2]],
165    visibility: &[u8],
166    threshold_fraction: f64,
167    reference_distance: f64,
168) -> Result<f64> {
169    let n = predicted.len();
170    if n == 0 {
171        return Err(MetricsError::InvalidInput(
172            "predicted keypoints must not be empty".to_string(),
173        ));
174    }
175    if ground_truth.len() != n || visibility.len() != n {
176        return Err(MetricsError::DimensionMismatch(format!(
177            "predicted, ground_truth and visibility must all have length {n}"
178        )));
179    }
180    if threshold_fraction <= 0.0 || reference_distance <= 0.0 {
181        return Err(MetricsError::InvalidInput(
182            "threshold_fraction and reference_distance must be positive".to_string(),
183        ));
184    }
185
186    let threshold = threshold_fraction * reference_distance;
187    let mut correct = 0usize;
188    let mut total = 0usize;
189
190    for i in 0..n {
191        if visibility[i] == 0 {
192            continue;
193        }
194        total += 1;
195        let [px, py] = predicted[i];
196        let [gx, gy] = ground_truth[i];
197        let dist = ((px - gx).powi(2) + (py - gy).powi(2)).sqrt();
198        if dist < threshold {
199            correct += 1;
200        }
201    }
202
203    if total == 0 {
204        return Ok(0.0);
205    }
206    Ok(correct as f64 / total as f64)
207}
208
209/// PCKh: PCK with threshold relative to head size.
210///
211/// A keypoint is correct when predicted distance < `threshold_fraction * head_size`.
212pub fn pckh(
213    predicted: &[[f64; 2]],
214    ground_truth: &[[f64; 2]],
215    visibility: &[u8],
216    head_size: f64,
217    threshold_fraction: f64,
218) -> Result<f64> {
219    pck(
220        predicted,
221        ground_truth,
222        visibility,
223        threshold_fraction,
224        head_size,
225    )
226}
227
228// ─────────────────────────────────────────────────────────────────────────────
229// Dataset-level metrics
230// ─────────────────────────────────────────────────────────────────────────────
231
232/// Mean OKS over a dataset.
233///
234/// Computes OKS for each (prediction, GT) pair and returns the mean.
235pub fn mean_oks(
236    predictions: &[KeypointAnnotation],
237    ground_truths: &[KeypointAnnotation],
238    sigmas: &[f64],
239) -> Result<f64> {
240    if predictions.is_empty() {
241        return Err(MetricsError::InvalidInput(
242            "predictions must not be empty".to_string(),
243        ));
244    }
245    if predictions.len() != ground_truths.len() {
246        return Err(MetricsError::DimensionMismatch(format!(
247            "predictions len {} != ground_truths len {}",
248            predictions.len(),
249            ground_truths.len()
250        )));
251    }
252    let total: f64 = predictions
253        .iter()
254        .zip(ground_truths)
255        .map(|(pred, gt)| object_keypoint_similarity(pred, gt, sigmas))
256        .sum::<Result<f64>>()?;
257    Ok(total / predictions.len() as f64)
258}
259
260/// Mean Euclidean keypoint error for visible keypoints.
261///
262/// Returns the average distance between predicted and GT keypoints
263/// where `visibility[i] > 0`.
264pub fn mean_keypoint_error(
265    predicted: &[[f64; 2]],
266    ground_truth: &[[f64; 2]],
267    visibility: &[u8],
268) -> Result<f64> {
269    let n = predicted.len();
270    if n == 0 {
271        return Err(MetricsError::InvalidInput(
272            "predicted keypoints must not be empty".to_string(),
273        ));
274    }
275    if ground_truth.len() != n || visibility.len() != n {
276        return Err(MetricsError::DimensionMismatch(format!(
277            "predicted, ground_truth and visibility must all have length {n}"
278        )));
279    }
280
281    let mut total_dist = 0.0_f64;
282    let mut count = 0usize;
283
284    for i in 0..n {
285        if visibility[i] == 0 {
286            continue;
287        }
288        let [px, py] = predicted[i];
289        let [gx, gy] = ground_truth[i];
290        total_dist += ((px - gx).powi(2) + (py - gy).powi(2)).sqrt();
291        count += 1;
292    }
293
294    if count == 0 {
295        return Ok(0.0);
296    }
297    Ok(total_dist / count as f64)
298}
299
300// ─────────────────────────────────────────────────────────────────────────────
301// Tests
302// ─────────────────────────────────────────────────────────────────────────────
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    fn make_annotation(kps: Vec<[f64; 2]>, vis: Vec<u8>, scale: f64) -> KeypointAnnotation {
309        KeypointAnnotation {
310            keypoints: kps,
311            visibility: vis,
312            scale,
313        }
314    }
315
316    #[test]
317    fn test_oks_perfect_prediction() {
318        let kps = vec![[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]];
319        let vis = vec![2, 2, 2];
320        let sigmas = vec![0.05, 0.05, 0.05];
321        let pred = make_annotation(kps.clone(), vis.clone(), 50.0);
322        let gt = make_annotation(kps, vis, 50.0);
323        let oks = object_keypoint_similarity(&pred, &gt, &sigmas).expect("should succeed");
324        assert!(
325            (oks - 1.0).abs() < 1e-10,
326            "perfect OKS should be 1.0, got {oks}"
327        );
328    }
329
330    #[test]
331    fn test_oks_large_distance_near_zero() {
332        let gt_kps = vec![[0.0, 0.0], [0.0, 0.0]];
333        let pred_kps = vec![[1000.0, 1000.0], [1000.0, 1000.0]];
334        let vis = vec![2, 2];
335        let sigmas = vec![0.05, 0.05];
336        let pred = make_annotation(pred_kps, vis.clone(), 1.0);
337        let gt = make_annotation(gt_kps, vis, 1.0);
338        let oks = object_keypoint_similarity(&pred, &gt, &sigmas).expect("should succeed");
339        assert!(
340            oks < 1e-6,
341            "OKS for very large error should be ~0, got {oks}"
342        );
343    }
344
345    #[test]
346    fn test_pck_all_correct() {
347        let kps: Vec<[f64; 2]> = vec![[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]];
348        let vis = vec![2, 2, 2];
349        let score = pck(&kps, &kps, &vis, 0.1, 100.0).expect("should succeed");
350        assert!(
351            (score - 1.0).abs() < 1e-12,
352            "perfect PCK should be 1.0, got {score}"
353        );
354    }
355
356    #[test]
357    fn test_pck_none_correct() {
358        let pred = vec![[0.0, 0.0], [0.0, 0.0]];
359        let gt = vec![[100.0, 100.0], [200.0, 200.0]];
360        let vis = vec![2, 2];
361        // threshold = 0.01 * 1.0 = 0.01; distances are >> 0.01
362        let score = pck(&pred, &gt, &vis, 0.01, 1.0).expect("should succeed");
363        assert!((score - 0.0).abs() < 1e-12, "expected PCK=0, got {score}");
364    }
365
366    #[test]
367    fn test_pckh_head_size_reference() {
368        let pred = vec![[10.0, 10.0], [20.0, 20.0]];
369        let gt = vec![[10.0, 10.0], [20.0, 20.0]];
370        let vis = vec![2, 2];
371        let score = pckh(&pred, &gt, &vis, 200.0, 0.5).expect("should succeed");
372        assert!(
373            (score - 1.0).abs() < 1e-12,
374            "expected PCKh=1.0, got {score}"
375        );
376    }
377
378    #[test]
379    fn test_mean_oks_batch() {
380        let kps = vec![[5.0, 5.0], [10.0, 10.0]];
381        let vis = vec![2, 2];
382        let sigmas = vec![0.05, 0.05];
383        let ann = make_annotation(kps.clone(), vis.clone(), 20.0);
384        let predictions = vec![ann.clone(), ann.clone()];
385        let ground_truths = vec![ann.clone(), ann];
386        let moks = mean_oks(&predictions, &ground_truths, &sigmas).expect("should succeed");
387        assert!(
388            (moks - 1.0).abs() < 1e-10,
389            "mean OKS for perfect predictions should be 1.0, got {moks}"
390        );
391    }
392
393    #[test]
394    fn test_mean_keypoint_error_perfect() {
395        let kps = vec![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
396        let vis = vec![2, 2, 2];
397        let err = mean_keypoint_error(&kps, &kps, &vis).expect("should succeed");
398        assert!(
399            err.abs() < 1e-12,
400            "perfect predictions → error = 0, got {err}"
401        );
402    }
403
404    #[test]
405    fn test_coco_sigmas_returns_17() {
406        let s = coco_sigmas();
407        assert_eq!(s.len(), 17, "COCO has 17 body keypoints, got {}", s.len());
408        for (i, &sigma) in s.iter().enumerate() {
409            assert!(sigma > 0.0, "sigma[{i}] must be positive, got {sigma}");
410        }
411    }
412
413    #[test]
414    fn test_oks_invisible_keypoints_excluded() {
415        // GT visibility = 0 for first keypoint → should not contribute to denominator
416        let pred_kps = vec![[999.0, 999.0], [10.0, 10.0]];
417        let gt_kps = vec![[0.0, 0.0], [10.0, 10.0]];
418        let vis_gt = vec![0, 2]; // first invisible
419        let sigmas = vec![0.05, 0.05];
420        let pred = KeypointAnnotation {
421            keypoints: pred_kps,
422            visibility: vec![2, 2],
423            scale: 50.0,
424        };
425        let gt = KeypointAnnotation {
426            keypoints: gt_kps,
427            visibility: vis_gt,
428            scale: 50.0,
429        };
430        let oks = object_keypoint_similarity(&pred, &gt, &sigmas).expect("should succeed");
431        // Only second keypoint evaluated; perfect match → OKS = 1.0
432        assert!(
433            (oks - 1.0).abs() < 1e-10,
434            "invisible GT keypoints should be excluded, OKS={oks}"
435        );
436    }
437}