Skip to main content

tsetlin_rs/
config.rs

1//! Configuration and builder for Tsetlin Machine.
2
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5
6use crate::error::{Error, Result};
7
8/// # Overview
9///
10/// Configuration parameters for a Tsetlin Machine.
11#[derive(Debug, Clone, Copy, PartialEq)]
12#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
13#[must_use]
14pub struct Config {
15    pub n_clauses:  usize,
16    pub n_features: usize,
17    pub n_states:   i16,
18    pub s:          f32
19}
20
21impl Config {
22    /// # Overview
23    ///
24    /// Creates a new ConfigBuilder.
25    #[inline]
26    #[must_use]
27    pub fn builder() -> ConfigBuilder {
28        ConfigBuilder::default()
29    }
30
31    /// # Overview
32    ///
33    /// Validates configuration parameters.
34    pub fn validate(&self) -> Result<()> {
35        if self.n_clauses == 0 {
36            return Err(Error::MissingClauses);
37        }
38        if !self.n_clauses.is_multiple_of(2) {
39            return Err(Error::OddClauses);
40        }
41        if self.n_features == 0 {
42            return Err(Error::MissingFeatures);
43        }
44        if self.s <= 1.0 {
45            return Err(Error::InvalidSpecificity);
46        }
47        Ok(())
48    }
49
50    /// # Overview
51    ///
52    /// Pre-computed probability for strengthening: (s-1)/s.
53    #[inline]
54    #[must_use]
55    pub fn prob_strengthen(&self) -> f32 {
56        (self.s - 1.0) / self.s
57    }
58
59    /// # Overview
60    ///
61    /// Pre-computed probability for weakening: 1/s.
62    #[inline]
63    #[must_use]
64    pub fn prob_weaken(&self) -> f32 {
65        1.0 / self.s
66    }
67
68    /// # Overview
69    ///
70    /// Pre-computed integer threshold for strengthening.
71    ///
72    /// Converts float probability to u32 for faster comparison:
73    /// `rng.random::<u32>() < threshold` is ~2x faster than float comparison.
74    #[inline]
75    #[must_use]
76    pub fn threshold_strengthen(&self) -> u32 {
77        prob_to_threshold(self.prob_strengthen())
78    }
79
80    /// # Overview
81    ///
82    /// Pre-computed integer threshold for weakening.
83    #[inline]
84    #[must_use]
85    pub fn threshold_weaken(&self) -> u32 {
86        prob_to_threshold(self.prob_weaken())
87    }
88}
89
90/// Converts probability [0.0, 1.0] to integer threshold for fast comparison.
91///
92/// Usage: `rng.random::<u32>() < threshold` is equivalent to
93/// `rng.random::<f32>() < probability` but ~2x faster.
94#[inline]
95#[must_use]
96pub fn prob_to_threshold(prob: f32) -> u32 {
97    (prob as f64 * u32::MAX as f64) as u32
98}
99
100/// # Overview
101///
102/// Builder for Config with validation.
103#[derive(Debug, Default)]
104pub struct ConfigBuilder {
105    n_clauses:  Option<usize>,
106    n_features: Option<usize>,
107    n_states:   Option<i16>,
108    s:          Option<f32>
109}
110
111impl ConfigBuilder {
112    /// # Overview
113    ///
114    /// Sets the number of clauses (must be even).
115    pub fn clauses(mut self, n: usize) -> Self {
116        self.n_clauses = Some(n);
117        self
118    }
119
120    /// # Overview
121    ///
122    /// Sets the number of input features.
123    pub fn features(mut self, n: usize) -> Self {
124        self.n_features = Some(n);
125        self
126    }
127
128    /// # Overview
129    ///
130    /// Sets states per automaton action (default: 100).
131    pub fn states(mut self, n: i16) -> Self {
132        self.n_states = Some(n);
133        self
134    }
135
136    /// # Overview
137    ///
138    /// Sets specificity parameter s (default: 3.9).
139    pub fn specificity(mut self, s: f32) -> Self {
140        self.s = Some(s);
141        self
142    }
143
144    /// # Overview
145    ///
146    /// Builds and validates the Config.
147    pub fn build(self) -> Result<Config> {
148        let config = Config {
149            n_clauses:  self.n_clauses.ok_or(Error::MissingClauses)?,
150            n_features: self.n_features.ok_or(Error::MissingFeatures)?,
151            n_states:   self.n_states.unwrap_or(100),
152            s:          self.s.unwrap_or(3.9)
153        };
154        config.validate()?;
155        Ok(config)
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn builder_with_defaults() {
165        let config = Config::builder().clauses(20).features(4).build().unwrap();
166
167        assert_eq!(config.n_clauses, 20);
168        assert_eq!(config.n_features, 4);
169        assert_eq!(config.n_states, 100);
170        assert!((config.s - 3.9).abs() < 0.01);
171    }
172
173    #[test]
174    fn builder_rejects_odd_clauses() {
175        let result = Config::builder().clauses(21).features(4).build();
176
177        assert_eq!(result, Err(Error::OddClauses));
178    }
179
180    #[test]
181    fn prob_precomputed() {
182        let config = Config::builder()
183            .clauses(20)
184            .features(4)
185            .specificity(4.0)
186            .build()
187            .unwrap();
188
189        assert!((config.prob_strengthen() - 0.75).abs() < 0.001);
190        assert!((config.prob_weaken() - 0.25).abs() < 0.001);
191    }
192
193    #[test]
194    fn integer_thresholds() {
195        let config = Config::builder()
196            .clauses(20)
197            .features(4)
198            .specificity(4.0)
199            .build()
200            .unwrap();
201
202        // 0.75 * u32::MAX ≈ 3221225471
203        assert!(config.threshold_strengthen() > 3_000_000_000);
204        assert!(config.threshold_strengthen() < 3_500_000_000);
205
206        // 0.25 * u32::MAX ≈ 1073741823
207        assert!(config.threshold_weaken() > 1_000_000_000);
208        assert!(config.threshold_weaken() < 1_200_000_000);
209    }
210
211    #[test]
212    fn prob_to_threshold_boundaries() {
213        assert_eq!(prob_to_threshold(0.0), 0);
214        assert_eq!(prob_to_threshold(1.0), u32::MAX);
215        assert_eq!(prob_to_threshold(0.5), u32::MAX / 2);
216    }
217}