1#[cfg(not(feature = "std"))]
4use alloc::vec::Vec;
5use core::cmp::Ordering;
6
7use rand::Rng;
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11use crate::{
12 Clause, Config,
13 feedback::{type_i, type_ii},
14 utils::rng_from_seed
15};
16
17#[derive(Debug, Clone)]
36#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
37pub struct MultiClass {
38 clauses: Vec<Vec<Clause>>,
39 config: Config,
40 threshold: i32
41}
42
43impl MultiClass {
44 #[must_use]
52 pub fn new(config: Config, n_classes: usize, threshold: i32) -> Self {
53 let clauses = (0..n_classes)
54 .map(|_| {
55 (0..config.n_clauses)
56 .map(|i| {
57 let p = if i % 2 == 0 { 1 } else { -1 };
58 Clause::new(config.n_features, config.n_states, p)
59 })
60 .collect()
61 })
62 .collect();
63
64 Self {
65 clauses,
66 config,
67 threshold
68 }
69 }
70
71 #[inline]
73 #[must_use]
74 pub fn n_classes(&self) -> usize {
75 self.clauses.len()
76 }
77
78 #[must_use]
80 pub fn class_votes(&self, x: &[u8]) -> Vec<f32> {
81 self.clauses
82 .iter()
83 .map(|cls| cls.iter().map(|c| c.vote(x)).sum())
84 .collect()
85 }
86
87 #[must_use]
91 pub fn predict(&self, x: &[u8]) -> usize {
92 self.class_votes(x)
93 .iter()
94 .enumerate()
95 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
96 .map_or(0, |(i, _)| i)
97 }
98
99 pub fn train_one<R: Rng>(&mut self, x: &[u8], y: usize, rng: &mut R) {
101 let votes = self.class_votes(x);
102 let t = self.threshold as f32;
103
104 for (class_idx, class_clauses) in self.clauses.iter_mut().enumerate() {
105 let is_target = class_idx == y;
106 let sum = votes[class_idx].clamp(-t, t);
107
108 for clause in class_clauses {
109 let fires = clause.evaluate(x);
110 let p = clause.polarity();
111
112 if is_target {
113 let prob = (t - sum) / (2.0 * t);
114 if p == 1 && rng.random::<f32>() <= prob {
115 type_i(clause, x, fires, self.config.s, rng);
116 } else if p == -1 && fires && rng.random::<f32>() <= prob {
117 type_ii(clause, x);
118 }
119 } else {
120 let prob = (t + sum) / (2.0 * t);
121 if p == -1 && rng.random::<f32>() <= prob {
122 type_i(clause, x, fires, self.config.s, rng);
123 } else if p == 1 && fires && rng.random::<f32>() <= prob {
124 type_ii(clause, x);
125 }
126 }
127 }
128 }
129 }
130
131 pub fn fit(&mut self, x: &[Vec<u8>], y: &[usize], epochs: usize, seed: u64) {
140 let mut rng = rng_from_seed(seed);
141 let mut indices: Vec<usize> = (0..x.len()).collect();
142
143 for _ in 0..epochs {
144 crate::utils::shuffle(&mut indices, &mut rng);
145 for &i in &indices {
146 self.train_one(&x[i], y[i], &mut rng);
147 }
148 }
149 }
150
151 #[must_use]
155 pub fn evaluate(&self, x: &[Vec<u8>], y: &[usize]) -> f32 {
156 if x.is_empty() {
157 return 0.0;
158 }
159 let correct = x
160 .iter()
161 .zip(y)
162 .filter(|(xi, yi)| self.predict(xi) == **yi)
163 .count();
164 correct as f32 / x.len() as f32
165 }
166
167 #[must_use]
173 pub fn quick(n_clauses: usize, n_features: usize, n_classes: usize, threshold: i32) -> Self {
174 let config = Config::builder()
175 .clauses(n_clauses)
176 .features(n_features)
177 .build()
178 .expect("invalid quick config");
179 Self::new(config, n_classes, threshold)
180 }
181
182 #[inline]
190 pub fn partial_fit(&mut self, x: &[u8], y: usize, seed: u64) {
191 let mut rng = rng_from_seed(seed);
192 self.train_one(x, y, &mut rng);
193 }
194
195 pub fn partial_fit_batch(&mut self, xs: &[Vec<u8>], ys: &[usize], seed: u64) {
203 if xs.is_empty() || xs.len() != ys.len() {
204 return;
205 }
206
207 let mut rng = rng_from_seed(seed);
208 for (x, &y) in xs.iter().zip(ys) {
209 self.train_one(x, y, &mut rng);
210 }
211 }
212}
213
214impl crate::model::TsetlinModel<Vec<u8>, usize> for MultiClass {
215 fn fit(&mut self, x: &[Vec<u8>], y: &[usize], epochs: usize, seed: u64) {
216 MultiClass::fit(self, x, y, epochs, seed);
217 }
218
219 fn predict(&self, x: &Vec<u8>) -> usize {
220 MultiClass::predict(self, x)
221 }
222
223 fn evaluate(&self, x: &[Vec<u8>], y: &[usize]) -> f32 {
224 MultiClass::evaluate(self, x, y)
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn predict_valid_class() {
234 let config = Config::builder().clauses(10).features(4).build().unwrap();
235 let tm = MultiClass::new(config, 3, 15);
236
237 assert!(tm.predict(&[1, 0, 1, 0]) < 3);
238 }
239
240 #[test]
241 fn n_classes_correct() {
242 let config = Config::builder().clauses(10).features(4).build().unwrap();
243 let tm = MultiClass::new(config, 5, 15);
244 assert_eq!(tm.n_classes(), 5);
245 }
246
247 #[test]
248 fn class_votes_returns_all_classes() {
249 let config = Config::builder().clauses(10).features(4).build().unwrap();
250 let tm = MultiClass::new(config, 3, 15);
251 let votes = tm.class_votes(&[1, 0, 1, 0]);
252 assert_eq!(votes.len(), 3);
253 }
254
255 #[test]
256 fn quick_constructor() {
257 let tm = MultiClass::quick(20, 4, 3, 15);
258 assert_eq!(tm.n_classes(), 3);
259 }
260
261 #[test]
262 fn evaluate_empty_returns_zero() {
263 let config = Config::builder().clauses(10).features(4).build().unwrap();
264 let tm = MultiClass::new(config, 3, 15);
265 assert!((tm.evaluate(&[], &[]) - 0.0).abs() < 0.001);
266 }
267
268 #[test]
269 fn train_one_modifies_state() {
270 let config = Config::builder().clauses(10).features(4).build().unwrap();
271 let mut tm = MultiClass::new(config, 2, 15);
272 let mut rng = rng_from_seed(42);
273
274 tm.train_one(&[1, 0, 1, 0], 0, &mut rng);
276 tm.train_one(&[0, 1, 0, 1], 1, &mut rng);
277
278 assert!(tm.predict(&[1, 0, 1, 0]) < 2);
280 }
281
282 #[test]
283 fn fit_and_evaluate() {
284 let mut tm = MultiClass::quick(50, 4, 2, 25);
285
286 let x = vec![
288 vec![1, 1, 0, 0],
289 vec![1, 0, 0, 0],
290 vec![0, 0, 1, 1],
291 vec![0, 0, 0, 1],
292 ];
293 let y = vec![0, 0, 1, 1];
294
295 tm.fit(&x, &y, 100, 42);
296
297 let acc = tm.evaluate(&x, &y);
299 assert!((0.0..=1.0).contains(&acc));
300 }
301
302 #[test]
303 fn trait_impl_works() {
304 use crate::model::TsetlinModel;
305
306 let config = Config::builder().clauses(20).features(4).build().unwrap();
307 let mut tm = MultiClass::new(config, 2, 15);
308
309 let x = vec![vec![1, 1, 0, 0], vec![0, 0, 1, 1]];
310 let y = vec![0, 1];
311
312 TsetlinModel::fit(&mut tm, &x, &y, 50, 42);
313 let pred = TsetlinModel::predict(&tm, &x[0]);
314 assert!(pred < 2);
315
316 let acc = TsetlinModel::evaluate(&tm, &x, &y);
317 assert!((0.0..=1.0).contains(&acc));
318 }
319
320 #[test]
321 fn partial_fit_single_sample() {
322 let mut tm = MultiClass::quick(20, 4, 3, 15);
323
324 tm.partial_fit(&[1, 1, 0, 0], 0, 42);
325 tm.partial_fit(&[0, 1, 1, 0], 1, 43);
326 tm.partial_fit(&[0, 0, 1, 1], 2, 44);
327
328 assert!(tm.predict(&[1, 1, 0, 0]) < 3);
329 }
330
331 #[test]
332 fn partial_fit_batch() {
333 let mut tm = MultiClass::quick(50, 4, 2, 25);
334
335 let x = vec![
336 vec![1, 1, 0, 0],
337 vec![1, 0, 0, 0],
338 vec![0, 0, 1, 1],
339 vec![0, 0, 0, 1],
340 ];
341 let y = vec![0, 0, 1, 1];
342
343 for epoch in 0..100 {
344 tm.partial_fit_batch(&x, &y, 42 + epoch);
345 }
346
347 let acc = tm.evaluate(&x, &y);
348 assert!((0.0..=1.0).contains(&acc));
349 }
350
351 #[test]
352 fn partial_fit_empty_batch() {
353 let mut tm = MultiClass::quick(10, 4, 2, 15);
354 tm.partial_fit_batch(&[], &[], 42);
355 assert_eq!(tm.n_classes(), 2);
356 }
357}