Skip to main content

tsetlin_rs/
bitwise.rs

1//! Bitwise clause evaluation for massive speedup.
2//!
3//! Processes 64 features per CPU instruction using bitmasks.
4
5#[cfg(not(feature = "std"))]
6use alloc::{vec, vec::Vec};
7
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11use crate::Automaton;
12
13/// # Overview
14///
15/// Bitwise clause using packed u64 bitmasks.
16///
17/// Processes 64 features per AND operation - up to 50x faster than scalar.
18#[derive(Debug, Clone)]
19#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
20#[repr(align(64))]
21pub struct BitwiseClause {
22    automata:   Vec<Automaton>,
23    include:    Vec<u64>,
24    negated:    Vec<u64>,
25    polarity:   i8,
26    n_features: usize,
27    dirty:      bool
28}
29
30impl BitwiseClause {
31    /// # Overview
32    ///
33    /// Creates clause with given features, states, and polarity.
34    #[must_use]
35    pub fn new(n_features: usize, n_states: i16, polarity: i8) -> Self {
36        debug_assert!(polarity == 1 || polarity == -1);
37        let n_words = n_features.div_ceil(64);
38        let automata = (0..2 * n_features)
39            .map(|_| Automaton::new(n_states))
40            .collect();
41
42        Self {
43            automata,
44            include: vec![0; n_words],
45            negated: vec![0; n_words],
46            polarity,
47            n_features,
48            dirty: true
49        }
50    }
51
52    #[inline(always)]
53    #[must_use]
54    pub const fn polarity(&self) -> i8 {
55        self.polarity
56    }
57
58    #[inline(always)]
59    #[must_use]
60    pub const fn n_features(&self) -> usize {
61        self.n_features
62    }
63
64    #[inline(always)]
65    #[must_use]
66    pub fn automata(&self) -> &[Automaton] {
67        &self.automata
68    }
69
70    #[inline(always)]
71    pub fn automata_mut(&mut self) -> &mut [Automaton] {
72        self.dirty = true;
73        &mut self.automata
74    }
75
76    /// # Overview
77    ///
78    /// Rebuilds bitmasks from automaton states. Call after training.
79    pub fn rebuild_masks(&mut self) {
80        if !self.dirty {
81            return;
82        }
83
84        for word in &mut self.include {
85            *word = 0;
86        }
87        for word in &mut self.negated {
88            *word = 0;
89        }
90
91        for k in 0..self.n_features {
92            let word_idx = k / 64;
93            let bit_idx = k % 64;
94
95            if self.automata[2 * k].action() {
96                self.include[word_idx] |= 1u64 << bit_idx;
97            }
98            if self.automata[2 * k + 1].action() {
99                self.negated[word_idx] |= 1u64 << bit_idx;
100            }
101        }
102
103        self.dirty = false;
104    }
105
106    /// Evaluates clause using bitwise AND operations.
107    ///
108    /// Processes 64 features per CPU instruction for massive speedup.
109    ///
110    /// # Arguments
111    ///
112    /// * `x_packed` - Input packed as u64 words via [`pack_input`]
113    ///
114    /// # Panics
115    ///
116    /// Debug-asserts that `rebuild_masks()` was called after training.
117    #[inline]
118    #[must_use]
119    pub fn evaluate_packed(&self, x_packed: &[u64]) -> bool {
120        debug_assert!(!self.dirty, "call rebuild_masks() first");
121
122        let n_words = self.include.len().min(x_packed.len());
123
124        for i in 0..n_words {
125            // SAFETY: `i < n_words <= min(include.len(), x_packed.len())`.
126            // Therefore all three accesses are within bounds.
127            let x = unsafe { *x_packed.get_unchecked(i) };
128            let inc = unsafe { *self.include.get_unchecked(i) };
129            let neg = unsafe { *self.negated.get_unchecked(i) };
130
131            // include violation: inc & !x != 0 (required bit is 0)
132            // negated violation: neg & x != 0 (forbidden bit is 1)
133            if (inc & !x) | (neg & x) != 0 {
134                return false;
135            }
136        }
137        true
138    }
139
140    /// Returns polarity if fires, 0 otherwise.
141    #[inline(always)]
142    #[must_use]
143    pub fn vote_packed(&self, x_packed: &[u64]) -> i32 {
144        if self.evaluate_packed(x_packed) {
145            self.polarity as i32
146        } else {
147            0
148        }
149    }
150
151    /// Fallback scalar evaluation (no packing needed).
152    ///
153    /// Slower than `evaluate_packed` but works with unpacked input.
154    #[inline]
155    #[must_use]
156    pub fn evaluate(&self, x: &[u8]) -> bool {
157        let n = self.n_features.min(x.len());
158
159        for k in 0..n {
160            // SAFETY: `k < n <= self.n_features`, and `automata.len() == 2 * n_features`.
161            // Therefore `2 * k + 1 < automata.len()`.
162            let include = unsafe { self.automata.get_unchecked(2 * k).action() };
163            let negated = unsafe { self.automata.get_unchecked(2 * k + 1).action() };
164
165            // SAFETY: `k < n <= x.len()`, so `k` is in bounds.
166            let xk = unsafe { *x.get_unchecked(k) };
167
168            if include && xk == 0 {
169                return false;
170            }
171            if negated && xk == 1 {
172                return false;
173            }
174        }
175        true
176    }
177}
178
179/// # Overview
180///
181/// Packs binary input into u64 words for bitwise evaluation.
182#[inline]
183#[must_use]
184pub fn pack_input(x: &[u8]) -> Vec<u64> {
185    let n_words = x.len().div_ceil(64);
186    let mut packed = vec![0u64; n_words];
187
188    for (k, &xk) in x.iter().enumerate() {
189        if xk != 0 {
190            packed[k / 64] |= 1u64 << (k % 64);
191        }
192    }
193
194    packed
195}
196
197/// # Overview
198///
199/// Packs multiple inputs for batch processing.
200#[inline]
201#[must_use]
202pub fn pack_batch(xs: &[Vec<u8>]) -> Vec<Vec<u64>> {
203    xs.iter().map(|x| pack_input(x)).collect()
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn pack_input_basic() {
212        let x = vec![1, 0, 1, 1, 0, 0, 0, 1];
213        let packed = pack_input(&x);
214
215        assert_eq!(packed.len(), 1);
216        assert_eq!(packed[0], 0b10001101); // bits 0,2,3,7 set
217    }
218
219    #[test]
220    fn bitwise_evaluate_empty() {
221        let mut c = BitwiseClause::new(64, 100, 1);
222        c.rebuild_masks();
223
224        let x_packed = vec![0xFFFF_FFFF_FFFF_FFFFu64];
225        assert!(c.evaluate_packed(&x_packed));
226    }
227
228    #[test]
229    fn bitwise_evaluate_violation() {
230        let mut c = BitwiseClause::new(64, 100, 1);
231
232        // Force include[0] to be active
233        for _ in 0..200 {
234            c.automata_mut()[0].increment();
235        }
236        c.rebuild_masks();
237
238        // x[0] = 0, should violate
239        let x_packed = vec![0u64];
240        assert!(!c.evaluate_packed(&x_packed));
241
242        // x[0] = 1, should pass
243        let x_packed = vec![1u64];
244        assert!(c.evaluate_packed(&x_packed));
245    }
246
247    #[test]
248    fn bitwise_clause_accessors() {
249        let c = BitwiseClause::new(128, 100, -1);
250
251        assert_eq!(c.polarity(), -1);
252        assert_eq!(c.n_features(), 128);
253        assert_eq!(c.automata().len(), 256); // 2 * n_features
254    }
255
256    #[test]
257    fn bitwise_automata_mut_sets_dirty() {
258        let mut c = BitwiseClause::new(64, 100, 1);
259        c.rebuild_masks();
260
261        // Access automata_mut sets dirty flag
262        let _ = c.automata_mut();
263
264        // Rebuild should execute (not early exit due to dirty=true)
265        c.rebuild_masks();
266        // No assertion needed - just verify it doesn't panic
267    }
268
269    #[test]
270    fn bitwise_vote_packed() {
271        let mut c = BitwiseClause::new(64, 100, 1);
272        c.rebuild_masks(); // Empty clause fires
273
274        // Empty clause fires -> returns polarity (1)
275        assert_eq!(c.vote_packed(&[0u64]), 1);
276
277        // Force include[0] active
278        for _ in 0..200 {
279            c.automata_mut()[0].increment();
280        }
281        c.rebuild_masks();
282
283        // Violation: x[0]=0 -> returns 0
284        assert_eq!(c.vote_packed(&[0u64]), 0);
285
286        // Fires: x[0]=1 -> returns polarity (1)
287        assert_eq!(c.vote_packed(&[1u64]), 1);
288    }
289
290    #[test]
291    fn bitwise_vote_packed_negative_polarity() {
292        let mut c = BitwiseClause::new(64, 100, -1);
293        c.rebuild_masks();
294
295        // Empty clause fires -> returns polarity (-1)
296        assert_eq!(c.vote_packed(&[0u64]), -1);
297    }
298
299    #[test]
300    fn bitwise_evaluate_scalar() {
301        let mut c = BitwiseClause::new(4, 100, 1);
302
303        // Force include[0] and negated[2] active
304        for _ in 0..200 {
305            c.automata_mut()[0].increment(); // include[0]
306            c.automata_mut()[5].increment(); // negated[2]
307        }
308
309        // x[0]=1, x[2]=0 -> should fire
310        assert!(c.evaluate(&[1, 0, 0, 0]));
311
312        // x[0]=0 -> include violation
313        assert!(!c.evaluate(&[0, 0, 0, 0]));
314
315        // x[2]=1 -> negated violation
316        assert!(!c.evaluate(&[1, 0, 1, 0]));
317    }
318
319    #[test]
320    fn bitwise_evaluate_scalar_empty() {
321        let c = BitwiseClause::new(4, 100, 1);
322        // Empty clause always fires
323        assert!(c.evaluate(&[0, 0, 0, 0]));
324        assert!(c.evaluate(&[1, 1, 1, 1]));
325    }
326
327    #[test]
328    fn pack_batch_multiple() {
329        let xs = vec![vec![1, 0, 0, 0], vec![0, 1, 0, 0], vec![1, 1, 0, 0]];
330        let packed = pack_batch(&xs);
331
332        assert_eq!(packed.len(), 3);
333        assert_eq!(packed[0][0], 0b0001); // bit 0 set
334        assert_eq!(packed[1][0], 0b0010); // bit 1 set
335        assert_eq!(packed[2][0], 0b0011); // bits 0,1 set
336    }
337
338    #[test]
339    fn pack_input_large() {
340        // 128 features = 2 words
341        let mut x = vec![0u8; 128];
342        x[0] = 1;
343        x[63] = 1;
344        x[64] = 1;
345        x[127] = 1;
346
347        let packed = pack_input(&x);
348        assert_eq!(packed.len(), 2);
349        assert_eq!(packed[0], 1u64 | (1u64 << 63)); // bits 0, 63
350        assert_eq!(packed[1], 1u64 | (1u64 << 63)); // bits 64, 127 (relative to word)
351    }
352
353    #[test]
354    fn bitwise_negated_violation() {
355        let mut c = BitwiseClause::new(64, 100, 1);
356
357        // Force negated[0] active
358        for _ in 0..200 {
359            c.automata_mut()[1].increment();
360        }
361        c.rebuild_masks();
362
363        // x[0]=1 -> negated violation
364        assert!(!c.evaluate_packed(&[1u64]));
365
366        // x[0]=0 -> fires
367        assert!(c.evaluate_packed(&[0u64]));
368    }
369
370    #[test]
371    fn bitwise_multi_word() {
372        let mut c = BitwiseClause::new(128, 100, 1);
373
374        // Force include[64] active (second word)
375        for _ in 0..200 {
376            c.automata_mut()[128].increment(); // automata[2*64] = include for feature 64
377        }
378        c.rebuild_masks();
379
380        // Need bit 64 set (which is bit 0 of word 1)
381        let x_packed = vec![0u64, 1u64];
382        assert!(c.evaluate_packed(&x_packed));
383
384        // Without bit 64 -> violation
385        let x_packed = vec![0u64, 0u64];
386        assert!(!c.evaluate_packed(&x_packed));
387    }
388}