rusty_machine/learning/toolkit/
cost_fn.rs1use linalg::{Matrix, BaseMatrix, BaseMatrixMut};
12use linalg::Vector;
13
14pub trait CostFunc<T> {
16 fn cost(outputs: &T, targets: &T) -> f64;
18
19 fn grad_cost(outputs: &T, targets: &T) -> T;
21}
22
23#[derive(Clone, Copy, Debug)]
25pub struct MeanSqError;
26
27impl CostFunc<Matrix<f64>> for MeanSqError {
30 fn cost(outputs: &Matrix<f64>, targets: &Matrix<f64>) -> f64 {
31 let diff = outputs - targets;
32 let sq_diff = &diff.elemul(&diff);
33
34 let n = diff.rows();
35
36 sq_diff.sum() / (2f64 * (n as f64))
37 }
38
39 fn grad_cost(outputs: &Matrix<f64>, targets: &Matrix<f64>) -> Matrix<f64> {
40 outputs - targets
41 }
42}
43
44impl CostFunc<Vector<f64>> for MeanSqError {
45 fn cost(outputs: &Vector<f64>, targets: &Vector<f64>) -> f64 {
46 let diff = outputs - targets;
47 let sq_diff = &diff.elemul(&diff);
48
49 let n = diff.size();
50
51 sq_diff.sum() / (2f64 * (n as f64))
52 }
53
54 fn grad_cost(outputs: &Vector<f64>, targets: &Vector<f64>) -> Vector<f64> {
55 outputs - targets
56 }
57}
58
59#[derive(Clone, Copy, Debug)]
61pub struct CrossEntropyError;
62
63impl CostFunc<Matrix<f64>> for CrossEntropyError {
64 fn cost(outputs: &Matrix<f64>, targets: &Matrix<f64>) -> f64 {
65 let log_inv_output = (-outputs + 1f64).apply(&ln);
67 let log_output = outputs.clone().apply(&ln);
68
69 let mat_cost = targets.elemul(&log_output) + (-targets + 1f64).elemul(&log_inv_output);
70
71 let n = outputs.rows();
72
73 -(mat_cost.sum()) / (n as f64)
74 }
75
76 fn grad_cost(outputs: &Matrix<f64>, targets: &Matrix<f64>) -> Matrix<f64> {
77 (outputs - targets).elediv(&(outputs.elemul(&(-outputs + 1f64))))
78 }
79}
80
81impl CostFunc<Vector<f64>> for CrossEntropyError {
82 fn cost(outputs: &Vector<f64>, targets: &Vector<f64>) -> f64 {
83 let log_inv_output = (-outputs + 1f64).apply(&ln);
85 let log_output = outputs.clone().apply(&ln);
86
87 let mat_cost = targets.elemul(&log_output) + (-targets + 1f64).elemul(&log_inv_output);
88
89 let n = outputs.size();
90
91 -(mat_cost.sum()) / (n as f64)
92 }
93
94 fn grad_cost(outputs: &Vector<f64>, targets: &Vector<f64>) -> Vector<f64> {
95 (outputs - targets).elediv(&(outputs.elemul(&(-outputs + 1f64))))
96 }
97}
98
99fn ln(x: f64) -> f64 {
101 x.ln()
102}