1#[cfg(not(feature = "std"))]
5use alloc::vec::Vec;
6
7use rand::Rng;
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11use crate::{
12 Clause, Config, Rule, SparseTsetlinMachine,
13 feedback::{type_i, type_ii},
14 training::{EarlyStopTracker, FitOptions, FitResult},
15 utils::rng_from_seed
16};
17
18#[derive(Debug, Clone)]
22#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
23pub struct AdvancedOptions {
24 pub adaptive_t: bool,
25 pub t_min: f32,
26 pub t_max: f32,
27 pub t_lr: f32,
28 pub weight_lr: f32,
29 pub weight_min: f32,
30 pub weight_max: f32,
31 pub prune_threshold: u32,
32 pub prune_weight: f32
33}
34
35impl Default for AdvancedOptions {
36 fn default() -> Self {
37 Self {
38 adaptive_t: false,
39 t_min: 5.0,
40 t_max: 50.0,
41 t_lr: 0.1,
42 weight_lr: 0.05,
43 weight_min: 0.1,
44 weight_max: 2.0,
45 prune_threshold: 0,
46 prune_weight: 0.0
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
66#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
67pub struct TsetlinMachine {
68 clauses: Vec<Clause>,
69 config: Config,
70 t: f32,
71 t_base: f32,
72 advanced: AdvancedOptions
73}
74
75impl TsetlinMachine {
76 pub fn new(config: Config, threshold: i32) -> Self {
80 let clauses = (0..config.n_clauses)
81 .map(|i| {
82 let p = if i % 2 == 0 { 1 } else { -1 };
83 Clause::new(config.n_features, config.n_states, p)
84 })
85 .collect();
86
87 let t = threshold as f32;
88 Self {
89 clauses,
90 config,
91 t,
92 t_base: t,
93 advanced: AdvancedOptions::default()
94 }
95 }
96
97 pub fn with_advanced(config: Config, threshold: i32, advanced: AdvancedOptions) -> Self {
101 let mut tm = Self::new(config, threshold);
102 tm.advanced = advanced;
103 tm
104 }
105
106 #[inline]
108 #[must_use]
109 pub fn threshold(&self) -> f32 {
110 self.t
111 }
112
113 #[inline]
115 pub const fn config(&self) -> &Config {
116 &self.config
117 }
118
119 #[inline]
121 #[must_use]
122 pub fn clauses(&self) -> &[Clause] {
123 &self.clauses
124 }
125
126 #[inline]
130 pub fn threshold_base(&self) -> f32 {
131 self.t_base
132 }
133
134 pub fn reset_threshold(&mut self) {
138 self.t = self.t_base;
139 }
140
141 #[inline]
145 pub fn sum_votes(&self, x: &[u8]) -> f32 {
146 self.clauses.iter().map(|c| c.vote(x)).sum()
147 }
148
149 #[inline(always)]
153 pub fn predict(&self, x: &[u8]) -> u8 {
154 if self.sum_votes(x) >= 0.0 { 1 } else { 0 }
155 }
156
157 #[inline]
161 pub fn predict_batch(&self, xs: &[Vec<u8>]) -> Vec<u8> {
162 xs.iter().map(|x| self.predict(x)).collect()
163 }
164
165 #[inline]
169 pub fn train_one<R: Rng>(&mut self, x: &[u8], y: u8, rng: &mut R) {
170 let sum = self.sum_votes(x).clamp(-self.t, self.t);
171 let inv_2t = 1.0 / (2.0 * self.t);
172 let s = self.config.s;
173
174 let prob = if y == 1 {
175 (self.t - sum) * inv_2t
176 } else {
177 (self.t + sum) * inv_2t
178 };
179
180 let prediction = if sum >= 0.0 { 1 } else { 0 };
181 let correct = prediction == y;
182
183 for clause in &mut self.clauses {
184 let fires = clause.evaluate_tracked(x);
185 let p = clause.polarity();
186
187 if fires {
189 let clause_correct = (p == 1 && y == 1) || (p == -1 && y == 0);
190 clause.record_outcome(clause_correct);
191 }
192
193 if y == 1 {
194 if p == 1 && rng.random::<f32>() <= prob {
195 type_i(clause, x, fires, s, rng);
196 } else if p == -1 && fires && rng.random::<f32>() <= prob {
197 type_ii(clause, x);
198 }
199 } else if p == -1 && rng.random::<f32>() <= prob {
200 type_i(clause, x, fires, s, rng);
201 } else if p == 1 && fires && rng.random::<f32>() <= prob {
202 type_ii(clause, x);
203 }
204 }
205
206 if self.advanced.adaptive_t {
208 let adj = if correct {
209 self.advanced.t_lr
210 } else {
211 -self.advanced.t_lr
212 };
213 self.t = (self.t + adj).clamp(self.advanced.t_min, self.advanced.t_max);
214 }
215 }
216
217 pub fn update_weights(&mut self) {
221 let lr = self.advanced.weight_lr;
222 let min = self.advanced.weight_min;
223 let max = self.advanced.weight_max;
224
225 for clause in &mut self.clauses {
226 clause.update_weight(lr, min, max);
227 }
228 }
229
230 pub fn prune_dead_clauses(&mut self) {
234 let min_act = self.advanced.prune_threshold;
235 let min_wt = self.advanced.prune_weight;
236
237 if min_act == 0 && min_wt == 0.0 {
238 return;
239 }
240
241 for clause in &mut self.clauses {
242 if clause.is_dead(min_act, min_wt) {
243 *clause = Clause::new(
245 self.config.n_features,
246 self.config.n_states,
247 clause.polarity()
248 );
249 }
250 }
251 }
252
253 pub fn reset_activations(&mut self) {
257 for clause in &mut self.clauses {
258 clause.reset_activations();
259 }
260 }
261
262 pub fn fit(&mut self, x: &[Vec<u8>], y: &[u8], epochs: usize, seed: u64) {
271 let _ = self.fit_with_options(x, y, FitOptions::new(epochs, seed));
272 }
273
274 pub fn fit_with_options(
286 &mut self,
287 x: &[Vec<u8>],
288 y: &[u8],
289 mut opts: FitOptions
290 ) -> FitResult {
291 if x.is_empty() || x.len() != y.len() {
292 return FitResult::new(0, 0.0, false);
293 }
294
295 let mut rng = rng_from_seed(opts.seed);
296 let mut indices: Vec<usize> = (0..x.len()).collect();
297 let mut tracker = opts.early_stop.as_ref().map(EarlyStopTracker::new);
298 let mut stopped = false;
299 let mut epochs_run = 0;
300 let mut history = Vec::with_capacity(opts.epochs);
301
302 for epoch in 0..opts.epochs {
303 self.reset_activations();
304
305 if opts.shuffle {
306 crate::utils::shuffle(&mut indices, &mut rng);
307 }
308
309 for &i in &indices {
310 self.train_one(&x[i], y[i], &mut rng);
311 }
312
313 self.update_weights();
315 self.prune_dead_clauses();
316
317 epochs_run = epoch + 1;
318 let accuracy = self.evaluate(x, y);
319 history.push(accuracy);
320
321 if let Some(ref mut callback) = opts.callback
323 && !callback(epoch + 1, accuracy)
324 {
325 stopped = true;
326 break;
327 }
328
329 if let Some(ref mut t) = tracker
331 && t.update(accuracy)
332 {
333 stopped = true;
334 break;
335 }
336 }
337
338 FitResult::with_history(epochs_run, self.evaluate(x, y), stopped, history)
339 }
340
341 #[inline]
345 #[must_use]
346 pub fn evaluate(&self, x: &[Vec<u8>], y: &[u8]) -> f32 {
347 if x.is_empty() {
348 return 0.0;
349 }
350 let correct = x
351 .iter()
352 .zip(y)
353 .filter(|(xi, yi)| self.predict(xi) == **yi)
354 .count();
355 correct as f32 / x.len() as f32
356 }
357
358 #[must_use]
360 pub fn rules(&self) -> Vec<Rule> {
361 self.clauses.iter().map(Rule::from_clause).collect()
362 }
363
364 pub fn clause_weights(&self) -> Vec<f32> {
368 self.clauses.iter().map(|c| c.weight()).collect()
369 }
370
371 #[must_use]
373 pub fn clause_activations(&self) -> Vec<u32> {
374 self.clauses.iter().map(|c| c.activations()).collect()
375 }
376
377 #[must_use]
387 pub fn quick(n_clauses: usize, n_features: usize, threshold: i32) -> Self {
388 let config = Config::builder()
389 .clauses(n_clauses)
390 .features(n_features)
391 .build()
392 .expect("invalid quick config");
393 Self::new(config, threshold)
394 }
395
396 #[must_use]
435 pub fn to_sparse(&self) -> SparseTsetlinMachine {
436 SparseTsetlinMachine::from_clauses(&self.clauses, self.config.n_features, self.t)
437 }
438
439 #[inline]
464 pub fn partial_fit(&mut self, x: &[u8], y: u8, seed: u64) {
465 let mut rng = rng_from_seed(seed);
466 self.train_one(x, y, &mut rng);
467 }
468
469 pub fn partial_fit_batch(
495 &mut self,
496 xs: &[Vec<u8>],
497 ys: &[u8],
498 seed: u64,
499 update_weights: bool
500 ) {
501 if xs.is_empty() || xs.len() != ys.len() {
502 return;
503 }
504
505 let mut rng = rng_from_seed(seed);
506 for (x, &y) in xs.iter().zip(ys) {
507 self.train_one(x, y, &mut rng);
508 }
509
510 if update_weights {
511 self.update_weights();
512 }
513 }
514}
515
516impl crate::model::TsetlinModel<Vec<u8>, u8> for TsetlinMachine {
517 fn fit(&mut self, x: &[Vec<u8>], y: &[u8], epochs: usize, seed: u64) {
518 TsetlinMachine::fit(self, x, y, epochs, seed);
519 }
520
521 fn predict(&self, x: &Vec<u8>) -> u8 {
522 TsetlinMachine::predict(self, x)
523 }
524
525 fn evaluate(&self, x: &[Vec<u8>], y: &[u8]) -> f32 {
526 TsetlinMachine::evaluate(self, x, y)
527 }
528
529 fn predict_batch(&self, xs: &[Vec<u8>]) -> Vec<u8> {
530 TsetlinMachine::predict_batch(self, xs)
531 }
532}
533
534impl crate::model::VotingModel<Vec<u8>> for TsetlinMachine {
535 fn sum_votes(&self, x: &Vec<u8>) -> f32 {
536 TsetlinMachine::sum_votes(self, x)
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543
544 #[test]
545 fn xor_convergence() {
546 let config = Config::builder().clauses(20).features(2).build().unwrap();
547 let mut tm = TsetlinMachine::new(config, 10);
548
549 let x = vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1]];
550 let y = vec![0, 1, 1, 0];
551
552 tm.fit(&x, &y, 200, 42);
553 assert!(tm.evaluate(&x, &y) >= 0.75);
554 }
555
556 #[test]
557 fn batch_predict() {
558 let config = Config::builder().clauses(10).features(2).build().unwrap();
559 let tm = TsetlinMachine::new(config, 5);
560
561 let xs = vec![vec![0, 0], vec![1, 1]];
562 let preds = tm.predict_batch(&xs);
563 assert_eq!(preds.len(), 2);
564 }
565
566 #[test]
567 fn weighted_clauses() {
568 let config = Config::builder().clauses(10).features(2).build().unwrap();
569 let tm = TsetlinMachine::new(config, 5);
570
571 let weights = tm.clause_weights();
572 assert!(weights.iter().all(|&w| (w - 1.0).abs() < 0.001));
573 }
574
575 #[test]
576 fn adaptive_threshold() {
577 let config = Config::builder().clauses(10).features(2).build().unwrap();
578 let opts = AdvancedOptions {
579 adaptive_t: true,
580 t_min: 3.0,
581 t_max: 20.0,
582 t_lr: 0.5,
583 ..Default::default()
584 };
585 let tm = TsetlinMachine::with_advanced(config, 10, opts);
586 assert!((tm.threshold() - 10.0).abs() < 0.001);
587 }
588
589 #[test]
590 fn accessors_config_clauses() {
591 let config = Config::builder().clauses(20).features(4).build().unwrap();
592 let tm = TsetlinMachine::new(config, 10);
593
594 assert_eq!(tm.config().n_clauses, 20);
595 assert_eq!(tm.config().n_features, 4);
596 assert_eq!(tm.clauses().len(), 20);
597 }
598
599 #[test]
600 fn threshold_base_and_reset() {
601 let config = Config::builder().clauses(10).features(2).build().unwrap();
602 let mut tm = TsetlinMachine::new(config, 15);
603
604 assert!((tm.threshold_base() - 15.0).abs() < 0.001);
605
606 let x = vec![vec![0, 0], vec![1, 1]];
608 let y = vec![0, 1];
609 tm.fit(&x, &y, 5, 42);
610
611 tm.reset_threshold();
612 assert!((tm.threshold() - 15.0).abs() < 0.001);
613 }
614
615 #[test]
616 fn trait_impl_tsetlin_model() {
617 use crate::model::TsetlinModel;
618
619 let config = Config::builder().clauses(20).features(2).build().unwrap();
620 let mut tm = TsetlinMachine::new(config, 10);
621
622 let x = vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1]];
623 let y = vec![0, 1, 1, 0];
624
625 TsetlinModel::fit(&mut tm, &x, &y, 100, 42);
626 let pred = TsetlinModel::predict(&tm, &x[1]);
627 assert!(pred == 0 || pred == 1);
628
629 let acc = TsetlinModel::evaluate(&tm, &x, &y);
630 assert!((0.0..=1.0).contains(&acc));
631
632 let batch = TsetlinModel::predict_batch(&tm, &x);
633 assert_eq!(batch.len(), 4);
634 }
635
636 #[test]
637 fn trait_impl_voting_model() {
638 use crate::model::VotingModel;
639
640 let config = Config::builder().clauses(20).features(2).build().unwrap();
641 let tm = TsetlinMachine::new(config, 10);
642
643 let votes = VotingModel::sum_votes(&tm, &vec![1, 0]);
644 assert!(votes.is_finite());
645 }
646
647 #[test]
648 fn partial_fit_single_sample() {
649 let config = Config::builder().clauses(20).features(2).build().unwrap();
650 let mut tm = TsetlinMachine::new(config, 10);
651
652 tm.partial_fit(&[0, 1], 1, 42);
653 tm.partial_fit(&[1, 0], 1, 43);
654 tm.partial_fit(&[0, 0], 0, 44);
655 tm.partial_fit(&[1, 1], 0, 45);
656
657 let pred = tm.predict(&[0, 1]);
658 assert!(pred == 0 || pred == 1);
659 }
660
661 #[test]
662 fn partial_fit_batch_learning() {
663 let config = Config::builder().clauses(20).features(2).build().unwrap();
664 let mut tm = TsetlinMachine::new(config, 10);
665
666 let x = vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1]];
667 let y = vec![0, 1, 1, 0];
668
669 for epoch in 0..100 {
670 tm.partial_fit_batch(&x, &y, 42 + epoch, true);
671 }
672
673 assert!(tm.evaluate(&x, &y) >= 0.5);
674 }
675
676 #[test]
677 fn partial_fit_empty_batch() {
678 let config = Config::builder().clauses(10).features(2).build().unwrap();
679 let mut tm = TsetlinMachine::new(config, 5);
680
681 tm.partial_fit_batch(&[], &[], 42, true);
682 assert_eq!(tm.clauses().len(), 10);
683 }
684
685 #[test]
686 fn partial_fit_streaming_simulation() {
687 let config = Config::builder().clauses(20).features(2).build().unwrap();
688 let mut tm = TsetlinMachine::new(config, 10);
689
690 let samples = [
691 (vec![0, 0], 0u8),
692 (vec![0, 1], 1),
693 (vec![1, 0], 1),
694 (vec![1, 1], 0)
695 ];
696
697 for (seed, (x, y)) in samples.iter().cycle().take(400).enumerate() {
698 tm.partial_fit(x, *y, seed as u64);
699 }
700
701 let x: Vec<Vec<u8>> = samples.iter().map(|(x, _)| x.clone()).collect();
702 let y: Vec<u8> = samples.iter().map(|(_, y)| *y).collect();
703 assert!(tm.evaluate(&x, &y) >= 0.5);
704 }
705}