1#[cfg(not(feature = "std"))]
4use alloc::vec::Vec;
5
6pub trait TsetlinModel<X, Y> {
31 fn fit(&mut self, x: &[X], y: &[Y], epochs: usize, seed: u64);
40
41 fn predict(&self, x: &X) -> Y;
43
44 fn evaluate(&self, x: &[X], y: &[Y]) -> f32;
48
49 fn predict_batch(&self, xs: &[X]) -> Vec<Y> {
51 xs.iter().map(|x| self.predict(x)).collect()
52 }
53}
54
55pub trait VotingModel<X>: TsetlinModel<X, u8> {
57 fn sum_votes(&self, x: &X) -> f32;
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64
65 struct MockModel;
66
67 impl TsetlinModel<u8, u8> for MockModel {
68 fn fit(&mut self, _x: &[u8], _y: &[u8], _epochs: usize, _seed: u64) {}
69
70 fn predict(&self, x: &u8) -> u8 {
71 *x % 2
72 }
73
74 fn evaluate(&self, x: &[u8], y: &[u8]) -> f32 {
75 let correct = x
76 .iter()
77 .zip(y)
78 .filter(|(xi, yi)| self.predict(xi) == **yi)
79 .count();
80 correct as f32 / x.len() as f32
81 }
82 }
83
84 #[test]
85 fn predict_batch_default_impl() {
86 let model = MockModel;
87 let xs = vec![0, 1, 2, 3, 4];
88 let preds = model.predict_batch(&xs);
89 assert_eq!(preds, vec![0, 1, 0, 1, 0]);
90 }
91
92 #[test]
93 fn mock_model_fit() {
94 let mut model = MockModel;
95 model.fit(&[1, 2, 3], &[0, 1, 0], 10, 42);
96 }
98
99 #[test]
100 fn mock_model_evaluate() {
101 let model = MockModel;
102 let acc = model.evaluate(&[0, 1, 2], &[0, 1, 0]);
107 assert!((acc - 1.0).abs() < 0.001);
108
109 let acc2 = model.evaluate(&[0, 1, 2, 3], &[1, 0, 1, 0]);
111 assert!((acc2 - 0.0).abs() < 0.001);
112 }
113
114 struct MockVotingModel;
115
116 impl TsetlinModel<u8, u8> for MockVotingModel {
117 fn fit(&mut self, _x: &[u8], _y: &[u8], _epochs: usize, _seed: u64) {}
118
119 fn predict(&self, x: &u8) -> u8 {
120 if self.sum_votes(x) >= 0.0 { 1 } else { 0 }
121 }
122
123 fn evaluate(&self, x: &[u8], y: &[u8]) -> f32 {
124 let correct = x
125 .iter()
126 .zip(y)
127 .filter(|(xi, yi)| self.predict(xi) == **yi)
128 .count();
129 correct as f32 / x.len() as f32
130 }
131 }
132
133 impl VotingModel<u8> for MockVotingModel {
134 fn sum_votes(&self, x: &u8) -> f32 {
135 (*x as f32) - 2.0 }
137 }
138
139 #[test]
140 fn voting_model_sum_votes() {
141 let model = MockVotingModel;
142
143 assert!((model.sum_votes(&0) - (-2.0)).abs() < 0.001);
144 assert!((model.sum_votes(&2) - 0.0).abs() < 0.001);
145 assert!((model.sum_votes(&5) - 3.0).abs() < 0.001);
146 }
147
148 #[test]
149 fn voting_model_predict_uses_votes() {
150 let model = MockVotingModel;
151
152 assert_eq!(model.predict(&0), 0);
154 assert_eq!(model.predict(&1), 0);
155
156 assert_eq!(model.predict(&2), 1);
158 assert_eq!(model.predict(&5), 1);
159 }
160
161 #[test]
162 fn voting_model_evaluate() {
163 let model = MockVotingModel;
164 let xs = vec![0, 1, 2, 3];
165 let ys = vec![0, 0, 1, 1];
166 let acc = model.evaluate(&xs, &ys);
167 assert!((acc - 1.0).abs() < 0.001);
168 }
169}