1#[cfg(not(feature = "std"))]
52use alloc::{vec, vec::Vec};
53
54#[cfg(feature = "serde")]
55use serde::{Deserialize, Serialize};
56use smallvec::SmallVec;
57
58use crate::{Clause, ClauseBank};
59
60const INLINE_CAPACITY: usize = 32;
65
66#[derive(Debug, Clone)]
97#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
98pub struct SparseClause {
99 include_indices: SmallVec<[u16; INLINE_CAPACITY]>,
101
102 negated_indices: SmallVec<[u16; INLINE_CAPACITY]>,
104
105 weight: f32,
107
108 polarity: i8
110}
111
112impl SparseClause {
113 #[must_use]
132 pub fn from_clause(clause: &Clause) -> Self {
133 let mut include = SmallVec::new();
134 let mut negated = SmallVec::new();
135
136 for (k, pair) in clause.automata().chunks(2).enumerate() {
137 if pair[0].action() {
138 include.push(k as u16);
139 }
140 if pair[1].action() {
141 negated.push(k as u16);
142 }
143 }
144
145 Self {
146 include_indices: include,
147 negated_indices: negated,
148 weight: clause.weight(),
149 polarity: clause.polarity()
150 }
151 }
152
153 #[must_use]
162 pub fn new(include: &[u16], negated: &[u16], weight: f32, polarity: i8) -> Self {
163 Self {
164 include_indices: SmallVec::from_slice(include),
165 negated_indices: SmallVec::from_slice(negated),
166 weight,
167 polarity
168 }
169 }
170
171 #[inline(always)]
173 #[must_use]
174 pub const fn polarity(&self) -> i8 {
175 self.polarity
176 }
177
178 #[inline(always)]
180 #[must_use]
181 pub const fn weight(&self) -> f32 {
182 self.weight
183 }
184
185 #[inline(always)]
187 #[must_use]
188 pub fn include_indices(&self) -> &[u16] {
189 &self.include_indices
190 }
191
192 #[inline(always)]
194 #[must_use]
195 pub fn negated_indices(&self) -> &[u16] {
196 &self.negated_indices
197 }
198
199 #[inline]
214 #[must_use]
215 pub fn evaluate(&self, x: &[u8]) -> bool {
216 for &idx in &self.include_indices {
217 if unsafe { *x.get_unchecked(idx as usize) } == 0 {
219 return false;
220 }
221 }
222 for &idx in &self.negated_indices {
223 if unsafe { *x.get_unchecked(idx as usize) } == 1 {
225 return false;
226 }
227 }
228 true
229 }
230
231 #[inline]
236 #[must_use]
237 pub fn evaluate_checked(&self, x: &[u8]) -> bool {
238 for &idx in &self.include_indices {
239 if x.get(idx as usize).copied().unwrap_or(0) == 0 {
240 return false;
241 }
242 }
243 for &idx in &self.negated_indices {
244 if x.get(idx as usize).copied().unwrap_or(0) == 1 {
245 return false;
246 }
247 }
248 true
249 }
250
251 #[inline]
260 #[must_use]
261 pub fn evaluate_packed(&self, x: &[u64]) -> bool {
262 for &idx in &self.include_indices {
263 let word = idx as usize >> 6; let bit = idx as usize & 63; if unsafe { *x.get_unchecked(word) } & (1u64 << bit) == 0 {
267 return false;
268 }
269 }
270 for &idx in &self.negated_indices {
271 let word = idx as usize >> 6;
272 let bit = idx as usize & 63;
273 if unsafe { *x.get_unchecked(word) } & (1u64 << bit) != 0 {
274 return false;
275 }
276 }
277 true
278 }
279
280 #[inline]
283 #[must_use]
284 pub fn vote(&self, x: &[u8]) -> f32 {
285 if self.evaluate(x) {
286 self.polarity as f32 * self.weight
287 } else {
288 0.0
289 }
290 }
291
292 #[inline]
294 #[must_use]
295 pub fn vote_unweighted(&self, x: &[u8]) -> i32 {
296 if self.evaluate(x) {
297 self.polarity as i32
298 } else {
299 0
300 }
301 }
302
303 #[must_use]
307 pub fn memory_usage(&self) -> usize {
308 let base = core::mem::size_of::<Self>();
309 let include_heap = if self.include_indices.spilled() {
310 self.include_indices.capacity() * 2
311 } else {
312 0
313 };
314 let negated_heap = if self.negated_indices.spilled() {
315 self.negated_indices.capacity() * 2
316 } else {
317 0
318 };
319 base + include_heap + negated_heap
320 }
321
322 #[inline]
324 #[must_use]
325 pub fn n_literals(&self) -> usize {
326 self.include_indices.len() + self.negated_indices.len()
327 }
328}
329
330#[derive(Debug, Clone)]
376#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
377pub struct SparseClauseBank {
378 include_indices: Vec<u16>,
380
381 include_offsets: Vec<u32>,
383
384 negated_indices: Vec<u16>,
386
387 negated_offsets: Vec<u32>,
389
390 weights: Vec<f32>,
392
393 polarities: Vec<i8>,
395
396 n_clauses: usize,
398
399 n_features: usize
401}
402
403impl SparseClauseBank {
404 #[must_use]
423 pub fn from_clause_bank(bank: &ClauseBank) -> Self {
424 let mut include_indices = Vec::new();
425 let mut include_offsets = vec![0u32];
426 let mut negated_indices = Vec::new();
427 let mut negated_offsets = vec![0u32];
428
429 let threshold = bank.n_states();
430
431 for c in 0..bank.n_clauses() {
432 let states = bank.clause_states(c);
433
434 for (k, pair) in states.chunks(2).enumerate() {
435 if pair[0] > threshold {
436 include_indices.push(k as u16);
437 }
438 if pair[1] > threshold {
439 negated_indices.push(k as u16);
440 }
441 }
442
443 include_offsets.push(include_indices.len() as u32);
444 negated_offsets.push(negated_indices.len() as u32);
445 }
446
447 Self {
448 include_indices,
449 include_offsets,
450 negated_indices,
451 negated_offsets,
452 weights: bank.weights().to_vec(),
453 polarities: bank.polarities().to_vec(),
454 n_clauses: bank.n_clauses(),
455 n_features: bank.n_features()
456 }
457 }
458
459 #[must_use]
464 pub fn from_clauses(clauses: &[SparseClause], n_features: usize) -> Self {
465 let mut include_indices = Vec::new();
466 let mut include_offsets = vec![0u32];
467 let mut negated_indices = Vec::new();
468 let mut negated_offsets = vec![0u32];
469 let mut weights = Vec::with_capacity(clauses.len());
470 let mut polarities = Vec::with_capacity(clauses.len());
471
472 for clause in clauses {
473 include_indices.extend_from_slice(&clause.include_indices);
474 include_offsets.push(include_indices.len() as u32);
475
476 negated_indices.extend_from_slice(&clause.negated_indices);
477 negated_offsets.push(negated_indices.len() as u32);
478
479 weights.push(clause.weight);
480 polarities.push(clause.polarity);
481 }
482
483 Self {
484 include_indices,
485 include_offsets,
486 negated_indices,
487 negated_offsets,
488 weights,
489 polarities,
490 n_clauses: clauses.len(),
491 n_features
492 }
493 }
494
495 #[inline(always)]
497 #[must_use]
498 pub const fn n_clauses(&self) -> usize {
499 self.n_clauses
500 }
501
502 #[inline(always)]
504 #[must_use]
505 pub const fn n_features(&self) -> usize {
506 self.n_features
507 }
508
509 #[inline(always)]
511 #[must_use]
512 pub fn weights(&self) -> &[f32] {
513 &self.weights
514 }
515
516 #[inline(always)]
518 #[must_use]
519 pub fn polarities(&self) -> &[i8] {
520 &self.polarities
521 }
522
523 #[inline]
525 #[must_use]
526 pub fn clause_n_literals(&self, clause: usize) -> usize {
527 let inc = self.include_offsets[clause + 1] - self.include_offsets[clause];
528 let neg = self.negated_offsets[clause + 1] - self.negated_offsets[clause];
529 (inc + neg) as usize
530 }
531
532 #[inline]
536 #[must_use]
537 pub fn evaluate_clause(&self, clause: usize, x: &[u8]) -> bool {
538 let inc_start = self.include_offsets[clause] as usize;
539 let inc_end = self.include_offsets[clause + 1] as usize;
540
541 for &idx in &self.include_indices[inc_start..inc_end] {
542 if unsafe { *x.get_unchecked(idx as usize) } == 0 {
544 return false;
545 }
546 }
547
548 let neg_start = self.negated_offsets[clause] as usize;
549 let neg_end = self.negated_offsets[clause + 1] as usize;
550
551 for &idx in &self.negated_indices[neg_start..neg_end] {
552 if unsafe { *x.get_unchecked(idx as usize) } == 1 {
553 return false;
554 }
555 }
556
557 true
558 }
559
560 #[inline]
564 #[must_use]
565 pub fn evaluate_clause_packed(&self, clause: usize, x: &[u64]) -> bool {
566 let inc_start = self.include_offsets[clause] as usize;
567 let inc_end = self.include_offsets[clause + 1] as usize;
568
569 for &idx in &self.include_indices[inc_start..inc_end] {
570 let word = idx as usize >> 6;
571 let bit = idx as usize & 63;
572 if unsafe { *x.get_unchecked(word) } & (1u64 << bit) == 0 {
573 return false;
574 }
575 }
576
577 let neg_start = self.negated_offsets[clause] as usize;
578 let neg_end = self.negated_offsets[clause + 1] as usize;
579
580 for &idx in &self.negated_indices[neg_start..neg_end] {
581 let word = idx as usize >> 6;
582 let bit = idx as usize & 63;
583 if unsafe { *x.get_unchecked(word) } & (1u64 << bit) != 0 {
584 return false;
585 }
586 }
587
588 true
589 }
590
591 #[must_use]
596 pub fn sum_votes(&self, x: &[u8]) -> f32 {
597 let mut sum = 0.0f32;
598 for c in 0..self.n_clauses {
599 if self.evaluate_clause(c, x) {
600 sum += unsafe {
602 *self.polarities.get_unchecked(c) as f32 * *self.weights.get_unchecked(c)
603 };
604 }
605 }
606 sum
607 }
608
609 #[must_use]
611 pub fn sum_votes_packed(&self, x: &[u64]) -> f32 {
612 let mut sum = 0.0f32;
613 for c in 0..self.n_clauses {
614 if self.evaluate_clause_packed(c, x) {
615 sum += unsafe {
616 *self.polarities.get_unchecked(c) as f32 * *self.weights.get_unchecked(c)
617 };
618 }
619 }
620 sum
621 }
622
623 #[must_use]
625 pub fn sum_votes_unweighted(&self, x: &[u8]) -> i32 {
626 let mut sum = 0i32;
627 for c in 0..self.n_clauses {
628 if self.evaluate_clause(c, x) {
629 sum += self.polarities[c] as i32;
630 }
631 }
632 sum
633 }
634
635 #[must_use]
637 pub fn memory_stats(&self) -> SparseMemoryStats {
638 SparseMemoryStats {
639 include_data: self.include_indices.len() * 2,
640 include_offsets: self.include_offsets.len() * 4,
641 negated_data: self.negated_indices.len() * 2,
642 negated_offsets: self.negated_offsets.len() * 4,
643 weights: self.weights.len() * 4,
644 polarities: self.polarities.len(),
645 total_literals: self.include_indices.len() + self.negated_indices.len(),
646 n_clauses: self.n_clauses,
647 n_features: self.n_features
648 }
649 }
650}
651
652#[derive(Debug, Clone, Copy)]
656pub struct SparseMemoryStats {
657 pub include_data: usize,
659
660 pub include_offsets: usize,
662
663 pub negated_data: usize,
665
666 pub negated_offsets: usize,
668
669 pub weights: usize,
671
672 pub polarities: usize,
674
675 pub total_literals: usize,
677
678 pub n_clauses: usize,
680
681 pub n_features: usize
683}
684
685impl SparseMemoryStats {
686 #[must_use]
688 pub const fn total(&self) -> usize {
689 self.include_data
690 + self.include_offsets
691 + self.negated_data
692 + self.negated_offsets
693 + self.weights
694 + self.polarities
695 }
696
697 #[must_use]
699 pub fn avg_literals_per_clause(&self) -> f32 {
700 if self.n_clauses == 0 {
701 0.0
702 } else {
703 self.total_literals as f32 / self.n_clauses as f32
704 }
705 }
706
707 #[must_use]
713 pub fn compression_ratio(&self, n_features: usize) -> f32 {
714 let dense_size = self.n_clauses * 2 * n_features * 2; if self.total() == 0 {
716 0.0
717 } else {
718 dense_size as f32 / self.total() as f32
719 }
720 }
721
722 #[must_use]
724 pub fn sparsity(&self) -> f32 {
725 let max_literals = self.n_clauses * 2 * self.n_features;
726 if max_literals == 0 {
727 0.0
728 } else {
729 self.total_literals as f32 / max_literals as f32
730 }
731 }
732}
733
734#[derive(Debug, Clone)]
759#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
760pub struct SparseTsetlinMachine {
761 clauses: SparseClauseBank,
762 threshold: f32
763}
764
765impl SparseTsetlinMachine {
766 #[must_use]
774 pub fn from_clauses(clauses: &[Clause], n_features: usize, threshold: f32) -> Self {
775 let sparse_clauses: Vec<SparseClause> =
776 clauses.iter().map(SparseClause::from_clause).collect();
777
778 Self {
779 clauses: SparseClauseBank::from_clauses(&sparse_clauses, n_features),
780 threshold
781 }
782 }
783
784 #[must_use]
786 pub fn new(clauses: SparseClauseBank, threshold: f32) -> Self {
787 Self {
788 clauses,
789 threshold
790 }
791 }
792
793 #[inline(always)]
795 #[must_use]
796 pub const fn n_clauses(&self) -> usize {
797 self.clauses.n_clauses()
798 }
799
800 #[inline(always)]
802 #[must_use]
803 pub const fn n_features(&self) -> usize {
804 self.clauses.n_features()
805 }
806
807 #[inline(always)]
809 #[must_use]
810 pub const fn threshold(&self) -> f32 {
811 self.threshold
812 }
813
814 #[inline]
816 #[must_use]
817 pub fn predict(&self, x: &[u8]) -> u8 {
818 if self.clauses.sum_votes(x) >= 0.0 {
819 1
820 } else {
821 0
822 }
823 }
824
825 #[inline]
827 #[must_use]
828 pub fn predict_packed(&self, x: &[u64]) -> u8 {
829 if self.clauses.sum_votes_packed(x) >= 0.0 {
830 1
831 } else {
832 0
833 }
834 }
835
836 #[must_use]
838 pub fn predict_batch(&self, xs: &[Vec<u8>]) -> Vec<u8> {
839 xs.iter().map(|x| self.predict(x)).collect()
840 }
841
842 #[must_use]
844 pub fn evaluate(&self, x: &[Vec<u8>], y: &[u8]) -> f32 {
845 if x.is_empty() {
846 return 0.0;
847 }
848 let correct = x
849 .iter()
850 .zip(y)
851 .filter(|(xi, yi)| self.predict(xi) == **yi)
852 .count();
853 correct as f32 / x.len() as f32
854 }
855
856 #[must_use]
858 pub fn memory_stats(&self) -> SparseMemoryStats {
859 self.clauses.memory_stats()
860 }
861
862 #[must_use]
864 pub fn compression_ratio(&self) -> f32 {
865 self.clauses
866 .memory_stats()
867 .compression_ratio(self.n_features())
868 }
869}
870
871#[cfg(test)]
872mod tests {
873 use super::*;
874
875 #[test]
876 fn sparse_clause_from_dense() {
877 let mut clause = Clause::new(10, 100, 1);
878
879 for _ in 0..200 {
881 clause.automata_mut()[0].increment(); clause.automata_mut()[5].increment(); }
884
885 let sparse = SparseClause::from_clause(&clause);
886 assert_eq!(sparse.include_indices.len(), 1);
887 assert_eq!(sparse.negated_indices.len(), 1);
888 assert_eq!(sparse.include_indices[0], 0);
889 assert_eq!(sparse.negated_indices[0], 2);
890 assert_eq!(sparse.polarity(), 1);
891 }
892
893 #[test]
894 fn sparse_clause_evaluate() {
895 let sparse = SparseClause::new(&[0, 2], &[1], 1.0, 1);
896
897 assert!(sparse.evaluate(&[1, 0, 1, 0]));
899
900 assert!(!sparse.evaluate(&[0, 0, 1, 0]));
902
903 assert!(!sparse.evaluate(&[1, 1, 1, 0]));
905
906 assert!(!sparse.evaluate(&[1, 0, 0, 0]));
908 }
909
910 #[test]
911 fn sparse_clause_evaluate_packed() {
912 let sparse = SparseClause::new(&[0, 2], &[1], 1.0, 1);
913
914 assert!(sparse.evaluate_packed(&[5u64]));
916
917 assert!(!sparse.evaluate_packed(&[4u64]));
919 }
920
921 #[test]
922 fn sparse_clause_vote() {
923 let sparse = SparseClause::new(&[], &[], 2.5, -1);
924
925 assert!((sparse.vote(&[0, 1, 0]) - (-2.5)).abs() < 0.001);
927 assert_eq!(sparse.vote_unweighted(&[0, 1, 0]), -1);
928 }
929
930 #[test]
931 fn sparse_clause_memory() {
932 let sparse = SparseClause::new(&[0, 1, 2], &[3, 4], 1.0, 1);
933
934 let usage = sparse.memory_usage();
936 assert!(usage < 200);
937 assert_eq!(sparse.n_literals(), 5);
938 }
939
940 #[test]
941 fn sparse_bank_from_clause_bank() {
942 let bank = ClauseBank::new(10, 100, 100);
943 let sparse = SparseClauseBank::from_clause_bank(&bank);
944
945 assert_eq!(sparse.n_clauses(), 10);
946 assert_eq!(sparse.n_features(), 100);
947
948 let stats = sparse.memory_stats();
950 assert_eq!(stats.total_literals, 0);
951 }
952
953 #[test]
954 fn sparse_bank_evaluate() {
955 let clauses = vec![
957 SparseClause::new(&[0], &[], 1.0, 1), SparseClause::new(&[], &[0], 1.0, -1), ];
960 let sparse = SparseClauseBank::from_clauses(&clauses, 4);
961
962 let votes = sparse.sum_votes(&[1, 0, 0, 0]);
964 assert!((votes - 1.0).abs() < 0.001);
965
966 let votes = sparse.sum_votes(&[0, 0, 0, 0]);
968 assert!((votes - (-1.0)).abs() < 0.001);
969 }
970
971 #[test]
972 fn sparse_bank_memory_stats() {
973 let clauses = vec![
974 SparseClause::new(&[0, 1, 2], &[3], 1.0, 1),
975 SparseClause::new(&[4, 5], &[6, 7, 8], 1.0, -1),
976 ];
977 let sparse = SparseClauseBank::from_clauses(&clauses, 100);
978
979 let stats = sparse.memory_stats();
980 assert_eq!(stats.total_literals, 9); assert_eq!(stats.n_clauses, 2);
982
983 let ratio = stats.compression_ratio(100);
985 assert!(ratio > 10.0);
986 }
987
988 #[test]
989 fn sparse_bank_packed_evaluation() {
990 let clauses = vec![
991 SparseClause::new(&[0, 63], &[], 1.0, 1), SparseClause::new(&[], &[1, 62], 1.0, -1), ];
994 let sparse = SparseClauseBank::from_clauses(&clauses, 64);
995
996 let packed = 1u64 | (1u64 << 63);
998 let votes = sparse.sum_votes_packed(&[packed]);
999 assert!((votes - 0.0).abs() < 0.001); }
1001
1002 #[test]
1003 fn sparse_clause_accessors() {
1004 let sparse = SparseClause::new(&[1, 3, 5], &[2, 4], 2.5, -1);
1005
1006 assert!((sparse.weight() - 2.5).abs() < 0.001);
1007 assert_eq!(sparse.include_indices(), &[1, 3, 5]);
1008 assert_eq!(sparse.negated_indices(), &[2, 4]);
1009 assert_eq!(sparse.polarity(), -1);
1010 }
1011
1012 #[test]
1013 fn sparse_clause_evaluate_checked() {
1014 let sparse = SparseClause::new(&[0, 2], &[1], 1.0, 1);
1015
1016 assert!(sparse.evaluate_checked(&[1, 0, 1, 0]));
1018
1019 assert!(!sparse.evaluate_checked(&[0, 0, 1, 0]));
1021
1022 assert!(!sparse.evaluate_checked(&[1, 1, 1, 0]));
1024
1025 let sparse_oob = SparseClause::new(&[100], &[], 1.0, 1);
1027 assert!(!sparse_oob.evaluate_checked(&[1, 1])); let sparse_oob_neg = SparseClause::new(&[], &[100], 1.0, 1);
1030 assert!(sparse_oob_neg.evaluate_checked(&[1, 1])); }
1032
1033 #[test]
1034 fn sparse_memory_stats_edge_cases() {
1035 let stats = SparseMemoryStats {
1037 include_data: 0,
1038 include_offsets: 0,
1039 negated_data: 0,
1040 negated_offsets: 0,
1041 weights: 0,
1042 polarities: 0,
1043 total_literals: 0,
1044 n_clauses: 0,
1045 n_features: 100
1046 };
1047
1048 assert!((stats.avg_literals_per_clause() - 0.0).abs() < 0.001);
1049 assert!((stats.sparsity() - 0.0).abs() < 0.001);
1050 assert!((stats.compression_ratio(100) - 0.0).abs() < 0.001);
1051 assert_eq!(stats.total(), 0);
1052
1053 let stats_zero_feat = SparseMemoryStats {
1055 include_data: 10,
1056 include_offsets: 8,
1057 negated_data: 10,
1058 negated_offsets: 8,
1059 weights: 8,
1060 polarities: 2,
1061 total_literals: 5,
1062 n_clauses: 2,
1063 n_features: 0
1064 };
1065 assert!((stats_zero_feat.sparsity() - 0.0).abs() < 0.001);
1066 assert_eq!(stats_zero_feat.total(), 46);
1067 }
1068
1069 #[test]
1070 fn sparse_tm_from_clauses() {
1071 let clauses = vec![Clause::new(4, 100, 1), Clause::new(4, 100, -1)];
1072
1073 let stm = SparseTsetlinMachine::from_clauses(&clauses, 4, 10.0);
1074 assert_eq!(stm.n_clauses(), 2);
1075 assert_eq!(stm.n_features(), 4);
1076 assert!((stm.threshold() - 10.0).abs() < 0.001);
1077 }
1078
1079 #[test]
1080 fn sparse_tm_new_and_accessors() {
1081 let clauses = vec![
1082 SparseClause::new(&[0], &[], 1.0, 1),
1083 SparseClause::new(&[], &[1], 1.0, -1),
1084 ];
1085 let bank = SparseClauseBank::from_clauses(&clauses, 4);
1086 let stm = SparseTsetlinMachine::new(bank, 5.0);
1087
1088 assert_eq!(stm.n_clauses(), 2);
1089 assert_eq!(stm.n_features(), 4);
1090 assert!((stm.threshold() - 5.0).abs() < 0.001);
1091 }
1092
1093 #[test]
1094 fn sparse_tm_predict() {
1095 let clauses = vec![
1096 SparseClause::new(&[0], &[], 1.0, 1), SparseClause::new(&[], &[0], 1.0, -1), ];
1099 let bank = SparseClauseBank::from_clauses(&clauses, 4);
1100 let stm = SparseTsetlinMachine::new(bank, 5.0);
1101
1102 assert_eq!(stm.predict(&[1, 0, 0, 0]), 1);
1104
1105 assert_eq!(stm.predict(&[0, 0, 0, 0]), 0);
1107 }
1108
1109 #[test]
1110 fn sparse_tm_predict_packed() {
1111 let clauses = vec![
1112 SparseClause::new(&[0], &[], 2.0, 1), SparseClause::new(&[], &[0], 1.0, -1), ];
1115 let bank = SparseClauseBank::from_clauses(&clauses, 64);
1116 let stm = SparseTsetlinMachine::new(bank, 5.0);
1117
1118 assert_eq!(stm.predict_packed(&[1u64]), 1);
1120
1121 assert_eq!(stm.predict_packed(&[0u64]), 0);
1123 }
1124
1125 #[test]
1126 fn sparse_tm_predict_batch() {
1127 let clauses = vec![
1128 SparseClause::new(&[0], &[], 1.0, 1),
1129 SparseClause::new(&[], &[0], 1.0, -1),
1130 ];
1131 let bank = SparseClauseBank::from_clauses(&clauses, 2);
1132 let stm = SparseTsetlinMachine::new(bank, 5.0);
1133
1134 let xs = vec![vec![1, 0], vec![0, 0], vec![1, 1], vec![0, 1]];
1135 let preds = stm.predict_batch(&xs);
1136
1137 assert_eq!(preds, vec![1, 0, 1, 0]);
1138 }
1139
1140 #[test]
1141 fn sparse_tm_evaluate() {
1142 let clauses = vec![
1143 SparseClause::new(&[0], &[], 1.0, 1),
1144 SparseClause::new(&[], &[0], 1.0, -1),
1145 ];
1146 let bank = SparseClauseBank::from_clauses(&clauses, 2);
1147 let stm = SparseTsetlinMachine::new(bank, 5.0);
1148
1149 let xs = vec![vec![1, 0], vec![0, 0], vec![1, 1], vec![0, 1]];
1150 let ys = vec![1, 0, 1, 0];
1151
1152 assert!((stm.evaluate(&xs, &ys) - 1.0).abs() < 0.001);
1154
1155 let wrong_ys = vec![0, 1, 0, 1];
1157 assert!((stm.evaluate(&xs, &wrong_ys) - 0.0).abs() < 0.001);
1158
1159 assert!((stm.evaluate(&[], &[]) - 0.0).abs() < 0.001);
1161 }
1162
1163 #[test]
1164 fn sparse_tm_memory_and_compression() {
1165 let clauses = vec![
1166 SparseClause::new(&[0, 1], &[2], 1.0, 1),
1167 SparseClause::new(&[3], &[4, 5], 1.0, -1),
1168 ];
1169 let bank = SparseClauseBank::from_clauses(&clauses, 100);
1170 let stm = SparseTsetlinMachine::new(bank, 5.0);
1171
1172 let stats = stm.memory_stats();
1173 assert_eq!(stats.total_literals, 6);
1174 assert_eq!(stats.n_clauses, 2);
1175
1176 let ratio = stm.compression_ratio();
1178 assert!(ratio > 10.0);
1179 }
1180}