scirs2_metrics/integration/
traits.rs

1//! Traits for integration with other modules
2//!
3//! This module defines traits that serve as an abstraction layer for integration
4//! with other scirs2 modules without creating direct dependencies. These traits
5//! are implemented conditionally based on feature flags.
6
7use crate::error::MetricsError;
8use scirs2_core::ndarray::{Array, IxDyn};
9use scirs2_core::numeric::Float;
10use std::fmt::{Debug, Display};
11
12/// Trait for metrics that can be computed on neural network predictions and targets
13pub trait MetricComputation<F: Float + Debug + Display> {
14    /// Compute the metric value from predictions and targets
15    fn compute(
16        &self,
17        predictions: &Array<F, IxDyn>,
18        targets: &Array<F, IxDyn>,
19    ) -> Result<F, MetricsError>;
20
21    /// Get the name of the metric
22    fn name(&self) -> &str;
23}
24
25/// Trait for callbacks that can track metrics during training
26pub trait MetricCallback<F: Float + Debug + Display> {
27    /// Initialize the callback at the start of training
28    fn on_train_begin(&mut self);
29
30    /// Update with batch results
31    fn on_batch_end(
32        &mut self,
33        batch: usize,
34        predictions: &Array<F, IxDyn>,
35        targets: &Array<F, IxDyn>,
36    );
37
38    /// Finalize metrics at the end of an epoch
39    fn on_epoch_end(&mut self, epoch: usize) -> Result<(), MetricsError>;
40
41    /// Clean up at the end of training
42    fn on_train_end(&mut self);
43}