Skip to main content

tsetlin_rs/
binary.rs

1//! Binary classification Tsetlin Machine with weighted clauses and adaptive
2//! threshold.
3
4#[cfg(not(feature = "std"))]
5use alloc::vec::Vec;
6
7use rand::Rng;
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11use crate::{
12    Clause, Config, Rule, SparseTsetlinMachine,
13    feedback::{type_i, type_ii},
14    training::{EarlyStopTracker, FitOptions, FitResult},
15    utils::rng_from_seed
16};
17
18/// # Overview
19///
20/// Configuration for advanced training features.
21#[derive(Debug, Clone)]
22#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
23pub struct AdvancedOptions {
24    pub adaptive_t:      bool,
25    pub t_min:           f32,
26    pub t_max:           f32,
27    pub t_lr:            f32,
28    pub weight_lr:       f32,
29    pub weight_min:      f32,
30    pub weight_max:      f32,
31    pub prune_threshold: u32,
32    pub prune_weight:    f32
33}
34
35impl Default for AdvancedOptions {
36    fn default() -> Self {
37        Self {
38            adaptive_t:      false,
39            t_min:           5.0,
40            t_max:           50.0,
41            t_lr:            0.1,
42            weight_lr:       0.05,
43            weight_min:      0.1,
44            weight_max:      2.0,
45            prune_threshold: 0,
46            prune_weight:    0.0
47        }
48    }
49}
50
51/// # Overview
52///
53/// Binary classification Tsetlin Machine with weighted clauses and adaptive
54/// threshold.
55///
56/// # Examples
57///
58/// ```
59/// use tsetlin_rs::{Config, TsetlinMachine};
60///
61/// let config = Config::builder().clauses(20).features(2).build().unwrap();
62///
63/// let mut tm = TsetlinMachine::new(config, 15);
64/// ```
65#[derive(Debug, Clone)]
66#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
67pub struct TsetlinMachine {
68    clauses:  Vec<Clause>,
69    config:   Config,
70    t:        f32,
71    t_base:   f32,
72    advanced: AdvancedOptions
73}
74
75impl TsetlinMachine {
76    /// # Overview
77    ///
78    /// Creates new machine. Half clauses +1 polarity, half -1.
79    pub fn new(config: Config, threshold: i32) -> Self {
80        let clauses = (0..config.n_clauses)
81            .map(|i| {
82                let p = if i % 2 == 0 { 1 } else { -1 };
83                Clause::new(config.n_features, config.n_states, p)
84            })
85            .collect();
86
87        let t = threshold as f32;
88        Self {
89            clauses,
90            config,
91            t,
92            t_base: t,
93            advanced: AdvancedOptions::default()
94        }
95    }
96
97    /// # Overview
98    ///
99    /// Creates machine with advanced options.
100    pub fn with_advanced(config: Config, threshold: i32, advanced: AdvancedOptions) -> Self {
101        let mut tm = Self::new(config, threshold);
102        tm.advanced = advanced;
103        tm
104    }
105
106    /// Returns current threshold value (may differ from base if adaptive).
107    #[inline]
108    #[must_use]
109    pub fn threshold(&self) -> f32 {
110        self.t
111    }
112
113    /// Returns configuration.
114    #[inline]
115    pub const fn config(&self) -> &Config {
116        &self.config
117    }
118
119    /// Returns read-only access to clauses.
120    #[inline]
121    #[must_use]
122    pub fn clauses(&self) -> &[Clause] {
123        &self.clauses
124    }
125
126    /// # Overview
127    ///
128    /// Base threshold (initial value).
129    #[inline]
130    pub fn threshold_base(&self) -> f32 {
131        self.t_base
132    }
133
134    /// # Overview
135    ///
136    /// Resets threshold to base value.
137    pub fn reset_threshold(&mut self) {
138        self.t = self.t_base;
139    }
140
141    /// # Overview
142    ///
143    /// Sum of weighted clause votes for input x.
144    #[inline]
145    pub fn sum_votes(&self, x: &[u8]) -> f32 {
146        self.clauses.iter().map(|c| c.vote(x)).sum()
147    }
148
149    /// # Overview
150    ///
151    /// Predicts class (0 or 1).
152    #[inline(always)]
153    pub fn predict(&self, x: &[u8]) -> u8 {
154        if self.sum_votes(x) >= 0.0 { 1 } else { 0 }
155    }
156
157    /// # Overview
158    ///
159    /// Batch prediction for multiple samples.
160    #[inline]
161    pub fn predict_batch(&self, xs: &[Vec<u8>]) -> Vec<u8> {
162        xs.iter().map(|x| self.predict(x)).collect()
163    }
164
165    /// # Overview
166    ///
167    /// Trains on single example with tracking.
168    #[inline]
169    pub fn train_one<R: Rng>(&mut self, x: &[u8], y: u8, rng: &mut R) {
170        let sum = self.sum_votes(x).clamp(-self.t, self.t);
171        let inv_2t = 1.0 / (2.0 * self.t);
172        let s = self.config.s;
173
174        let prob = if y == 1 {
175            (self.t - sum) * inv_2t
176        } else {
177            (self.t + sum) * inv_2t
178        };
179
180        let prediction = if sum >= 0.0 { 1 } else { 0 };
181        let correct = prediction == y;
182
183        for clause in &mut self.clauses {
184            let fires = clause.evaluate_tracked(x);
185            let p = clause.polarity();
186
187            // Record outcome for weight learning
188            if fires {
189                let clause_correct = (p == 1 && y == 1) || (p == -1 && y == 0);
190                clause.record_outcome(clause_correct);
191            }
192
193            if y == 1 {
194                if p == 1 && rng.random::<f32>() <= prob {
195                    type_i(clause, x, fires, s, rng);
196                } else if p == -1 && fires && rng.random::<f32>() <= prob {
197                    type_ii(clause, x);
198                }
199            } else if p == -1 && rng.random::<f32>() <= prob {
200                type_i(clause, x, fires, s, rng);
201            } else if p == 1 && fires && rng.random::<f32>() <= prob {
202                type_ii(clause, x);
203            }
204        }
205
206        // Adaptive threshold adjustment
207        if self.advanced.adaptive_t {
208            let adj = if correct {
209                self.advanced.t_lr
210            } else {
211                -self.advanced.t_lr
212            };
213            self.t = (self.t + adj).clamp(self.advanced.t_min, self.advanced.t_max);
214        }
215    }
216
217    /// # Overview
218    ///
219    /// Updates clause weights. Call at end of epoch.
220    pub fn update_weights(&mut self) {
221        let lr = self.advanced.weight_lr;
222        let min = self.advanced.weight_min;
223        let max = self.advanced.weight_max;
224
225        for clause in &mut self.clauses {
226            clause.update_weight(lr, min, max);
227        }
228    }
229
230    /// # Overview
231    ///
232    /// Prunes dead clauses (low activation or weight).
233    pub fn prune_dead_clauses(&mut self) {
234        let min_act = self.advanced.prune_threshold;
235        let min_wt = self.advanced.prune_weight;
236
237        if min_act == 0 && min_wt == 0.0 {
238            return;
239        }
240
241        for clause in &mut self.clauses {
242            if clause.is_dead(min_act, min_wt) {
243                // Reset dead clause to fresh state
244                *clause = Clause::new(
245                    self.config.n_features,
246                    self.config.n_states,
247                    clause.polarity()
248                );
249            }
250        }
251    }
252
253    /// # Overview
254    ///
255    /// Resets activation counters. Call at start of epoch.
256    pub fn reset_activations(&mut self) {
257        for clause in &mut self.clauses {
258            clause.reset_activations();
259        }
260    }
261
262    /// Simple training for given epochs.
263    ///
264    /// # Arguments
265    ///
266    /// * `x` - Training inputs (binary features)
267    /// * `y` - Binary labels (0 or 1)
268    /// * `epochs` - Number of training epochs
269    /// * `seed` - Random seed for reproducibility
270    pub fn fit(&mut self, x: &[Vec<u8>], y: &[u8], epochs: usize, seed: u64) {
271        let _ = self.fit_with_options(x, y, FitOptions::new(epochs, seed));
272    }
273
274    /// Training with full options including early stopping and callbacks.
275    ///
276    /// # Arguments
277    ///
278    /// * `x` - Training inputs (binary features)
279    /// * `y` - Binary labels (0 or 1)
280    /// * `opts` - Training options (epochs, early stopping, callback)
281    ///
282    /// # Returns
283    ///
284    /// [`FitResult`] with training statistics.
285    pub fn fit_with_options(
286        &mut self,
287        x: &[Vec<u8>],
288        y: &[u8],
289        mut opts: FitOptions
290    ) -> FitResult {
291        if x.is_empty() || x.len() != y.len() {
292            return FitResult::new(0, 0.0, false);
293        }
294
295        let mut rng = rng_from_seed(opts.seed);
296        let mut indices: Vec<usize> = (0..x.len()).collect();
297        let mut tracker = opts.early_stop.as_ref().map(EarlyStopTracker::new);
298        let mut stopped = false;
299        let mut epochs_run = 0;
300        let mut history = Vec::with_capacity(opts.epochs);
301
302        for epoch in 0..opts.epochs {
303            self.reset_activations();
304
305            if opts.shuffle {
306                crate::utils::shuffle(&mut indices, &mut rng);
307            }
308
309            for &i in &indices {
310                self.train_one(&x[i], y[i], &mut rng);
311            }
312
313            // End of epoch: update weights and prune
314            self.update_weights();
315            self.prune_dead_clauses();
316
317            epochs_run = epoch + 1;
318            let accuracy = self.evaluate(x, y);
319            history.push(accuracy);
320
321            // Callback
322            if let Some(ref mut callback) = opts.callback
323                && !callback(epoch + 1, accuracy)
324            {
325                stopped = true;
326                break;
327            }
328
329            // Early stopping
330            if let Some(ref mut t) = tracker
331                && t.update(accuracy)
332            {
333                stopped = true;
334                break;
335            }
336        }
337
338        FitResult::with_history(epochs_run, self.evaluate(x, y), stopped, history)
339    }
340
341    /// Evaluates accuracy on test data.
342    ///
343    /// Returns fraction of correct predictions (0.0 to 1.0).
344    #[inline]
345    #[must_use]
346    pub fn evaluate(&self, x: &[Vec<u8>], y: &[u8]) -> f32 {
347        if x.is_empty() {
348            return 0.0;
349        }
350        let correct = x
351            .iter()
352            .zip(y)
353            .filter(|(xi, yi)| self.predict(xi) == **yi)
354            .count();
355        correct as f32 / x.len() as f32
356    }
357
358    /// Extracts learned rules from all clauses.
359    #[must_use]
360    pub fn rules(&self) -> Vec<Rule> {
361        self.clauses.iter().map(Rule::from_clause).collect()
362    }
363
364    /// # Overview
365    ///
366    /// Returns clause weights for inspection.
367    pub fn clause_weights(&self) -> Vec<f32> {
368        self.clauses.iter().map(|c| c.weight()).collect()
369    }
370
371    /// Returns clause activation counts.
372    #[must_use]
373    pub fn clause_activations(&self) -> Vec<u32> {
374        self.clauses.iter().map(|c| c.activations()).collect()
375    }
376
377    /// Quick constructor with sensible defaults.
378    ///
379    /// Equivalent to
380    /// `Config::builder().clauses(n_clauses).features(n_features).build()`
381    /// followed by `TsetlinMachine::new(config, threshold)`.
382    ///
383    /// # Panics
384    ///
385    /// Panics if n_clauses is odd or zero, or n_features is zero.
386    #[must_use]
387    pub fn quick(n_clauses: usize, n_features: usize, threshold: i32) -> Self {
388        let config = Config::builder()
389            .clauses(n_clauses)
390            .features(n_features)
391            .build()
392            .expect("invalid quick config");
393        Self::new(config, threshold)
394    }
395
396    /// Converts to sparse representation for memory-efficient inference.
397    ///
398    /// Call after `fit()` to reduce memory by 5-100x depending on clause
399    /// sparsity. The sparse model supports prediction but not training.
400    ///
401    /// # Memory Savings
402    ///
403    /// | Features | Dense | Sparse (typical) | Reduction |
404    /// |----------|-------|------------------|-----------|
405    /// | 100      | 40 KB | 2 KB             | 20x       |
406    /// | 1000     | 400 KB| 8 KB             | 50x       |
407    /// | 10000    | 4 MB  | 20 KB            | 200x      |
408    ///
409    /// # Example
410    ///
411    /// ```
412    /// use tsetlin_rs::{Config, TsetlinMachine};
413    ///
414    /// let config = Config::builder().clauses(20).features(100).build().unwrap();
415    /// let mut tm = TsetlinMachine::new(config, 10);
416    ///
417    /// let x: Vec<Vec<u8>> = (0..100).map(|i| vec![(i % 2) as u8; 100]).collect();
418    /// let y: Vec<u8> = (0..100).map(|i| (i % 2) as u8).collect();
419    ///
420    /// tm.fit(&x, &y, 100, 42);
421    ///
422    /// // Convert to sparse
423    /// let sparse = tm.to_sparse();
424    ///
425    /// // Verify same predictions
426    /// for xi in &x {
427    ///     assert_eq!(tm.predict(xi), sparse.predict(xi));
428    /// }
429    ///
430    /// // Check compression
431    /// let stats = sparse.memory_stats();
432    /// println!("Compression: {}x", stats.compression_ratio(100));
433    /// ```
434    #[must_use]
435    pub fn to_sparse(&self) -> SparseTsetlinMachine {
436        SparseTsetlinMachine::from_clauses(&self.clauses, self.config.n_features, self.t)
437    }
438
439    /// Trains on a single sample (online/incremental learning).
440    ///
441    /// Unlike `fit()`, this processes one sample without requiring the full
442    /// dataset in memory. Useful for streaming data or real-time learning.
443    ///
444    /// # Arguments
445    ///
446    /// * `x` - Single input sample (binary features)
447    /// * `y` - Binary label (0 or 1)
448    /// * `seed` - Random seed for this update
449    ///
450    /// # Example
451    ///
452    /// ```
453    /// use tsetlin_rs::{Config, TsetlinMachine};
454    ///
455    /// let config = Config::builder().clauses(20).features(2).build().unwrap();
456    /// let mut tm = TsetlinMachine::new(config, 10);
457    ///
458    /// // Stream samples one at a time
459    /// tm.partial_fit(&[0, 1], 1, 42);
460    /// tm.partial_fit(&[1, 0], 1, 43);
461    /// tm.partial_fit(&[0, 0], 0, 44);
462    /// ```
463    #[inline]
464    pub fn partial_fit(&mut self, x: &[u8], y: u8, seed: u64) {
465        let mut rng = rng_from_seed(seed);
466        self.train_one(x, y, &mut rng);
467    }
468
469    /// Trains on a mini-batch of samples (online/incremental learning).
470    ///
471    /// Processes multiple samples in sequence without requiring the full
472    /// dataset. Optionally updates weights after the batch.
473    ///
474    /// # Arguments
475    ///
476    /// * `xs` - Batch of input samples
477    /// * `ys` - Batch of binary labels
478    /// * `seed` - Random seed for this batch
479    /// * `update_weights` - Whether to update clause weights after batch
480    ///
481    /// # Example
482    ///
483    /// ```
484    /// use tsetlin_rs::{Config, TsetlinMachine};
485    ///
486    /// let config = Config::builder().clauses(20).features(2).build().unwrap();
487    /// let mut tm = TsetlinMachine::new(config, 10);
488    ///
489    /// // Process mini-batches
490    /// let batch_x = vec![vec![0, 1], vec![1, 0]];
491    /// let batch_y = vec![1, 1];
492    /// tm.partial_fit_batch(&batch_x, &batch_y, 42, true);
493    /// ```
494    pub fn partial_fit_batch(
495        &mut self,
496        xs: &[Vec<u8>],
497        ys: &[u8],
498        seed: u64,
499        update_weights: bool
500    ) {
501        if xs.is_empty() || xs.len() != ys.len() {
502            return;
503        }
504
505        let mut rng = rng_from_seed(seed);
506        for (x, &y) in xs.iter().zip(ys) {
507            self.train_one(x, y, &mut rng);
508        }
509
510        if update_weights {
511            self.update_weights();
512        }
513    }
514}
515
516impl crate::model::TsetlinModel<Vec<u8>, u8> for TsetlinMachine {
517    fn fit(&mut self, x: &[Vec<u8>], y: &[u8], epochs: usize, seed: u64) {
518        TsetlinMachine::fit(self, x, y, epochs, seed);
519    }
520
521    fn predict(&self, x: &Vec<u8>) -> u8 {
522        TsetlinMachine::predict(self, x)
523    }
524
525    fn evaluate(&self, x: &[Vec<u8>], y: &[u8]) -> f32 {
526        TsetlinMachine::evaluate(self, x, y)
527    }
528
529    fn predict_batch(&self, xs: &[Vec<u8>]) -> Vec<u8> {
530        TsetlinMachine::predict_batch(self, xs)
531    }
532}
533
534impl crate::model::VotingModel<Vec<u8>> for TsetlinMachine {
535    fn sum_votes(&self, x: &Vec<u8>) -> f32 {
536        TsetlinMachine::sum_votes(self, x)
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543
544    #[test]
545    fn xor_convergence() {
546        let config = Config::builder().clauses(20).features(2).build().unwrap();
547        let mut tm = TsetlinMachine::new(config, 10);
548
549        let x = vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1]];
550        let y = vec![0, 1, 1, 0];
551
552        tm.fit(&x, &y, 200, 42);
553        assert!(tm.evaluate(&x, &y) >= 0.75);
554    }
555
556    #[test]
557    fn batch_predict() {
558        let config = Config::builder().clauses(10).features(2).build().unwrap();
559        let tm = TsetlinMachine::new(config, 5);
560
561        let xs = vec![vec![0, 0], vec![1, 1]];
562        let preds = tm.predict_batch(&xs);
563        assert_eq!(preds.len(), 2);
564    }
565
566    #[test]
567    fn weighted_clauses() {
568        let config = Config::builder().clauses(10).features(2).build().unwrap();
569        let tm = TsetlinMachine::new(config, 5);
570
571        let weights = tm.clause_weights();
572        assert!(weights.iter().all(|&w| (w - 1.0).abs() < 0.001));
573    }
574
575    #[test]
576    fn adaptive_threshold() {
577        let config = Config::builder().clauses(10).features(2).build().unwrap();
578        let opts = AdvancedOptions {
579            adaptive_t: true,
580            t_min: 3.0,
581            t_max: 20.0,
582            t_lr: 0.5,
583            ..Default::default()
584        };
585        let tm = TsetlinMachine::with_advanced(config, 10, opts);
586        assert!((tm.threshold() - 10.0).abs() < 0.001);
587    }
588
589    #[test]
590    fn accessors_config_clauses() {
591        let config = Config::builder().clauses(20).features(4).build().unwrap();
592        let tm = TsetlinMachine::new(config, 10);
593
594        assert_eq!(tm.config().n_clauses, 20);
595        assert_eq!(tm.config().n_features, 4);
596        assert_eq!(tm.clauses().len(), 20);
597    }
598
599    #[test]
600    fn threshold_base_and_reset() {
601        let config = Config::builder().clauses(10).features(2).build().unwrap();
602        let mut tm = TsetlinMachine::new(config, 15);
603
604        assert!((tm.threshold_base() - 15.0).abs() < 0.001);
605
606        // Threshold may change during training, reset it
607        let x = vec![vec![0, 0], vec![1, 1]];
608        let y = vec![0, 1];
609        tm.fit(&x, &y, 5, 42);
610
611        tm.reset_threshold();
612        assert!((tm.threshold() - 15.0).abs() < 0.001);
613    }
614
615    #[test]
616    fn trait_impl_tsetlin_model() {
617        use crate::model::TsetlinModel;
618
619        let config = Config::builder().clauses(20).features(2).build().unwrap();
620        let mut tm = TsetlinMachine::new(config, 10);
621
622        let x = vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1]];
623        let y = vec![0, 1, 1, 0];
624
625        TsetlinModel::fit(&mut tm, &x, &y, 100, 42);
626        let pred = TsetlinModel::predict(&tm, &x[1]);
627        assert!(pred == 0 || pred == 1);
628
629        let acc = TsetlinModel::evaluate(&tm, &x, &y);
630        assert!((0.0..=1.0).contains(&acc));
631
632        let batch = TsetlinModel::predict_batch(&tm, &x);
633        assert_eq!(batch.len(), 4);
634    }
635
636    #[test]
637    fn trait_impl_voting_model() {
638        use crate::model::VotingModel;
639
640        let config = Config::builder().clauses(20).features(2).build().unwrap();
641        let tm = TsetlinMachine::new(config, 10);
642
643        let votes = VotingModel::sum_votes(&tm, &vec![1, 0]);
644        assert!(votes.is_finite());
645    }
646
647    #[test]
648    fn partial_fit_single_sample() {
649        let config = Config::builder().clauses(20).features(2).build().unwrap();
650        let mut tm = TsetlinMachine::new(config, 10);
651
652        tm.partial_fit(&[0, 1], 1, 42);
653        tm.partial_fit(&[1, 0], 1, 43);
654        tm.partial_fit(&[0, 0], 0, 44);
655        tm.partial_fit(&[1, 1], 0, 45);
656
657        let pred = tm.predict(&[0, 1]);
658        assert!(pred == 0 || pred == 1);
659    }
660
661    #[test]
662    fn partial_fit_batch_learning() {
663        let config = Config::builder().clauses(20).features(2).build().unwrap();
664        let mut tm = TsetlinMachine::new(config, 10);
665
666        let x = vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1]];
667        let y = vec![0, 1, 1, 0];
668
669        for epoch in 0..100 {
670            tm.partial_fit_batch(&x, &y, 42 + epoch, true);
671        }
672
673        assert!(tm.evaluate(&x, &y) >= 0.5);
674    }
675
676    #[test]
677    fn partial_fit_empty_batch() {
678        let config = Config::builder().clauses(10).features(2).build().unwrap();
679        let mut tm = TsetlinMachine::new(config, 5);
680
681        tm.partial_fit_batch(&[], &[], 42, true);
682        assert_eq!(tm.clauses().len(), 10);
683    }
684
685    #[test]
686    fn partial_fit_streaming_simulation() {
687        let config = Config::builder().clauses(20).features(2).build().unwrap();
688        let mut tm = TsetlinMachine::new(config, 10);
689
690        let samples = [
691            (vec![0, 0], 0u8),
692            (vec![0, 1], 1),
693            (vec![1, 0], 1),
694            (vec![1, 1], 0)
695        ];
696
697        for (seed, (x, y)) in samples.iter().cycle().take(400).enumerate() {
698            tm.partial_fit(x, *y, seed as u64);
699        }
700
701        let x: Vec<Vec<u8>> = samples.iter().map(|(x, _)| x.clone()).collect();
702        let y: Vec<u8> = samples.iter().map(|(_, y)| *y).collect();
703        assert!(tm.evaluate(&x, &y) >= 0.5);
704    }
705}