1use crate::error::{MetricsError, Result};
13
14#[derive(Debug, Clone)]
20pub struct KeypointAnnotation {
21 pub keypoints: Vec<[f64; 2]>,
23 pub visibility: Vec<u8>,
25 pub scale: f64,
27}
28
29impl KeypointAnnotation {
30 pub fn validate(&self) -> Result<()> {
32 if self.keypoints.len() != self.visibility.len() {
33 return Err(MetricsError::DimensionMismatch(format!(
34 "keypoints len {} != visibility len {}",
35 self.keypoints.len(),
36 self.visibility.len()
37 )));
38 }
39 Ok(())
40 }
41}
42
43pub fn coco_sigmas() -> Vec<f64> {
54 vec![
55 0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089, ]
73}
74
75pub fn object_keypoint_similarity(
94 predicted: &KeypointAnnotation,
95 ground_truth: &KeypointAnnotation,
96 sigmas: &[f64],
97) -> Result<f64> {
98 predicted.validate()?;
99 ground_truth.validate()?;
100
101 let n = predicted.keypoints.len();
102 if n == 0 {
103 return Err(MetricsError::InvalidInput(
104 "keypoint annotations must have at least one keypoint".to_string(),
105 ));
106 }
107 if ground_truth.keypoints.len() != n {
108 return Err(MetricsError::DimensionMismatch(format!(
109 "predicted has {n} keypoints but GT has {}",
110 ground_truth.keypoints.len()
111 )));
112 }
113 if sigmas.len() != n {
114 return Err(MetricsError::DimensionMismatch(format!(
115 "sigmas len {} != keypoints len {n}",
116 sigmas.len()
117 )));
118 }
119
120 let s = ground_truth.scale;
121 if s <= 0.0 {
122 return Err(MetricsError::InvalidInput(
123 "object scale must be positive".to_string(),
124 ));
125 }
126
127 let mut numerator = 0.0_f64;
128 let mut denominator = 0.0_f64;
129
130 for i in 0..n {
131 let v_gt = ground_truth.visibility[i];
132 if v_gt == 0 {
133 continue;
135 }
136 denominator += 1.0;
137
138 let [px, py] = predicted.keypoints[i];
139 let [gx, gy] = ground_truth.keypoints[i];
140 let d_sq = (px - gx).powi(2) + (py - gy).powi(2);
141 let ki = sigmas[i];
142 let e = -d_sq / (2.0 * s * s * ki * ki);
143 numerator += e.exp();
144 }
145
146 if denominator == 0.0 {
147 return Ok(0.0);
148 }
149 Ok(numerator / denominator)
150}
151
152pub fn pck(
163 predicted: &[[f64; 2]],
164 ground_truth: &[[f64; 2]],
165 visibility: &[u8],
166 threshold_fraction: f64,
167 reference_distance: f64,
168) -> Result<f64> {
169 let n = predicted.len();
170 if n == 0 {
171 return Err(MetricsError::InvalidInput(
172 "predicted keypoints must not be empty".to_string(),
173 ));
174 }
175 if ground_truth.len() != n || visibility.len() != n {
176 return Err(MetricsError::DimensionMismatch(format!(
177 "predicted, ground_truth and visibility must all have length {n}"
178 )));
179 }
180 if threshold_fraction <= 0.0 || reference_distance <= 0.0 {
181 return Err(MetricsError::InvalidInput(
182 "threshold_fraction and reference_distance must be positive".to_string(),
183 ));
184 }
185
186 let threshold = threshold_fraction * reference_distance;
187 let mut correct = 0usize;
188 let mut total = 0usize;
189
190 for i in 0..n {
191 if visibility[i] == 0 {
192 continue;
193 }
194 total += 1;
195 let [px, py] = predicted[i];
196 let [gx, gy] = ground_truth[i];
197 let dist = ((px - gx).powi(2) + (py - gy).powi(2)).sqrt();
198 if dist < threshold {
199 correct += 1;
200 }
201 }
202
203 if total == 0 {
204 return Ok(0.0);
205 }
206 Ok(correct as f64 / total as f64)
207}
208
209pub fn pckh(
213 predicted: &[[f64; 2]],
214 ground_truth: &[[f64; 2]],
215 visibility: &[u8],
216 head_size: f64,
217 threshold_fraction: f64,
218) -> Result<f64> {
219 pck(
220 predicted,
221 ground_truth,
222 visibility,
223 threshold_fraction,
224 head_size,
225 )
226}
227
228pub fn mean_oks(
236 predictions: &[KeypointAnnotation],
237 ground_truths: &[KeypointAnnotation],
238 sigmas: &[f64],
239) -> Result<f64> {
240 if predictions.is_empty() {
241 return Err(MetricsError::InvalidInput(
242 "predictions must not be empty".to_string(),
243 ));
244 }
245 if predictions.len() != ground_truths.len() {
246 return Err(MetricsError::DimensionMismatch(format!(
247 "predictions len {} != ground_truths len {}",
248 predictions.len(),
249 ground_truths.len()
250 )));
251 }
252 let total: f64 = predictions
253 .iter()
254 .zip(ground_truths)
255 .map(|(pred, gt)| object_keypoint_similarity(pred, gt, sigmas))
256 .sum::<Result<f64>>()?;
257 Ok(total / predictions.len() as f64)
258}
259
260pub fn mean_keypoint_error(
265 predicted: &[[f64; 2]],
266 ground_truth: &[[f64; 2]],
267 visibility: &[u8],
268) -> Result<f64> {
269 let n = predicted.len();
270 if n == 0 {
271 return Err(MetricsError::InvalidInput(
272 "predicted keypoints must not be empty".to_string(),
273 ));
274 }
275 if ground_truth.len() != n || visibility.len() != n {
276 return Err(MetricsError::DimensionMismatch(format!(
277 "predicted, ground_truth and visibility must all have length {n}"
278 )));
279 }
280
281 let mut total_dist = 0.0_f64;
282 let mut count = 0usize;
283
284 for i in 0..n {
285 if visibility[i] == 0 {
286 continue;
287 }
288 let [px, py] = predicted[i];
289 let [gx, gy] = ground_truth[i];
290 total_dist += ((px - gx).powi(2) + (py - gy).powi(2)).sqrt();
291 count += 1;
292 }
293
294 if count == 0 {
295 return Ok(0.0);
296 }
297 Ok(total_dist / count as f64)
298}
299
300#[cfg(test)]
305mod tests {
306 use super::*;
307
308 fn make_annotation(kps: Vec<[f64; 2]>, vis: Vec<u8>, scale: f64) -> KeypointAnnotation {
309 KeypointAnnotation {
310 keypoints: kps,
311 visibility: vis,
312 scale,
313 }
314 }
315
316 #[test]
317 fn test_oks_perfect_prediction() {
318 let kps = vec![[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]];
319 let vis = vec![2, 2, 2];
320 let sigmas = vec![0.05, 0.05, 0.05];
321 let pred = make_annotation(kps.clone(), vis.clone(), 50.0);
322 let gt = make_annotation(kps, vis, 50.0);
323 let oks = object_keypoint_similarity(&pred, >, &sigmas).expect("should succeed");
324 assert!(
325 (oks - 1.0).abs() < 1e-10,
326 "perfect OKS should be 1.0, got {oks}"
327 );
328 }
329
330 #[test]
331 fn test_oks_large_distance_near_zero() {
332 let gt_kps = vec![[0.0, 0.0], [0.0, 0.0]];
333 let pred_kps = vec![[1000.0, 1000.0], [1000.0, 1000.0]];
334 let vis = vec![2, 2];
335 let sigmas = vec![0.05, 0.05];
336 let pred = make_annotation(pred_kps, vis.clone(), 1.0);
337 let gt = make_annotation(gt_kps, vis, 1.0);
338 let oks = object_keypoint_similarity(&pred, >, &sigmas).expect("should succeed");
339 assert!(
340 oks < 1e-6,
341 "OKS for very large error should be ~0, got {oks}"
342 );
343 }
344
345 #[test]
346 fn test_pck_all_correct() {
347 let kps: Vec<[f64; 2]> = vec![[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]];
348 let vis = vec![2, 2, 2];
349 let score = pck(&kps, &kps, &vis, 0.1, 100.0).expect("should succeed");
350 assert!(
351 (score - 1.0).abs() < 1e-12,
352 "perfect PCK should be 1.0, got {score}"
353 );
354 }
355
356 #[test]
357 fn test_pck_none_correct() {
358 let pred = vec![[0.0, 0.0], [0.0, 0.0]];
359 let gt = vec![[100.0, 100.0], [200.0, 200.0]];
360 let vis = vec![2, 2];
361 let score = pck(&pred, >, &vis, 0.01, 1.0).expect("should succeed");
363 assert!((score - 0.0).abs() < 1e-12, "expected PCK=0, got {score}");
364 }
365
366 #[test]
367 fn test_pckh_head_size_reference() {
368 let pred = vec![[10.0, 10.0], [20.0, 20.0]];
369 let gt = vec![[10.0, 10.0], [20.0, 20.0]];
370 let vis = vec![2, 2];
371 let score = pckh(&pred, >, &vis, 200.0, 0.5).expect("should succeed");
372 assert!(
373 (score - 1.0).abs() < 1e-12,
374 "expected PCKh=1.0, got {score}"
375 );
376 }
377
378 #[test]
379 fn test_mean_oks_batch() {
380 let kps = vec![[5.0, 5.0], [10.0, 10.0]];
381 let vis = vec![2, 2];
382 let sigmas = vec![0.05, 0.05];
383 let ann = make_annotation(kps.clone(), vis.clone(), 20.0);
384 let predictions = vec![ann.clone(), ann.clone()];
385 let ground_truths = vec![ann.clone(), ann];
386 let moks = mean_oks(&predictions, &ground_truths, &sigmas).expect("should succeed");
387 assert!(
388 (moks - 1.0).abs() < 1e-10,
389 "mean OKS for perfect predictions should be 1.0, got {moks}"
390 );
391 }
392
393 #[test]
394 fn test_mean_keypoint_error_perfect() {
395 let kps = vec![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
396 let vis = vec![2, 2, 2];
397 let err = mean_keypoint_error(&kps, &kps, &vis).expect("should succeed");
398 assert!(
399 err.abs() < 1e-12,
400 "perfect predictions → error = 0, got {err}"
401 );
402 }
403
404 #[test]
405 fn test_coco_sigmas_returns_17() {
406 let s = coco_sigmas();
407 assert_eq!(s.len(), 17, "COCO has 17 body keypoints, got {}", s.len());
408 for (i, &sigma) in s.iter().enumerate() {
409 assert!(sigma > 0.0, "sigma[{i}] must be positive, got {sigma}");
410 }
411 }
412
413 #[test]
414 fn test_oks_invisible_keypoints_excluded() {
415 let pred_kps = vec![[999.0, 999.0], [10.0, 10.0]];
417 let gt_kps = vec![[0.0, 0.0], [10.0, 10.0]];
418 let vis_gt = vec![0, 2]; let sigmas = vec![0.05, 0.05];
420 let pred = KeypointAnnotation {
421 keypoints: pred_kps,
422 visibility: vec![2, 2],
423 scale: 50.0,
424 };
425 let gt = KeypointAnnotation {
426 keypoints: gt_kps,
427 visibility: vis_gt,
428 scale: 50.0,
429 };
430 let oks = object_keypoint_similarity(&pred, >, &sigmas).expect("should succeed");
431 assert!(
433 (oks - 1.0).abs() < 1e-10,
434 "invisible GT keypoints should be excluded, OKS={oks}"
435 );
436 }
437}