scirs2_metrics/custom/
mod.rs

1//! Custom metric definition framework
2//!
3//! This module provides a framework for defining custom metrics that integrate
4//! seamlessly with the rest of the scirs2-metrics ecosystem.
5//!
6//! # Features
7//!
8//! - **Trait-based design**: Define custom metrics by implementing simple traits
9//! - **Type safety**: Leverage Rust's type system for metric validation
10//! - **Integration**: Custom metrics work with evaluation pipelines and visualization
11//! - **Performance**: Zero-cost abstractions with compile-time optimization
12//! - **Composability**: Combine custom metrics with built-in metrics
13//!
14//! # Examples
15//!
16//! ## Defining a Custom Classification Metric
17//!
18//! ```
19//! use scirs2_metrics::custom::{ClassificationMetric, MetricResult};
20//! use scirs2_core::ndarray::Array1;
21//!
22//! struct CustomAccuracy;
23//!
24//! impl ClassificationMetric<f64> for CustomAccuracy {
25//!     fn name(&self) -> &'static str {
26//!         "custom_accuracy"
27//!     }
28//!
29//!     fn compute(&self, y_true: &Array1<i32>, ypred: &Array1<i32>) -> MetricResult<f64> {
30//!         if y_true.len() != ypred.len() {
31//!             return Err("Arrays must have the same length".into());
32//!         }
33//!
34//!         let correct = y_true.iter()
35//!             .zip(ypred.iter())
36//!             .filter(|(true_val, pred_val)| true_val == pred_val)
37//!             .count();
38//!
39//!         Ok(correct as f64 / y_true.len() as f64)
40//!     }
41//!
42//!     fn higher_is_better(&self) -> bool {
43//!         true
44//!     }
45//! }
46//! ```
47//!
48//! ## Defining a Custom Regression Metric
49//!
50//! ```
51//! use scirs2_metrics::custom::{RegressionMetric, MetricResult};
52//! use scirs2_core::ndarray::Array1;
53//!
54//! struct LogCoshError;
55//!
56//! impl RegressionMetric<f64> for LogCoshError {
57//!     fn name(&self) -> &'static str {
58//!         "log_cosh_error"
59//!     }
60//!
61//!     fn compute(&self, y_true: &Array1<f64>, ypred: &Array1<f64>) -> MetricResult<f64> {
62//!         if y_true.len() != ypred.len() {
63//!             return Err("Arrays must have the same length".into());
64//!         }
65//!
66//!         let error: f64 = y_true.iter()
67//!             .zip(ypred.iter())
68//!             .map(|(true_val, pred_val)| {
69//!                 let diff = pred_val - true_val;
70//!                 (diff.cosh()).ln()
71//!             })
72//!             .sum();
73//!
74//!         Ok(error / y_true.len() as f64)
75//!     }
76//!
77//!     fn higher_is_better(&self) -> bool {
78//!         false
79//!     }
80//! }
81//! ```
82
83use crate::error::Result as MetricsResult;
84use scirs2_core::ndarray::Array1;
85use scirs2_core::numeric::Float;
86use std::fmt;
87
88/// Result type for custom metric computations
89pub type MetricResult<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
90
91/// Trait for defining custom classification metrics
92pub trait ClassificationMetric<F: Float> {
93    /// Returns the name of the metric
94    fn name(&self) -> &'static str;
95
96    /// Computes the metric value given true and predicted labels
97    fn compute(&self, y_true: &Array1<i32>, ypred: &Array1<i32>) -> MetricResult<F>;
98
99    /// Returns whether higher values indicate better performance
100    fn higher_is_better(&self) -> bool;
101
102    /// Optional: Returns a description of the metric
103    fn description(&self) -> Option<&'static str> {
104        None
105    }
106
107    /// Optional: Returns the valid range of the metric (min, max)
108    fn value_range(&self) -> Option<(F, F)> {
109        None
110    }
111}
112
113/// Trait for defining custom regression metrics
114pub trait RegressionMetric<F: Float> {
115    /// Returns the name of the metric
116    fn name(&self) -> &'static str;
117
118    /// Computes the metric value given true and predicted values
119    fn compute(&self, y_true: &Array1<F>, ypred: &Array1<F>) -> MetricResult<F>;
120
121    /// Returns whether higher values indicate better performance
122    fn higher_is_better(&self) -> bool;
123
124    /// Optional: Returns a description of the metric
125    fn description(&self) -> Option<&'static str> {
126        None
127    }
128
129    /// Optional: Returns the valid range of the metric (min, max)
130    fn value_range(&self) -> Option<(F, F)> {
131        None
132    }
133}
134
135/// Trait for defining custom clustering metrics
136pub trait ClusteringMetric<F: Float> {
137    /// Returns the name of the metric
138    fn name(&self) -> &'static str;
139
140    /// Computes the metric value given data points and cluster labels
141    fn compute(&self, data: &Array1<F>, labels: &Array1<i32>) -> MetricResult<F>;
142
143    /// Returns whether higher values indicate better performance
144    fn higher_is_better(&self) -> bool;
145
146    /// Optional: Returns a description of the metric
147    fn description(&self) -> Option<&'static str> {
148        None
149    }
150
151    /// Optional: Returns the valid range of the metric (min, max)
152    fn value_range(&self) -> Option<(F, F)> {
153        None
154    }
155}
156
157/// A wrapper that combines multiple custom metrics into a single evaluator
158pub struct CustomMetricSuite<F: Float> {
159    classification_metrics: Vec<Box<dyn ClassificationMetric<F> + Send + Sync>>,
160    regression_metrics: Vec<Box<dyn RegressionMetric<F> + Send + Sync>>,
161    clustering_metrics: Vec<Box<dyn ClusteringMetric<F> + Send + Sync>>,
162}
163
164impl<F: Float> Default for CustomMetricSuite<F> {
165    fn default() -> Self {
166        Self::new()
167    }
168}
169
170impl<F: Float> CustomMetricSuite<F> {
171    /// Creates a new empty metric suite
172    pub fn new() -> Self {
173        Self {
174            classification_metrics: Vec::new(),
175            regression_metrics: Vec::new(),
176            clustering_metrics: Vec::new(),
177        }
178    }
179
180    /// Adds a classification metric to the suite
181    pub fn add_classification_metric<M>(&mut self, metric: M) -> &mut Self
182    where
183        M: ClassificationMetric<F> + Send + Sync + 'static,
184    {
185        self.classification_metrics.push(Box::new(metric));
186        self
187    }
188
189    /// Adds a regression metric to the suite
190    pub fn add_regression_metric<M>(&mut self, metric: M) -> &mut Self
191    where
192        M: RegressionMetric<F> + Send + Sync + 'static,
193    {
194        self.regression_metrics.push(Box::new(metric));
195        self
196    }
197
198    /// Adds a clustering metric to the suite
199    pub fn add_clustering_metric<M>(&mut self, metric: M) -> &mut Self
200    where
201        M: ClusteringMetric<F> + Send + Sync + 'static,
202    {
203        self.clustering_metrics.push(Box::new(metric));
204        self
205    }
206
207    /// Evaluates all classification metrics
208    pub fn evaluate_classification(
209        &self,
210        y_true: &Array1<i32>,
211        ypred: &Array1<i32>,
212    ) -> MetricsResult<CustomMetricResults<F>> {
213        let mut results = CustomMetricResults::new("classification");
214
215        for metric in &self.classification_metrics {
216            match metric.compute(y_true, ypred) {
217                Ok(value) => {
218                    results.add_result(metric.name(), value, metric.higher_is_better());
219                }
220                Err(e) => {
221                    eprintln!("Warning: Failed to compute {}: {}", metric.name(), e);
222                }
223            }
224        }
225
226        Ok(results)
227    }
228
229    /// Evaluates all regression metrics
230    pub fn evaluate_regression(
231        &self,
232        y_true: &Array1<F>,
233        ypred: &Array1<F>,
234    ) -> MetricsResult<CustomMetricResults<F>> {
235        let mut results = CustomMetricResults::new("regression");
236
237        for metric in &self.regression_metrics {
238            match metric.compute(y_true, ypred) {
239                Ok(value) => {
240                    results.add_result(metric.name(), value, metric.higher_is_better());
241                }
242                Err(e) => {
243                    eprintln!("Warning: Failed to compute {}: {}", metric.name(), e);
244                }
245            }
246        }
247
248        Ok(results)
249    }
250
251    /// Evaluates all clustering metrics
252    pub fn evaluate_clustering(
253        &self,
254        data: &Array1<F>,
255        labels: &Array1<i32>,
256    ) -> MetricsResult<CustomMetricResults<F>> {
257        let mut results = CustomMetricResults::new("clustering");
258
259        for metric in &self.clustering_metrics {
260            match metric.compute(data, labels) {
261                Ok(value) => {
262                    results.add_result(metric.name(), value, metric.higher_is_better());
263                }
264                Err(e) => {
265                    eprintln!("Warning: Failed to compute {}: {}", metric.name(), e);
266                }
267            }
268        }
269
270        Ok(results)
271    }
272
273    /// Gets the names of all registered metrics
274    pub fn metric_names(&self) -> Vec<String> {
275        let mut names = Vec::new();
276
277        for metric in &self.classification_metrics {
278            names.push(format!("classification:{}", metric.name()));
279        }
280
281        for metric in &self.regression_metrics {
282            names.push(format!("regression:{}", metric.name()));
283        }
284
285        for metric in &self.clustering_metrics {
286            names.push(format!("clustering:{}", metric.name()));
287        }
288
289        names
290    }
291}
292
293/// Container for custom metric evaluation results
294#[derive(Debug, Clone)]
295pub struct CustomMetricResults<F: Float> {
296    metric_type: String,
297    results: Vec<CustomMetricResult<F>>,
298}
299
300#[derive(Debug, Clone)]
301pub struct CustomMetricResult<F: Float> {
302    pub name: String,
303    pub value: F,
304    pub higher_is_better: bool,
305}
306
307impl<F: Float> CustomMetricResults<F> {
308    /// Creates a new results container
309    pub fn new(_metrictype: &str) -> Self {
310        Self {
311            metric_type: _metrictype.to_string(),
312            results: Vec::new(),
313        }
314    }
315
316    /// Adds a metric result
317    pub fn add_result(&mut self, name: &str, value: F, higher_isbetter: bool) {
318        self.results.push(CustomMetricResult {
319            name: name.to_string(),
320            value,
321            higher_is_better: higher_isbetter,
322        });
323    }
324
325    /// Gets all results
326    pub fn results(&self) -> &[CustomMetricResult<F>] {
327        &self.results
328    }
329
330    /// Gets the metric type
331    pub fn metric_type(&self) -> &str {
332        &self.metric_type
333    }
334
335    /// Gets a specific result by name
336    pub fn get(&self, name: &str) -> Option<&CustomMetricResult<F>> {
337        self.results.iter().find(|r| r.name == name)
338    }
339
340    /// Gets the best result according to the metric's optimization direction
341    pub fn best_result(&self) -> Option<&CustomMetricResult<F>> {
342        self.results.iter().max_by(|a, b| {
343            let a_val = if a.higher_is_better {
344                a.value
345            } else {
346                -a.value
347            };
348            let b_val = if b.higher_is_better {
349                b.value
350            } else {
351                -b.value
352            };
353            a_val
354                .partial_cmp(&b_val)
355                .unwrap_or(std::cmp::Ordering::Equal)
356        })
357    }
358}
359
360impl<F: Float + fmt::Display> fmt::Display for CustomMetricResults<F> {
361    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362        writeln!(f, "Custom {} Metrics:", self.metric_type)?;
363        writeln!(f, "{:-<50}", "")?;
364
365        for result in &self.results {
366            let direction = if result.higher_is_better {
367                "↑"
368            } else {
369                "↓"
370            };
371            writeln!(
372                f,
373                "{:<30} {:<15} {}",
374                result.name,
375                format!("{:.6}", result.value),
376                direction
377            )?;
378        }
379
380        Ok(())
381    }
382}
383
384/// Macro for easy metric trait implementation
385#[macro_export]
386macro_rules! classification_metric {
387    ($name:ident, $metric_name:expr, $higher_is_better:expr, $compute:expr) => {
388        struct $name;
389
390        impl $crate::custom::ClassificationMetric<f64> for $name {
391            fn name(&self) -> &'static str {
392                $metric_name
393            }
394
395            fn compute(
396                &self,
397                y_true: &scirs2_core::ndarray::Array1<i32>,
398                ypred: &scirs2_core::ndarray::Array1<i32>,
399            ) -> $crate::custom::MetricResult<f64> {
400                $compute(y_true, ypred)
401            }
402
403            fn higher_is_better(&self) -> bool {
404                $higher_is_better
405            }
406        }
407    };
408}
409
410/// Macro for easy regression metric implementation
411#[macro_export]
412macro_rules! regression_metric {
413    ($name:ident, $metric_name:expr, $higher_is_better:expr, $compute:expr) => {
414        struct $name;
415
416        impl $crate::custom::RegressionMetric<f64> for $name {
417            fn name(&self) -> &'static str {
418                $metric_name
419            }
420
421            fn compute(
422                &self,
423                y_true: &scirs2_core::ndarray::Array1<f64>,
424                ypred: &scirs2_core::ndarray::Array1<f64>,
425            ) -> $crate::custom::MetricResult<f64> {
426                $compute(y_true, ypred)
427            }
428
429            fn higher_is_better(&self) -> bool {
430                $higher_is_better
431            }
432        }
433    };
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439    use scirs2_core::ndarray::array;
440
441    struct TestAccuracy;
442
443    impl ClassificationMetric<f64> for TestAccuracy {
444        fn name(&self) -> &'static str {
445            "test_accuracy"
446        }
447
448        fn compute(&self, y_true: &Array1<i32>, ypred: &Array1<i32>) -> MetricResult<f64> {
449            if y_true.len() != ypred.len() {
450                return Err("Length mismatch".into());
451            }
452
453            let correct = y_true
454                .iter()
455                .zip(ypred.iter())
456                .filter(|(a, b)| a == b)
457                .count();
458
459            Ok(correct as f64 / y_true.len() as f64)
460        }
461
462        fn higher_is_better(&self) -> bool {
463            true
464        }
465    }
466
467    struct TestMSE;
468
469    impl RegressionMetric<f64> for TestMSE {
470        fn name(&self) -> &'static str {
471            "test_mse"
472        }
473
474        fn compute(&self, y_true: &Array1<f64>, ypred: &Array1<f64>) -> MetricResult<f64> {
475            if y_true.len() != ypred.len() {
476                return Err("Length mismatch".into());
477            }
478
479            let mse = y_true
480                .iter()
481                .zip(ypred.iter())
482                .map(|(a, b)| (a - b).powi(2))
483                .sum::<f64>()
484                / y_true.len() as f64;
485
486            Ok(mse)
487        }
488
489        fn higher_is_better(&self) -> bool {
490            false
491        }
492    }
493
494    #[test]
495    fn test_custom_classification_metric() {
496        let metric = TestAccuracy;
497        let y_true = array![1, 0, 1, 1, 0];
498        let ypred = array![1, 0, 0, 1, 0];
499
500        let result = metric.compute(&y_true, &ypred).unwrap();
501        assert_eq!(result, 0.8);
502        assert!(metric.higher_is_better());
503    }
504
505    #[test]
506    fn test_custom_regression_metric() {
507        let metric = TestMSE;
508        let y_true = array![1.0, 2.0, 3.0];
509        let ypred = array![1.1, 2.1, 2.9];
510
511        let result = metric.compute(&y_true, &ypred).unwrap();
512        // MSE = ((1.0-1.1)² + (2.0-2.1)² + (3.0-2.9)²) / 3 = (0.01 + 0.01 + 0.01) / 3 = 0.01
513        assert!((result - 0.01).abs() < 1e-10);
514        assert!(!metric.higher_is_better());
515    }
516
517    #[test]
518    fn test_metric_suite() {
519        let mut suite = CustomMetricSuite::new();
520        suite.add_classification_metric(TestAccuracy);
521        suite.add_regression_metric(TestMSE);
522
523        // Test classification
524        let y_true_cls = array![1, 0, 1, 1, 0];
525        let ypred_cls = array![1, 0, 0, 1, 0];
526        let cls_results = suite
527            .evaluate_classification(&y_true_cls, &ypred_cls)
528            .unwrap();
529
530        assert_eq!(cls_results.results().len(), 1);
531        assert_eq!(cls_results.get("test_accuracy").unwrap().value, 0.8);
532
533        // Test regression
534        let y_true_reg = array![1.0, 2.0, 3.0];
535        let ypred_reg = array![1.1, 2.1, 2.9];
536        let reg_results = suite.evaluate_regression(&y_true_reg, &ypred_reg).unwrap();
537
538        assert_eq!(reg_results.results().len(), 1);
539        assert!((reg_results.get("test_mse").unwrap().value - 0.01).abs() < 1e-10);
540    }
541
542    #[test]
543    fn test_metric_names() {
544        let mut suite = CustomMetricSuite::new();
545        suite.add_classification_metric(TestAccuracy);
546        suite.add_regression_metric(TestMSE);
547
548        let names = suite.metric_names();
549        assert_eq!(names.len(), 2);
550        assert!(names.contains(&"classification:test_accuracy".to_string()));
551        assert!(names.contains(&"regression:test_mse".to_string()));
552    }
553
554    #[test]
555    fn test_best_result() {
556        let mut results = CustomMetricResults::new("test");
557        results.add_result("metric1", 0.8, true); // higher is better
558        results.add_result("metric2", 0.2, false); // lower is better
559
560        let best = results.best_result().unwrap();
561        // Both metrics are equally good in their respective directions
562        // but metric1 should be selected as it comes first
563        assert_eq!(best.name, "metric1");
564    }
565}