1use crate::error::Result as MetricsResult;
84use scirs2_core::ndarray::Array1;
85use scirs2_core::numeric::Float;
86use std::fmt;
87
88pub type MetricResult<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
90
91pub trait ClassificationMetric<F: Float> {
93 fn name(&self) -> &'static str;
95
96 fn compute(&self, y_true: &Array1<i32>, ypred: &Array1<i32>) -> MetricResult<F>;
98
99 fn higher_is_better(&self) -> bool;
101
102 fn description(&self) -> Option<&'static str> {
104 None
105 }
106
107 fn value_range(&self) -> Option<(F, F)> {
109 None
110 }
111}
112
113pub trait RegressionMetric<F: Float> {
115 fn name(&self) -> &'static str;
117
118 fn compute(&self, y_true: &Array1<F>, ypred: &Array1<F>) -> MetricResult<F>;
120
121 fn higher_is_better(&self) -> bool;
123
124 fn description(&self) -> Option<&'static str> {
126 None
127 }
128
129 fn value_range(&self) -> Option<(F, F)> {
131 None
132 }
133}
134
135pub trait ClusteringMetric<F: Float> {
137 fn name(&self) -> &'static str;
139
140 fn compute(&self, data: &Array1<F>, labels: &Array1<i32>) -> MetricResult<F>;
142
143 fn higher_is_better(&self) -> bool;
145
146 fn description(&self) -> Option<&'static str> {
148 None
149 }
150
151 fn value_range(&self) -> Option<(F, F)> {
153 None
154 }
155}
156
157pub struct CustomMetricSuite<F: Float> {
159 classification_metrics: Vec<Box<dyn ClassificationMetric<F> + Send + Sync>>,
160 regression_metrics: Vec<Box<dyn RegressionMetric<F> + Send + Sync>>,
161 clustering_metrics: Vec<Box<dyn ClusteringMetric<F> + Send + Sync>>,
162}
163
164impl<F: Float> Default for CustomMetricSuite<F> {
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170impl<F: Float> CustomMetricSuite<F> {
171 pub fn new() -> Self {
173 Self {
174 classification_metrics: Vec::new(),
175 regression_metrics: Vec::new(),
176 clustering_metrics: Vec::new(),
177 }
178 }
179
180 pub fn add_classification_metric<M>(&mut self, metric: M) -> &mut Self
182 where
183 M: ClassificationMetric<F> + Send + Sync + 'static,
184 {
185 self.classification_metrics.push(Box::new(metric));
186 self
187 }
188
189 pub fn add_regression_metric<M>(&mut self, metric: M) -> &mut Self
191 where
192 M: RegressionMetric<F> + Send + Sync + 'static,
193 {
194 self.regression_metrics.push(Box::new(metric));
195 self
196 }
197
198 pub fn add_clustering_metric<M>(&mut self, metric: M) -> &mut Self
200 where
201 M: ClusteringMetric<F> + Send + Sync + 'static,
202 {
203 self.clustering_metrics.push(Box::new(metric));
204 self
205 }
206
207 pub fn evaluate_classification(
209 &self,
210 y_true: &Array1<i32>,
211 ypred: &Array1<i32>,
212 ) -> MetricsResult<CustomMetricResults<F>> {
213 let mut results = CustomMetricResults::new("classification");
214
215 for metric in &self.classification_metrics {
216 match metric.compute(y_true, ypred) {
217 Ok(value) => {
218 results.add_result(metric.name(), value, metric.higher_is_better());
219 }
220 Err(e) => {
221 eprintln!("Warning: Failed to compute {}: {}", metric.name(), e);
222 }
223 }
224 }
225
226 Ok(results)
227 }
228
229 pub fn evaluate_regression(
231 &self,
232 y_true: &Array1<F>,
233 ypred: &Array1<F>,
234 ) -> MetricsResult<CustomMetricResults<F>> {
235 let mut results = CustomMetricResults::new("regression");
236
237 for metric in &self.regression_metrics {
238 match metric.compute(y_true, ypred) {
239 Ok(value) => {
240 results.add_result(metric.name(), value, metric.higher_is_better());
241 }
242 Err(e) => {
243 eprintln!("Warning: Failed to compute {}: {}", metric.name(), e);
244 }
245 }
246 }
247
248 Ok(results)
249 }
250
251 pub fn evaluate_clustering(
253 &self,
254 data: &Array1<F>,
255 labels: &Array1<i32>,
256 ) -> MetricsResult<CustomMetricResults<F>> {
257 let mut results = CustomMetricResults::new("clustering");
258
259 for metric in &self.clustering_metrics {
260 match metric.compute(data, labels) {
261 Ok(value) => {
262 results.add_result(metric.name(), value, metric.higher_is_better());
263 }
264 Err(e) => {
265 eprintln!("Warning: Failed to compute {}: {}", metric.name(), e);
266 }
267 }
268 }
269
270 Ok(results)
271 }
272
273 pub fn metric_names(&self) -> Vec<String> {
275 let mut names = Vec::new();
276
277 for metric in &self.classification_metrics {
278 names.push(format!("classification:{}", metric.name()));
279 }
280
281 for metric in &self.regression_metrics {
282 names.push(format!("regression:{}", metric.name()));
283 }
284
285 for metric in &self.clustering_metrics {
286 names.push(format!("clustering:{}", metric.name()));
287 }
288
289 names
290 }
291}
292
293#[derive(Debug, Clone)]
295pub struct CustomMetricResults<F: Float> {
296 metric_type: String,
297 results: Vec<CustomMetricResult<F>>,
298}
299
300#[derive(Debug, Clone)]
301pub struct CustomMetricResult<F: Float> {
302 pub name: String,
303 pub value: F,
304 pub higher_is_better: bool,
305}
306
307impl<F: Float> CustomMetricResults<F> {
308 pub fn new(_metrictype: &str) -> Self {
310 Self {
311 metric_type: _metrictype.to_string(),
312 results: Vec::new(),
313 }
314 }
315
316 pub fn add_result(&mut self, name: &str, value: F, higher_isbetter: bool) {
318 self.results.push(CustomMetricResult {
319 name: name.to_string(),
320 value,
321 higher_is_better: higher_isbetter,
322 });
323 }
324
325 pub fn results(&self) -> &[CustomMetricResult<F>] {
327 &self.results
328 }
329
330 pub fn metric_type(&self) -> &str {
332 &self.metric_type
333 }
334
335 pub fn get(&self, name: &str) -> Option<&CustomMetricResult<F>> {
337 self.results.iter().find(|r| r.name == name)
338 }
339
340 pub fn best_result(&self) -> Option<&CustomMetricResult<F>> {
342 self.results.iter().max_by(|a, b| {
343 let a_val = if a.higher_is_better {
344 a.value
345 } else {
346 -a.value
347 };
348 let b_val = if b.higher_is_better {
349 b.value
350 } else {
351 -b.value
352 };
353 a_val
354 .partial_cmp(&b_val)
355 .unwrap_or(std::cmp::Ordering::Equal)
356 })
357 }
358}
359
360impl<F: Float + fmt::Display> fmt::Display for CustomMetricResults<F> {
361 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362 writeln!(f, "Custom {} Metrics:", self.metric_type)?;
363 writeln!(f, "{:-<50}", "")?;
364
365 for result in &self.results {
366 let direction = if result.higher_is_better {
367 "↑"
368 } else {
369 "↓"
370 };
371 writeln!(
372 f,
373 "{:<30} {:<15} {}",
374 result.name,
375 format!("{:.6}", result.value),
376 direction
377 )?;
378 }
379
380 Ok(())
381 }
382}
383
384#[macro_export]
386macro_rules! classification_metric {
387 ($name:ident, $metric_name:expr, $higher_is_better:expr, $compute:expr) => {
388 struct $name;
389
390 impl $crate::custom::ClassificationMetric<f64> for $name {
391 fn name(&self) -> &'static str {
392 $metric_name
393 }
394
395 fn compute(
396 &self,
397 y_true: &scirs2_core::ndarray::Array1<i32>,
398 ypred: &scirs2_core::ndarray::Array1<i32>,
399 ) -> $crate::custom::MetricResult<f64> {
400 $compute(y_true, ypred)
401 }
402
403 fn higher_is_better(&self) -> bool {
404 $higher_is_better
405 }
406 }
407 };
408}
409
410#[macro_export]
412macro_rules! regression_metric {
413 ($name:ident, $metric_name:expr, $higher_is_better:expr, $compute:expr) => {
414 struct $name;
415
416 impl $crate::custom::RegressionMetric<f64> for $name {
417 fn name(&self) -> &'static str {
418 $metric_name
419 }
420
421 fn compute(
422 &self,
423 y_true: &scirs2_core::ndarray::Array1<f64>,
424 ypred: &scirs2_core::ndarray::Array1<f64>,
425 ) -> $crate::custom::MetricResult<f64> {
426 $compute(y_true, ypred)
427 }
428
429 fn higher_is_better(&self) -> bool {
430 $higher_is_better
431 }
432 }
433 };
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use scirs2_core::ndarray::array;
440
441 struct TestAccuracy;
442
443 impl ClassificationMetric<f64> for TestAccuracy {
444 fn name(&self) -> &'static str {
445 "test_accuracy"
446 }
447
448 fn compute(&self, y_true: &Array1<i32>, ypred: &Array1<i32>) -> MetricResult<f64> {
449 if y_true.len() != ypred.len() {
450 return Err("Length mismatch".into());
451 }
452
453 let correct = y_true
454 .iter()
455 .zip(ypred.iter())
456 .filter(|(a, b)| a == b)
457 .count();
458
459 Ok(correct as f64 / y_true.len() as f64)
460 }
461
462 fn higher_is_better(&self) -> bool {
463 true
464 }
465 }
466
467 struct TestMSE;
468
469 impl RegressionMetric<f64> for TestMSE {
470 fn name(&self) -> &'static str {
471 "test_mse"
472 }
473
474 fn compute(&self, y_true: &Array1<f64>, ypred: &Array1<f64>) -> MetricResult<f64> {
475 if y_true.len() != ypred.len() {
476 return Err("Length mismatch".into());
477 }
478
479 let mse = y_true
480 .iter()
481 .zip(ypred.iter())
482 .map(|(a, b)| (a - b).powi(2))
483 .sum::<f64>()
484 / y_true.len() as f64;
485
486 Ok(mse)
487 }
488
489 fn higher_is_better(&self) -> bool {
490 false
491 }
492 }
493
494 #[test]
495 fn test_custom_classification_metric() {
496 let metric = TestAccuracy;
497 let y_true = array![1, 0, 1, 1, 0];
498 let ypred = array![1, 0, 0, 1, 0];
499
500 let result = metric.compute(&y_true, &ypred).unwrap();
501 assert_eq!(result, 0.8);
502 assert!(metric.higher_is_better());
503 }
504
505 #[test]
506 fn test_custom_regression_metric() {
507 let metric = TestMSE;
508 let y_true = array![1.0, 2.0, 3.0];
509 let ypred = array![1.1, 2.1, 2.9];
510
511 let result = metric.compute(&y_true, &ypred).unwrap();
512 assert!((result - 0.01).abs() < 1e-10);
514 assert!(!metric.higher_is_better());
515 }
516
517 #[test]
518 fn test_metric_suite() {
519 let mut suite = CustomMetricSuite::new();
520 suite.add_classification_metric(TestAccuracy);
521 suite.add_regression_metric(TestMSE);
522
523 let y_true_cls = array![1, 0, 1, 1, 0];
525 let ypred_cls = array![1, 0, 0, 1, 0];
526 let cls_results = suite
527 .evaluate_classification(&y_true_cls, &ypred_cls)
528 .unwrap();
529
530 assert_eq!(cls_results.results().len(), 1);
531 assert_eq!(cls_results.get("test_accuracy").unwrap().value, 0.8);
532
533 let y_true_reg = array![1.0, 2.0, 3.0];
535 let ypred_reg = array![1.1, 2.1, 2.9];
536 let reg_results = suite.evaluate_regression(&y_true_reg, &ypred_reg).unwrap();
537
538 assert_eq!(reg_results.results().len(), 1);
539 assert!((reg_results.get("test_mse").unwrap().value - 0.01).abs() < 1e-10);
540 }
541
542 #[test]
543 fn test_metric_names() {
544 let mut suite = CustomMetricSuite::new();
545 suite.add_classification_metric(TestAccuracy);
546 suite.add_regression_metric(TestMSE);
547
548 let names = suite.metric_names();
549 assert_eq!(names.len(), 2);
550 assert!(names.contains(&"classification:test_accuracy".to_string()));
551 assert!(names.contains(&"regression:test_mse".to_string()));
552 }
553
554 #[test]
555 fn test_best_result() {
556 let mut results = CustomMetricResults::new("test");
557 results.add_result("metric1", 0.8, true); results.add_result("metric2", 0.2, false); let best = results.best_result().unwrap();
561 assert_eq!(best.name, "metric1");
564 }
565}