Skip to main content

tsetlin_rs/
sparse.rs

1//! Sparse clause representation for memory-efficient inference.
2//!
3//! This module provides [`SparseClause`] and [`SparseClauseBank`] for
4//! 5-100x memory reduction in trained Tsetlin Machine models.
5//!
6//! # Background
7//!
8//! After training, most automata in a clause are in the "exclude" state
9//! (action = false). A typical trained clause contains only 10-50 active
10//! literals out of 1000+ possible features. Storing all automata wastes
11//! 95-99% of memory.
12//!
13//! # Architecture
14//!
15//! This module implements sparse representations based on research:
16//! - [Sparse TM with Active Literals (arXiv:2405.02375)](https://arxiv.org/abs/2405.02375)
17//! - [Contracting TM with Absorbing Automata (arXiv:2310.11481)](https://arxiv.org/abs/2310.11481)
18//!
19//! ```text
20//! Training:  ClauseBank (dense SoA)
21//!                │
22//!                ▼ .to_sparse()
23//!                │
24//! Inference: SparseClauseBank (CSR)  ──► 5-100x memory reduction
25//! ```
26//!
27//! # Data Structures
28//!
29//! | Type | Format | Use Case |
30//! |------|--------|----------|
31//! | [`SparseClause`] | SmallVec | Single clause, inline for ≤32 literals |
32//! | [`SparseClauseBank`] | CSR | Batch inference, cache-friendly |
33//!
34//! # Example
35//!
36//! ```
37//! use tsetlin_rs::{ClauseBank, SparseClauseBank};
38//!
39//! // Train with dense representation
40//! let bank = ClauseBank::new(100, 1000, 100);
41//! // ... training ...
42//!
43//! // Convert to sparse for deployment
44//! let sparse = SparseClauseBank::from_clause_bank(&bank);
45//!
46//! // Memory reduction
47//! let stats = sparse.memory_stats();
48//! println!("Compression: {}x", stats.compression_ratio(1000));
49//! ```
50
51#[cfg(not(feature = "std"))]
52use alloc::{vec, vec::Vec};
53
54#[cfg(feature = "serde")]
55use serde::{Deserialize, Serialize};
56use smallvec::SmallVec;
57
58use crate::{Clause, ClauseBank};
59
60/// Inline capacity for typical clause (32 literals = 64 bytes).
61///
62/// Research shows trained clauses typically contain 10-50 active literals.
63/// SmallVec avoids heap allocation for this common case.
64const INLINE_CAPACITY: usize = 32;
65
66/// Sparse clause representation storing only active literal indices.
67///
68/// Achieves 5-100x memory reduction compared to dense [`Clause`] by storing
69/// only indices of features that affect evaluation.
70///
71/// # Memory Layout
72///
73/// Typical clause with 20 active literals:
74/// - `include_indices`: 32 × 2 = 64 bytes (inline SmallVec)
75/// - `negated_indices`: 32 × 2 = 64 bytes (inline SmallVec)
76/// - `weight`: 4 bytes
77/// - `polarity`: 1 byte
78/// - **Total: ~133 bytes** vs ~8000 bytes for dense (1000 features)
79///
80/// # Performance
81///
82/// Uses early-exit evaluation: returns `false` on first violation.
83/// For sparse input data, this is often faster than dense bitmask evaluation.
84///
85/// # Example
86///
87/// ```
88/// use tsetlin_rs::{Clause, SparseClause};
89///
90/// let mut clause = Clause::new(100, 100, 1);
91/// // ... train clause ...
92///
93/// let sparse = SparseClause::from_clause(&clause);
94/// assert!(sparse.memory_usage() < 200); // vs 800+ bytes dense
95/// ```
96#[derive(Debug, Clone)]
97#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
98pub struct SparseClause {
99    /// Indices of features where `x[k] = 1` is required (include literals).
100    include_indices: SmallVec<[u16; INLINE_CAPACITY]>,
101
102    /// Indices of features where `x[k] = 0` is required (negated literals).
103    negated_indices: SmallVec<[u16; INLINE_CAPACITY]>,
104
105    /// Clause weight for weighted voting.
106    weight: f32,
107
108    /// Voting polarity (+1 or -1).
109    polarity: i8
110}
111
112impl SparseClause {
113    /// Creates a sparse clause from a dense [`Clause`] by extracting active
114    /// literals.
115    ///
116    /// Scans automata and records indices where `action() == true`.
117    ///
118    /// # Arguments
119    ///
120    /// * `clause` - Dense clause to convert
121    ///
122    /// # Example
123    ///
124    /// ```
125    /// use tsetlin_rs::{Clause, SparseClause};
126    ///
127    /// let clause = Clause::new(100, 100, 1);
128    /// let sparse = SparseClause::from_clause(&clause);
129    /// assert_eq!(sparse.n_literals(), 0); // Fresh clause has no active literals
130    /// ```
131    #[must_use]
132    pub fn from_clause(clause: &Clause) -> Self {
133        let mut include = SmallVec::new();
134        let mut negated = SmallVec::new();
135
136        for (k, pair) in clause.automata().chunks(2).enumerate() {
137            if pair[0].action() {
138                include.push(k as u16);
139            }
140            if pair[1].action() {
141                negated.push(k as u16);
142            }
143        }
144
145        Self {
146            include_indices: include,
147            negated_indices: negated,
148            weight:          clause.weight(),
149            polarity:        clause.polarity()
150        }
151    }
152
153    /// Creates a sparse clause from raw components.
154    ///
155    /// # Arguments
156    ///
157    /// * `include` - Feature indices requiring `x[k] = 1`
158    /// * `negated` - Feature indices requiring `x[k] = 0`
159    /// * `weight` - Clause weight
160    /// * `polarity` - Vote direction (+1 or -1)
161    #[must_use]
162    pub fn new(include: &[u16], negated: &[u16], weight: f32, polarity: i8) -> Self {
163        Self {
164            include_indices: SmallVec::from_slice(include),
165            negated_indices: SmallVec::from_slice(negated),
166            weight,
167            polarity
168        }
169    }
170
171    /// Returns clause polarity (+1 or -1).
172    #[inline(always)]
173    #[must_use]
174    pub const fn polarity(&self) -> i8 {
175        self.polarity
176    }
177
178    /// Returns clause weight.
179    #[inline(always)]
180    #[must_use]
181    pub const fn weight(&self) -> f32 {
182        self.weight
183    }
184
185    /// Returns indices of include literals.
186    #[inline(always)]
187    #[must_use]
188    pub fn include_indices(&self) -> &[u16] {
189        &self.include_indices
190    }
191
192    /// Returns indices of negated literals.
193    #[inline(always)]
194    #[must_use]
195    pub fn negated_indices(&self) -> &[u16] {
196        &self.negated_indices
197    }
198
199    /// Evaluates clause with early exit on first violation.
200    ///
201    /// Returns `true` if all conditions are satisfied:
202    /// - For each index in `include_indices`: `x[idx] == 1`
203    /// - For each index in `negated_indices`: `x[idx] == 0`
204    ///
205    /// # Arguments
206    ///
207    /// * `x` - Binary input vector
208    ///
209    /// # Safety
210    ///
211    /// Uses unchecked indexing for performance. Caller must ensure all
212    /// stored indices are within `x.len()`.
213    #[inline]
214    #[must_use]
215    pub fn evaluate(&self, x: &[u8]) -> bool {
216        for &idx in &self.include_indices {
217            // SAFETY: caller ensures idx < x.len()
218            if unsafe { *x.get_unchecked(idx as usize) } == 0 {
219                return false;
220            }
221        }
222        for &idx in &self.negated_indices {
223            // SAFETY: caller ensures idx < x.len()
224            if unsafe { *x.get_unchecked(idx as usize) } == 1 {
225                return false;
226            }
227        }
228        true
229    }
230
231    /// Evaluates clause with bounds checking.
232    ///
233    /// Safe version of [`evaluate`](Self::evaluate) that performs bounds
234    /// checks.
235    #[inline]
236    #[must_use]
237    pub fn evaluate_checked(&self, x: &[u8]) -> bool {
238        for &idx in &self.include_indices {
239            if x.get(idx as usize).copied().unwrap_or(0) == 0 {
240                return false;
241            }
242        }
243        for &idx in &self.negated_indices {
244            if x.get(idx as usize).copied().unwrap_or(0) == 1 {
245                return false;
246            }
247        }
248        true
249    }
250
251    /// Evaluates using packed u64 input (64 features per word).
252    ///
253    /// Optimized for cases where input is already packed. Uses bit extraction
254    /// instead of array indexing.
255    ///
256    /// # Arguments
257    ///
258    /// * `x` - Binary input packed as u64 words
259    #[inline]
260    #[must_use]
261    pub fn evaluate_packed(&self, x: &[u64]) -> bool {
262        for &idx in &self.include_indices {
263            let word = idx as usize >> 6; // / 64
264            let bit = idx as usize & 63; // % 64
265            // SAFETY: caller ensures sufficient words
266            if unsafe { *x.get_unchecked(word) } & (1u64 << bit) == 0 {
267                return false;
268            }
269        }
270        for &idx in &self.negated_indices {
271            let word = idx as usize >> 6;
272            let bit = idx as usize & 63;
273            if unsafe { *x.get_unchecked(word) } & (1u64 << bit) != 0 {
274                return false;
275            }
276        }
277        true
278    }
279
280    /// Returns weighted vote: `polarity × weight` if clause fires, `0.0`
281    /// otherwise.
282    #[inline]
283    #[must_use]
284    pub fn vote(&self, x: &[u8]) -> f32 {
285        if self.evaluate(x) {
286            self.polarity as f32 * self.weight
287        } else {
288            0.0
289        }
290    }
291
292    /// Returns unweighted vote: `polarity` if clause fires, `0` otherwise.
293    #[inline]
294    #[must_use]
295    pub fn vote_unweighted(&self, x: &[u8]) -> i32 {
296        if self.evaluate(x) {
297            self.polarity as i32
298        } else {
299            0
300        }
301    }
302
303    /// Returns approximate memory usage in bytes.
304    ///
305    /// Accounts for SmallVec inline vs heap allocation.
306    #[must_use]
307    pub fn memory_usage(&self) -> usize {
308        let base = core::mem::size_of::<Self>();
309        let include_heap = if self.include_indices.spilled() {
310            self.include_indices.capacity() * 2
311        } else {
312            0
313        };
314        let negated_heap = if self.negated_indices.spilled() {
315            self.negated_indices.capacity() * 2
316        } else {
317            0
318        };
319        base + include_heap + negated_heap
320    }
321
322    /// Returns total number of active literals.
323    #[inline]
324    #[must_use]
325    pub fn n_literals(&self) -> usize {
326        self.include_indices.len() + self.negated_indices.len()
327    }
328}
329
330/// Sparse clause bank using CSR (Compressed Sparse Row) format.
331///
332/// Optimized for batch inference with cache-friendly memory access.
333/// Stores all active literals in contiguous arrays with offset indexing.
334///
335/// # Memory Layout
336///
337/// ```text
338/// Dense ClauseBank (100 clauses, 1000 features):
339///   states: 100 × 2 × 1000 × 2 = 400,000 bytes
340///
341/// SparseClauseBank (avg 30 literals/clause):
342///   include_indices: ~1500 × 2 = 3,000 bytes
343///   include_offsets: 101 × 4 = 404 bytes
344///   negated_indices: ~1500 × 2 = 3,000 bytes
345///   negated_offsets: 101 × 4 = 404 bytes
346///   weights: 100 × 4 = 400 bytes
347///   polarities: 100 bytes
348///   Total: ~7,300 bytes
349///
350/// Memory reduction: 55x
351/// ```
352///
353/// # CSR Format
354///
355/// For clause `c`, active include literals are at:
356/// `include_indices[include_offsets[c]..include_offsets[c+1]]`
357///
358/// This format provides:
359/// - O(1) clause lookup
360/// - Sequential memory access (cache-friendly)
361/// - Minimal overhead per clause (4 bytes offset)
362///
363/// # Example
364///
365/// ```
366/// use tsetlin_rs::{ClauseBank, SparseClauseBank};
367///
368/// let bank = ClauseBank::new(100, 1000, 100);
369/// let sparse = SparseClauseBank::from_clause_bank(&bank);
370///
371/// let stats = sparse.memory_stats();
372/// println!("Total literals: {}", stats.total_literals);
373/// println!("Memory: {} bytes", stats.total());
374/// ```
375#[derive(Debug, Clone)]
376#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
377pub struct SparseClauseBank {
378    /// CSR data: active include literal indices for all clauses.
379    include_indices: Vec<u16>,
380
381    /// CSR offsets: `include_indices[offsets[c]..offsets[c+1]]` for clause `c`.
382    include_offsets: Vec<u32>,
383
384    /// CSR data: active negated literal indices.
385    negated_indices: Vec<u16>,
386
387    /// CSR offsets for negated literals.
388    negated_offsets: Vec<u32>,
389
390    /// Clause weights.
391    weights: Vec<f32>,
392
393    /// Clause polarities (+1 or -1).
394    polarities: Vec<i8>,
395
396    /// Number of clauses.
397    n_clauses: usize,
398
399    /// Number of features (for documentation/validation).
400    n_features: usize
401}
402
403impl SparseClauseBank {
404    /// Converts dense [`ClauseBank`] to sparse CSR format.
405    ///
406    /// Scans all clauses and extracts active literals (automata with
407    /// `state > n_states`).
408    ///
409    /// # Arguments
410    ///
411    /// * `bank` - Dense clause bank to convert
412    ///
413    /// # Example
414    ///
415    /// ```
416    /// use tsetlin_rs::{ClauseBank, SparseClauseBank};
417    ///
418    /// let bank = ClauseBank::new(100, 64, 100);
419    /// let sparse = SparseClauseBank::from_clause_bank(&bank);
420    /// assert_eq!(sparse.n_clauses(), 100);
421    /// ```
422    #[must_use]
423    pub fn from_clause_bank(bank: &ClauseBank) -> Self {
424        let mut include_indices = Vec::new();
425        let mut include_offsets = vec![0u32];
426        let mut negated_indices = Vec::new();
427        let mut negated_offsets = vec![0u32];
428
429        let threshold = bank.n_states();
430
431        for c in 0..bank.n_clauses() {
432            let states = bank.clause_states(c);
433
434            for (k, pair) in states.chunks(2).enumerate() {
435                if pair[0] > threshold {
436                    include_indices.push(k as u16);
437                }
438                if pair[1] > threshold {
439                    negated_indices.push(k as u16);
440                }
441            }
442
443            include_offsets.push(include_indices.len() as u32);
444            negated_offsets.push(negated_indices.len() as u32);
445        }
446
447        Self {
448            include_indices,
449            include_offsets,
450            negated_indices,
451            negated_offsets,
452            weights: bank.weights().to_vec(),
453            polarities: bank.polarities().to_vec(),
454            n_clauses: bank.n_clauses(),
455            n_features: bank.n_features()
456        }
457    }
458
459    /// Creates from a vector of [`SparseClause`].
460    ///
461    /// Useful when building sparse representation incrementally or from
462    /// non-ClauseBank sources.
463    #[must_use]
464    pub fn from_clauses(clauses: &[SparseClause], n_features: usize) -> Self {
465        let mut include_indices = Vec::new();
466        let mut include_offsets = vec![0u32];
467        let mut negated_indices = Vec::new();
468        let mut negated_offsets = vec![0u32];
469        let mut weights = Vec::with_capacity(clauses.len());
470        let mut polarities = Vec::with_capacity(clauses.len());
471
472        for clause in clauses {
473            include_indices.extend_from_slice(&clause.include_indices);
474            include_offsets.push(include_indices.len() as u32);
475
476            negated_indices.extend_from_slice(&clause.negated_indices);
477            negated_offsets.push(negated_indices.len() as u32);
478
479            weights.push(clause.weight);
480            polarities.push(clause.polarity);
481        }
482
483        Self {
484            include_indices,
485            include_offsets,
486            negated_indices,
487            negated_offsets,
488            weights,
489            polarities,
490            n_clauses: clauses.len(),
491            n_features
492        }
493    }
494
495    /// Returns number of clauses.
496    #[inline(always)]
497    #[must_use]
498    pub const fn n_clauses(&self) -> usize {
499        self.n_clauses
500    }
501
502    /// Returns number of features.
503    #[inline(always)]
504    #[must_use]
505    pub const fn n_features(&self) -> usize {
506        self.n_features
507    }
508
509    /// Returns clause weights.
510    #[inline(always)]
511    #[must_use]
512    pub fn weights(&self) -> &[f32] {
513        &self.weights
514    }
515
516    /// Returns clause polarities.
517    #[inline(always)]
518    #[must_use]
519    pub fn polarities(&self) -> &[i8] {
520        &self.polarities
521    }
522
523    /// Returns number of active literals for a specific clause.
524    #[inline]
525    #[must_use]
526    pub fn clause_n_literals(&self, clause: usize) -> usize {
527        let inc = self.include_offsets[clause + 1] - self.include_offsets[clause];
528        let neg = self.negated_offsets[clause + 1] - self.negated_offsets[clause];
529        (inc + neg) as usize
530    }
531
532    /// Evaluates single clause.
533    ///
534    /// Returns `true` if all active literal conditions are satisfied.
535    #[inline]
536    #[must_use]
537    pub fn evaluate_clause(&self, clause: usize, x: &[u8]) -> bool {
538        let inc_start = self.include_offsets[clause] as usize;
539        let inc_end = self.include_offsets[clause + 1] as usize;
540
541        for &idx in &self.include_indices[inc_start..inc_end] {
542            // SAFETY: indices were validated during construction
543            if unsafe { *x.get_unchecked(idx as usize) } == 0 {
544                return false;
545            }
546        }
547
548        let neg_start = self.negated_offsets[clause] as usize;
549        let neg_end = self.negated_offsets[clause + 1] as usize;
550
551        for &idx in &self.negated_indices[neg_start..neg_end] {
552            if unsafe { *x.get_unchecked(idx as usize) } == 1 {
553                return false;
554            }
555        }
556
557        true
558    }
559
560    /// Evaluates clause with packed u64 input.
561    ///
562    /// Optimized for pre-packed input where 64 features fit in one u64.
563    #[inline]
564    #[must_use]
565    pub fn evaluate_clause_packed(&self, clause: usize, x: &[u64]) -> bool {
566        let inc_start = self.include_offsets[clause] as usize;
567        let inc_end = self.include_offsets[clause + 1] as usize;
568
569        for &idx in &self.include_indices[inc_start..inc_end] {
570            let word = idx as usize >> 6;
571            let bit = idx as usize & 63;
572            if unsafe { *x.get_unchecked(word) } & (1u64 << bit) == 0 {
573                return false;
574            }
575        }
576
577        let neg_start = self.negated_offsets[clause] as usize;
578        let neg_end = self.negated_offsets[clause + 1] as usize;
579
580        for &idx in &self.negated_indices[neg_start..neg_end] {
581            let word = idx as usize >> 6;
582            let bit = idx as usize & 63;
583            if unsafe { *x.get_unchecked(word) } & (1u64 << bit) != 0 {
584                return false;
585            }
586        }
587
588        true
589    }
590
591    /// Sum of weighted votes for all clauses.
592    ///
593    /// Evaluates all clauses and accumulates `polarity × weight` for
594    /// firing clauses.
595    #[must_use]
596    pub fn sum_votes(&self, x: &[u8]) -> f32 {
597        let mut sum = 0.0f32;
598        for c in 0..self.n_clauses {
599            if self.evaluate_clause(c, x) {
600                // SAFETY: c < n_clauses
601                sum += unsafe {
602                    *self.polarities.get_unchecked(c) as f32 * *self.weights.get_unchecked(c)
603                };
604            }
605        }
606        sum
607    }
608
609    /// Sum of weighted votes with packed input.
610    #[must_use]
611    pub fn sum_votes_packed(&self, x: &[u64]) -> f32 {
612        let mut sum = 0.0f32;
613        for c in 0..self.n_clauses {
614            if self.evaluate_clause_packed(c, x) {
615                sum += unsafe {
616                    *self.polarities.get_unchecked(c) as f32 * *self.weights.get_unchecked(c)
617                };
618            }
619        }
620        sum
621    }
622
623    /// Sum of unweighted votes.
624    #[must_use]
625    pub fn sum_votes_unweighted(&self, x: &[u8]) -> i32 {
626        let mut sum = 0i32;
627        for c in 0..self.n_clauses {
628            if self.evaluate_clause(c, x) {
629                sum += self.polarities[c] as i32;
630            }
631        }
632        sum
633    }
634
635    /// Returns memory usage statistics.
636    #[must_use]
637    pub fn memory_stats(&self) -> SparseMemoryStats {
638        SparseMemoryStats {
639            include_data:    self.include_indices.len() * 2,
640            include_offsets: self.include_offsets.len() * 4,
641            negated_data:    self.negated_indices.len() * 2,
642            negated_offsets: self.negated_offsets.len() * 4,
643            weights:         self.weights.len() * 4,
644            polarities:      self.polarities.len(),
645            total_literals:  self.include_indices.len() + self.negated_indices.len(),
646            n_clauses:       self.n_clauses,
647            n_features:      self.n_features
648        }
649    }
650}
651
652/// Memory usage breakdown for sparse clause bank.
653///
654/// Provides detailed statistics for memory analysis and optimization decisions.
655#[derive(Debug, Clone, Copy)]
656pub struct SparseMemoryStats {
657    /// Size of include indices array in bytes.
658    pub include_data: usize,
659
660    /// Size of include offsets array in bytes.
661    pub include_offsets: usize,
662
663    /// Size of negated indices array in bytes.
664    pub negated_data: usize,
665
666    /// Size of negated offsets array in bytes.
667    pub negated_offsets: usize,
668
669    /// Size of weights array in bytes.
670    pub weights: usize,
671
672    /// Size of polarities array in bytes.
673    pub polarities: usize,
674
675    /// Total number of active literals across all clauses.
676    pub total_literals: usize,
677
678    /// Number of clauses.
679    pub n_clauses: usize,
680
681    /// Number of features.
682    pub n_features: usize
683}
684
685impl SparseMemoryStats {
686    /// Returns total memory usage in bytes.
687    #[must_use]
688    pub const fn total(&self) -> usize {
689        self.include_data
690            + self.include_offsets
691            + self.negated_data
692            + self.negated_offsets
693            + self.weights
694            + self.polarities
695    }
696
697    /// Returns average literals per clause.
698    #[must_use]
699    pub fn avg_literals_per_clause(&self) -> f32 {
700        if self.n_clauses == 0 {
701            0.0
702        } else {
703            self.total_literals as f32 / self.n_clauses as f32
704        }
705    }
706
707    /// Returns compression ratio compared to dense storage.
708    ///
709    /// # Arguments
710    ///
711    /// * `n_features` - Number of features (for dense size calculation)
712    #[must_use]
713    pub fn compression_ratio(&self, n_features: usize) -> f32 {
714        let dense_size = self.n_clauses * 2 * n_features * 2; // i16 states
715        if self.total() == 0 {
716            0.0
717        } else {
718            dense_size as f32 / self.total() as f32
719        }
720    }
721
722    /// Returns sparsity ratio (fraction of literals that are active).
723    #[must_use]
724    pub fn sparsity(&self) -> f32 {
725        let max_literals = self.n_clauses * 2 * self.n_features;
726        if max_literals == 0 {
727            0.0
728        } else {
729            self.total_literals as f32 / max_literals as f32
730        }
731    }
732}
733
734/// Sparse Tsetlin Machine for memory-efficient inference.
735///
736/// Inference-only model with 5-100x memory reduction compared to dense
737/// `TsetlinMachine`. Create via `TsetlinMachine::to_sparse()` after training.
738///
739/// # Example
740///
741/// ```
742/// use tsetlin_rs::{Config, TsetlinMachine};
743///
744/// let config = Config::builder().clauses(20).features(2).build().unwrap();
745/// let mut tm = TsetlinMachine::new(config, 10);
746///
747/// let x = vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1]];
748/// let y = vec![0, 1, 1, 0];
749///
750/// tm.fit(&x, &y, 200, 42);
751///
752/// // Convert to sparse for deployment
753/// let sparse = tm.to_sparse();
754///
755/// // Same predictions, less memory
756/// assert_eq!(tm.predict(&x[0]), sparse.predict(&x[0]));
757/// ```
758#[derive(Debug, Clone)]
759#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
760pub struct SparseTsetlinMachine {
761    clauses:   SparseClauseBank,
762    threshold: f32
763}
764
765impl SparseTsetlinMachine {
766    /// Creates from dense clauses.
767    ///
768    /// # Arguments
769    ///
770    /// * `clauses` - Dense clauses to convert
771    /// * `n_features` - Number of features
772    /// * `threshold` - Classification threshold
773    #[must_use]
774    pub fn from_clauses(clauses: &[Clause], n_features: usize, threshold: f32) -> Self {
775        let sparse_clauses: Vec<SparseClause> =
776            clauses.iter().map(SparseClause::from_clause).collect();
777
778        Self {
779            clauses: SparseClauseBank::from_clauses(&sparse_clauses, n_features),
780            threshold
781        }
782    }
783
784    /// Creates from pre-built sparse clause bank.
785    #[must_use]
786    pub fn new(clauses: SparseClauseBank, threshold: f32) -> Self {
787        Self {
788            clauses,
789            threshold
790        }
791    }
792
793    /// Returns number of clauses.
794    #[inline(always)]
795    #[must_use]
796    pub const fn n_clauses(&self) -> usize {
797        self.clauses.n_clauses()
798    }
799
800    /// Returns number of features.
801    #[inline(always)]
802    #[must_use]
803    pub const fn n_features(&self) -> usize {
804        self.clauses.n_features()
805    }
806
807    /// Returns threshold.
808    #[inline(always)]
809    #[must_use]
810    pub const fn threshold(&self) -> f32 {
811        self.threshold
812    }
813
814    /// Predicts class (0 or 1).
815    #[inline]
816    #[must_use]
817    pub fn predict(&self, x: &[u8]) -> u8 {
818        if self.clauses.sum_votes(x) >= 0.0 {
819            1
820        } else {
821            0
822        }
823    }
824
825    /// Predicts using packed u64 input.
826    #[inline]
827    #[must_use]
828    pub fn predict_packed(&self, x: &[u64]) -> u8 {
829        if self.clauses.sum_votes_packed(x) >= 0.0 {
830            1
831        } else {
832            0
833        }
834    }
835
836    /// Batch prediction.
837    #[must_use]
838    pub fn predict_batch(&self, xs: &[Vec<u8>]) -> Vec<u8> {
839        xs.iter().map(|x| self.predict(x)).collect()
840    }
841
842    /// Evaluates accuracy on test data.
843    #[must_use]
844    pub fn evaluate(&self, x: &[Vec<u8>], y: &[u8]) -> f32 {
845        if x.is_empty() {
846            return 0.0;
847        }
848        let correct = x
849            .iter()
850            .zip(y)
851            .filter(|(xi, yi)| self.predict(xi) == **yi)
852            .count();
853        correct as f32 / x.len() as f32
854    }
855
856    /// Returns memory statistics.
857    #[must_use]
858    pub fn memory_stats(&self) -> SparseMemoryStats {
859        self.clauses.memory_stats()
860    }
861
862    /// Returns compression ratio compared to dense model.
863    #[must_use]
864    pub fn compression_ratio(&self) -> f32 {
865        self.clauses
866            .memory_stats()
867            .compression_ratio(self.n_features())
868    }
869}
870
871#[cfg(test)]
872mod tests {
873    use super::*;
874
875    #[test]
876    fn sparse_clause_from_dense() {
877        let mut clause = Clause::new(10, 100, 1);
878
879        // Activate literal 0 (include) and literal 2 (negated)
880        for _ in 0..200 {
881            clause.automata_mut()[0].increment(); // include[0]
882            clause.automata_mut()[5].increment(); // negated[2]
883        }
884
885        let sparse = SparseClause::from_clause(&clause);
886        assert_eq!(sparse.include_indices.len(), 1);
887        assert_eq!(sparse.negated_indices.len(), 1);
888        assert_eq!(sparse.include_indices[0], 0);
889        assert_eq!(sparse.negated_indices[0], 2);
890        assert_eq!(sparse.polarity(), 1);
891    }
892
893    #[test]
894    fn sparse_clause_evaluate() {
895        let sparse = SparseClause::new(&[0, 2], &[1], 1.0, 1);
896
897        // x[0]=1, x[1]=0, x[2]=1 -> should fire
898        assert!(sparse.evaluate(&[1, 0, 1, 0]));
899
900        // x[0]=0 -> include violation
901        assert!(!sparse.evaluate(&[0, 0, 1, 0]));
902
903        // x[1]=1 -> negated violation
904        assert!(!sparse.evaluate(&[1, 1, 1, 0]));
905
906        // x[2]=0 -> include violation
907        assert!(!sparse.evaluate(&[1, 0, 0, 0]));
908    }
909
910    #[test]
911    fn sparse_clause_evaluate_packed() {
912        let sparse = SparseClause::new(&[0, 2], &[1], 1.0, 1);
913
914        // Packed: bits 0,2 set, bit 1 clear -> 0b101 = 5
915        assert!(sparse.evaluate_packed(&[5u64]));
916
917        // Packed: bit 0 clear -> should fail
918        assert!(!sparse.evaluate_packed(&[4u64]));
919    }
920
921    #[test]
922    fn sparse_clause_vote() {
923        let sparse = SparseClause::new(&[], &[], 2.5, -1);
924
925        // Empty clause always fires
926        assert!((sparse.vote(&[0, 1, 0]) - (-2.5)).abs() < 0.001);
927        assert_eq!(sparse.vote_unweighted(&[0, 1, 0]), -1);
928    }
929
930    #[test]
931    fn sparse_clause_memory() {
932        let sparse = SparseClause::new(&[0, 1, 2], &[3, 4], 1.0, 1);
933
934        // Should be inline (not spilled)
935        let usage = sparse.memory_usage();
936        assert!(usage < 200);
937        assert_eq!(sparse.n_literals(), 5);
938    }
939
940    #[test]
941    fn sparse_bank_from_clause_bank() {
942        let bank = ClauseBank::new(10, 100, 100);
943        let sparse = SparseClauseBank::from_clause_bank(&bank);
944
945        assert_eq!(sparse.n_clauses(), 10);
946        assert_eq!(sparse.n_features(), 100);
947
948        // Fresh bank has no active literals
949        let stats = sparse.memory_stats();
950        assert_eq!(stats.total_literals, 0);
951    }
952
953    #[test]
954    fn sparse_bank_evaluate() {
955        // Create sparse bank manually
956        let clauses = vec![
957            SparseClause::new(&[0], &[], 1.0, 1),  // requires x[0]=1
958            SparseClause::new(&[], &[0], 1.0, -1), // requires x[0]=0
959        ];
960        let sparse = SparseClauseBank::from_clauses(&clauses, 4);
961
962        // x[0]=1: clause 0 fires (+1), clause 1 fails
963        let votes = sparse.sum_votes(&[1, 0, 0, 0]);
964        assert!((votes - 1.0).abs() < 0.001);
965
966        // x[0]=0: clause 0 fails, clause 1 fires (-1)
967        let votes = sparse.sum_votes(&[0, 0, 0, 0]);
968        assert!((votes - (-1.0)).abs() < 0.001);
969    }
970
971    #[test]
972    fn sparse_bank_memory_stats() {
973        let clauses = vec![
974            SparseClause::new(&[0, 1, 2], &[3], 1.0, 1),
975            SparseClause::new(&[4, 5], &[6, 7, 8], 1.0, -1),
976        ];
977        let sparse = SparseClauseBank::from_clauses(&clauses, 100);
978
979        let stats = sparse.memory_stats();
980        assert_eq!(stats.total_literals, 9); // 3+1 + 2+3
981        assert_eq!(stats.n_clauses, 2);
982
983        // Compression ratio should be high
984        let ratio = stats.compression_ratio(100);
985        assert!(ratio > 10.0);
986    }
987
988    #[test]
989    fn sparse_bank_packed_evaluation() {
990        let clauses = vec![
991            SparseClause::new(&[0, 63], &[], 1.0, 1),  // bits 0 and 63
992            SparseClause::new(&[], &[1, 62], 1.0, -1), // not bits 1 and 62
993        ];
994        let sparse = SparseClauseBank::from_clauses(&clauses, 64);
995
996        // Packed: bits 0 and 63 set, bits 1 and 62 clear
997        let packed = 1u64 | (1u64 << 63);
998        let votes = sparse.sum_votes_packed(&[packed]);
999        assert!((votes - 0.0).abs() < 0.001); // Both fire: +1 - 1 = 0
1000    }
1001
1002    #[test]
1003    fn sparse_clause_accessors() {
1004        let sparse = SparseClause::new(&[1, 3, 5], &[2, 4], 2.5, -1);
1005
1006        assert!((sparse.weight() - 2.5).abs() < 0.001);
1007        assert_eq!(sparse.include_indices(), &[1, 3, 5]);
1008        assert_eq!(sparse.negated_indices(), &[2, 4]);
1009        assert_eq!(sparse.polarity(), -1);
1010    }
1011
1012    #[test]
1013    fn sparse_clause_evaluate_checked() {
1014        let sparse = SparseClause::new(&[0, 2], &[1], 1.0, 1);
1015
1016        // Normal case: should fire
1017        assert!(sparse.evaluate_checked(&[1, 0, 1, 0]));
1018
1019        // Include violation
1020        assert!(!sparse.evaluate_checked(&[0, 0, 1, 0]));
1021
1022        // Negated violation
1023        assert!(!sparse.evaluate_checked(&[1, 1, 1, 0]));
1024
1025        // Out of bounds index treated as 0
1026        let sparse_oob = SparseClause::new(&[100], &[], 1.0, 1);
1027        assert!(!sparse_oob.evaluate_checked(&[1, 1])); // idx 100 doesn't exist
1028
1029        let sparse_oob_neg = SparseClause::new(&[], &[100], 1.0, 1);
1030        assert!(sparse_oob_neg.evaluate_checked(&[1, 1])); // negated: missing = 0 = ok
1031    }
1032
1033    #[test]
1034    fn sparse_memory_stats_edge_cases() {
1035        // Empty stats
1036        let stats = SparseMemoryStats {
1037            include_data:    0,
1038            include_offsets: 0,
1039            negated_data:    0,
1040            negated_offsets: 0,
1041            weights:         0,
1042            polarities:      0,
1043            total_literals:  0,
1044            n_clauses:       0,
1045            n_features:      100
1046        };
1047
1048        assert!((stats.avg_literals_per_clause() - 0.0).abs() < 0.001);
1049        assert!((stats.sparsity() - 0.0).abs() < 0.001);
1050        assert!((stats.compression_ratio(100) - 0.0).abs() < 0.001);
1051        assert_eq!(stats.total(), 0);
1052
1053        // Zero features
1054        let stats_zero_feat = SparseMemoryStats {
1055            include_data:    10,
1056            include_offsets: 8,
1057            negated_data:    10,
1058            negated_offsets: 8,
1059            weights:         8,
1060            polarities:      2,
1061            total_literals:  5,
1062            n_clauses:       2,
1063            n_features:      0
1064        };
1065        assert!((stats_zero_feat.sparsity() - 0.0).abs() < 0.001);
1066        assert_eq!(stats_zero_feat.total(), 46);
1067    }
1068
1069    #[test]
1070    fn sparse_tm_from_clauses() {
1071        let clauses = vec![Clause::new(4, 100, 1), Clause::new(4, 100, -1)];
1072
1073        let stm = SparseTsetlinMachine::from_clauses(&clauses, 4, 10.0);
1074        assert_eq!(stm.n_clauses(), 2);
1075        assert_eq!(stm.n_features(), 4);
1076        assert!((stm.threshold() - 10.0).abs() < 0.001);
1077    }
1078
1079    #[test]
1080    fn sparse_tm_new_and_accessors() {
1081        let clauses = vec![
1082            SparseClause::new(&[0], &[], 1.0, 1),
1083            SparseClause::new(&[], &[1], 1.0, -1),
1084        ];
1085        let bank = SparseClauseBank::from_clauses(&clauses, 4);
1086        let stm = SparseTsetlinMachine::new(bank, 5.0);
1087
1088        assert_eq!(stm.n_clauses(), 2);
1089        assert_eq!(stm.n_features(), 4);
1090        assert!((stm.threshold() - 5.0).abs() < 0.001);
1091    }
1092
1093    #[test]
1094    fn sparse_tm_predict() {
1095        let clauses = vec![
1096            SparseClause::new(&[0], &[], 1.0, 1),  // x[0]=1 -> +1
1097            SparseClause::new(&[], &[0], 1.0, -1), // x[0]=0 -> -1
1098        ];
1099        let bank = SparseClauseBank::from_clauses(&clauses, 4);
1100        let stm = SparseTsetlinMachine::new(bank, 5.0);
1101
1102        // x[0]=1: +1 vote -> positive -> class 1
1103        assert_eq!(stm.predict(&[1, 0, 0, 0]), 1);
1104
1105        // x[0]=0: -1 vote -> negative -> class 0
1106        assert_eq!(stm.predict(&[0, 0, 0, 0]), 0);
1107    }
1108
1109    #[test]
1110    fn sparse_tm_predict_packed() {
1111        let clauses = vec![
1112            SparseClause::new(&[0], &[], 2.0, 1),  // x[0]=1 -> +2
1113            SparseClause::new(&[], &[0], 1.0, -1), // x[0]=0 -> -1
1114        ];
1115        let bank = SparseClauseBank::from_clauses(&clauses, 64);
1116        let stm = SparseTsetlinMachine::new(bank, 5.0);
1117
1118        // Packed bit 0 set: +2 votes
1119        assert_eq!(stm.predict_packed(&[1u64]), 1);
1120
1121        // Packed bit 0 clear: -1 vote
1122        assert_eq!(stm.predict_packed(&[0u64]), 0);
1123    }
1124
1125    #[test]
1126    fn sparse_tm_predict_batch() {
1127        let clauses = vec![
1128            SparseClause::new(&[0], &[], 1.0, 1),
1129            SparseClause::new(&[], &[0], 1.0, -1),
1130        ];
1131        let bank = SparseClauseBank::from_clauses(&clauses, 2);
1132        let stm = SparseTsetlinMachine::new(bank, 5.0);
1133
1134        let xs = vec![vec![1, 0], vec![0, 0], vec![1, 1], vec![0, 1]];
1135        let preds = stm.predict_batch(&xs);
1136
1137        assert_eq!(preds, vec![1, 0, 1, 0]);
1138    }
1139
1140    #[test]
1141    fn sparse_tm_evaluate() {
1142        let clauses = vec![
1143            SparseClause::new(&[0], &[], 1.0, 1),
1144            SparseClause::new(&[], &[0], 1.0, -1),
1145        ];
1146        let bank = SparseClauseBank::from_clauses(&clauses, 2);
1147        let stm = SparseTsetlinMachine::new(bank, 5.0);
1148
1149        let xs = vec![vec![1, 0], vec![0, 0], vec![1, 1], vec![0, 1]];
1150        let ys = vec![1, 0, 1, 0];
1151
1152        // 100% accuracy expected
1153        assert!((stm.evaluate(&xs, &ys) - 1.0).abs() < 0.001);
1154
1155        // Wrong labels: 0% accuracy
1156        let wrong_ys = vec![0, 1, 0, 1];
1157        assert!((stm.evaluate(&xs, &wrong_ys) - 0.0).abs() < 0.001);
1158
1159        // Empty input
1160        assert!((stm.evaluate(&[], &[]) - 0.0).abs() < 0.001);
1161    }
1162
1163    #[test]
1164    fn sparse_tm_memory_and_compression() {
1165        let clauses = vec![
1166            SparseClause::new(&[0, 1], &[2], 1.0, 1),
1167            SparseClause::new(&[3], &[4, 5], 1.0, -1),
1168        ];
1169        let bank = SparseClauseBank::from_clauses(&clauses, 100);
1170        let stm = SparseTsetlinMachine::new(bank, 5.0);
1171
1172        let stats = stm.memory_stats();
1173        assert_eq!(stats.total_literals, 6);
1174        assert_eq!(stats.n_clauses, 2);
1175
1176        // High compression ratio expected
1177        let ratio = stm.compression_ratio();
1178        assert!(ratio > 10.0);
1179    }
1180}