scirs2_metrics/integration/traits.rs
1//! Traits for integration with other modules
2//!
3//! This module defines traits that serve as an abstraction layer for integration
4//! with other scirs2 modules without creating direct dependencies. These traits
5//! are implemented conditionally based on feature flags.
6
7use crate::error::MetricsError;
8use scirs2_core::ndarray::{Array, IxDyn};
9use scirs2_core::numeric::Float;
10use std::fmt::{Debug, Display};
11
12/// Trait for metrics that can be computed on neural network predictions and targets
13pub trait MetricComputation<F: Float + Debug + Display> {
14 /// Compute the metric value from predictions and targets
15 fn compute(
16 &self,
17 predictions: &Array<F, IxDyn>,
18 targets: &Array<F, IxDyn>,
19 ) -> Result<F, MetricsError>;
20
21 /// Get the name of the metric
22 fn name(&self) -> &str;
23}
24
25/// Trait for callbacks that can track metrics during training
26pub trait MetricCallback<F: Float + Debug + Display> {
27 /// Initialize the callback at the start of training
28 fn on_train_begin(&mut self);
29
30 /// Update with batch results
31 fn on_batch_end(
32 &mut self,
33 batch: usize,
34 predictions: &Array<F, IxDyn>,
35 targets: &Array<F, IxDyn>,
36 );
37
38 /// Finalize metrics at the end of an epoch
39 fn on_epoch_end(&mut self, epoch: usize) -> Result<(), MetricsError>;
40
41 /// Clean up at the end of training
42 fn on_train_end(&mut self);
43}