Skip to main content

tsetlin_rs/
model.rs

1//! Unified trait for all Tsetlin Machine variants.
2
3#[cfg(not(feature = "std"))]
4use alloc::vec::Vec;
5
6/// Unified interface for all Tsetlin Machine variants.
7///
8/// This trait provides a common API for binary, multi-class,
9/// regression, and convolutional Tsetlin Machines.
10///
11/// # Type Parameters
12///
13/// * `X` - Input sample type (typically `Vec<u8>` for binary features)
14/// * `Y` - Label type (varies by model: `u8`, `usize`, `i32`)
15///
16/// # Example
17///
18/// ```
19/// use tsetlin_rs::{Config, TsetlinMachine, TsetlinModel};
20///
21/// let config = Config::builder().clauses(20).features(2).build().unwrap();
22/// let mut tm = TsetlinMachine::new(config, 10);
23///
24/// let x = vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1]];
25/// let y = vec![0, 1, 1, 0];
26///
27/// tm.fit(&x, &y, 100, 42);
28/// let accuracy = tm.evaluate(&x, &y);
29/// ```
30pub trait TsetlinModel<X, Y> {
31    /// Trains the model on labeled data.
32    ///
33    /// # Arguments
34    ///
35    /// * `x` - Training samples
36    /// * `y` - Labels for each sample
37    /// * `epochs` - Number of training iterations
38    /// * `seed` - Random seed for reproducibility
39    fn fit(&mut self, x: &[X], y: &[Y], epochs: usize, seed: u64);
40
41    /// Predicts label for a single sample.
42    fn predict(&self, x: &X) -> Y;
43
44    /// Evaluates model accuracy/performance on test data.
45    ///
46    /// Returns a score between 0.0 and 1.0 (higher is better).
47    fn evaluate(&self, x: &[X], y: &[Y]) -> f32;
48
49    /// Batch prediction for multiple samples.
50    fn predict_batch(&self, xs: &[X]) -> Vec<Y> {
51        xs.iter().map(|x| self.predict(x)).collect()
52    }
53}
54
55/// Extension trait for models with vote-based predictions.
56pub trait VotingModel<X>: TsetlinModel<X, u8> {
57    /// Returns raw vote sum for input.
58    fn sum_votes(&self, x: &X) -> f32;
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64
65    struct MockModel;
66
67    impl TsetlinModel<u8, u8> for MockModel {
68        fn fit(&mut self, _x: &[u8], _y: &[u8], _epochs: usize, _seed: u64) {}
69
70        fn predict(&self, x: &u8) -> u8 {
71            *x % 2
72        }
73
74        fn evaluate(&self, x: &[u8], y: &[u8]) -> f32 {
75            let correct = x
76                .iter()
77                .zip(y)
78                .filter(|(xi, yi)| self.predict(xi) == **yi)
79                .count();
80            correct as f32 / x.len() as f32
81        }
82    }
83
84    #[test]
85    fn predict_batch_default_impl() {
86        let model = MockModel;
87        let xs = vec![0, 1, 2, 3, 4];
88        let preds = model.predict_batch(&xs);
89        assert_eq!(preds, vec![0, 1, 0, 1, 0]);
90    }
91
92    #[test]
93    fn mock_model_fit() {
94        let mut model = MockModel;
95        model.fit(&[1, 2, 3], &[0, 1, 0], 10, 42);
96        // fit is no-op, just verify it doesn't panic
97    }
98
99    #[test]
100    fn mock_model_evaluate() {
101        let model = MockModel;
102        // predict(x) = x % 2, so:
103        // x=0 -> 0, y=0 -> correct
104        // x=1 -> 1, y=1 -> correct
105        // x=2 -> 0, y=0 -> correct
106        let acc = model.evaluate(&[0, 1, 2], &[0, 1, 0]);
107        assert!((acc - 1.0).abs() < 0.001);
108
109        // 50% accuracy case
110        let acc2 = model.evaluate(&[0, 1, 2, 3], &[1, 0, 1, 0]);
111        assert!((acc2 - 0.0).abs() < 0.001);
112    }
113
114    struct MockVotingModel;
115
116    impl TsetlinModel<u8, u8> for MockVotingModel {
117        fn fit(&mut self, _x: &[u8], _y: &[u8], _epochs: usize, _seed: u64) {}
118
119        fn predict(&self, x: &u8) -> u8 {
120            if self.sum_votes(x) >= 0.0 { 1 } else { 0 }
121        }
122
123        fn evaluate(&self, x: &[u8], y: &[u8]) -> f32 {
124            let correct = x
125                .iter()
126                .zip(y)
127                .filter(|(xi, yi)| self.predict(xi) == **yi)
128                .count();
129            correct as f32 / x.len() as f32
130        }
131    }
132
133    impl VotingModel<u8> for MockVotingModel {
134        fn sum_votes(&self, x: &u8) -> f32 {
135            (*x as f32) - 2.0 // returns negative for x < 2, positive for x >= 2
136        }
137    }
138
139    #[test]
140    fn voting_model_sum_votes() {
141        let model = MockVotingModel;
142
143        assert!((model.sum_votes(&0) - (-2.0)).abs() < 0.001);
144        assert!((model.sum_votes(&2) - 0.0).abs() < 0.001);
145        assert!((model.sum_votes(&5) - 3.0).abs() < 0.001);
146    }
147
148    #[test]
149    fn voting_model_predict_uses_votes() {
150        let model = MockVotingModel;
151
152        // x < 2: negative votes -> predict 0
153        assert_eq!(model.predict(&0), 0);
154        assert_eq!(model.predict(&1), 0);
155
156        // x >= 2: non-negative votes -> predict 1
157        assert_eq!(model.predict(&2), 1);
158        assert_eq!(model.predict(&5), 1);
159    }
160
161    #[test]
162    fn voting_model_evaluate() {
163        let model = MockVotingModel;
164        let xs = vec![0, 1, 2, 3];
165        let ys = vec![0, 0, 1, 1];
166        let acc = model.evaluate(&xs, &ys);
167        assert!((acc - 1.0).abs() < 0.001);
168    }
169}