1use ternlang_core::trit::Trit;
19use serde::{Serialize, Deserialize};
20
21pub mod spectra_compat {
24 use super::*;
25
26 pub fn import_spectra_weights(raw_data: &[f32], rows: usize, cols: usize) -> TritMatrix {
29 println!("ternlang-ml: Annexing Spectra-1.1 weights (Scale: 1.2T tokens)...");
30 TritMatrix::from_f32(rows, cols, raw_data, 0.5)
32 }
33}
34
35pub mod coherence;
36
37pub fn quantize(weights: &[f32], threshold: f32) -> Vec<Trit> {
48 weights.iter().map(|&w| {
49 if w > threshold {
50 Trit::Affirm
51 } else if w < -threshold {
52 Trit::Reject
53 } else {
54 Trit::Tend
55 }
56 }).collect()
57}
58
59pub fn bitnet_threshold(weights: &[f32]) -> f32 {
61 let mean_abs = weights.iter().map(|w| w.abs()).sum::<f32>() / weights.len() as f32;
62 0.5 * mean_abs
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct TritMatrix {
70 pub rows: usize,
71 pub cols: usize,
72 pub data: Vec<Trit>,
73}
74
75impl TritMatrix {
76 pub fn new(rows: usize, cols: usize) -> Self {
77 Self { rows, cols, data: vec![Trit::Tend; rows * cols] }
78 }
79
80 pub fn from_trits(rows: usize, cols: usize, data: Vec<Trit>) -> Self {
81 assert_eq!(data.len(), rows * cols);
82 Self { rows, cols, data }
83 }
84
85 pub fn from_f32(rows: usize, cols: usize, weights: &[f32], threshold: f32) -> Self {
86 Self::from_trits(rows, cols, quantize(weights, threshold))
87 }
88
89 #[inline]
90 pub fn get(&self, row: usize, col: usize) -> Trit {
91 self.data[row * self.cols + col]
92 }
93
94 #[inline]
95 pub fn set(&mut self, row: usize, col: usize, val: Trit) {
96 self.data[row * self.cols + col] = val;
97 }
98
99 pub fn sparsity(&self) -> f64 {
101 let zeros = self.data.iter().filter(|&&t| t == Trit::Tend).count();
102 zeros as f64 / self.data.len() as f64
103 }
104
105 pub fn nnz(&self) -> usize {
107 self.data.iter().filter(|&&t| t != Trit::Tend).count()
108 }
109
110 pub fn to_i8_vec(&self) -> Vec<i8> {
112 self.data.iter().map(|&t| match t {
113 Trit::Affirm => 1,
114 Trit::Reject => -1,
115 Trit::Tend => 0,
116 }).collect()
117 }
118}
119
120pub fn dense_matmul(a: &TritMatrix, b: &TritMatrix) -> TritMatrix {
126 assert_eq!(a.cols, b.rows, "matmul dimension mismatch: a.cols must equal b.rows");
127 let mut c = TritMatrix::new(a.rows, b.cols);
128 for row in 0..a.rows {
129 for col in 0..b.cols {
130 let mut acc = Trit::Tend;
131 for k in 0..a.cols {
132 let prod = a.get(row, k) * b.get(k, col);
133 let (sum, _carry) = acc + prod;
134 acc = sum;
135 }
136 c.set(row, col, acc);
137 }
138 }
139 c
140}
141
142pub fn sparse_matmul(a: &TritMatrix, b: &TritMatrix) -> (TritMatrix, usize) {
162 use rayon::prelude::*;
163
164 assert_eq!(a.cols, b.rows, "matmul dimension mismatch");
165
166 #[inline(always)]
167 fn t2i(t: Trit) -> i8 {
168 match t { Trit::Reject => -1, Trit::Tend => 0, Trit::Affirm => 1 }
169 }
170
171 let a_flat: Vec<i8> = a.data.iter().map(|&t| t2i(t)).collect();
173 let a_cols = a.cols;
174
175 let mut csc_offsets = vec![0usize; b.cols + 1];
180 for k in 0..b.rows {
182 for j in 0..b.cols {
183 if t2i(b.data[k * b.cols + j]) != 0 {
184 csc_offsets[j + 1] += 1;
185 }
186 }
187 }
188 for j in 0..b.cols {
190 csc_offsets[j + 1] += csc_offsets[j];
191 }
192 let nnz = csc_offsets[b.cols];
193 let mut csc_idx = vec![0u32; nnz];
194 let mut csc_val = vec![0i8; nnz];
195 let mut col_cursor = csc_offsets[..b.cols].to_vec(); for k in 0..b.rows {
197 for j in 0..b.cols {
198 let w = t2i(b.data[k * b.cols + j]);
199 if w != 0 {
200 let pos = col_cursor[j];
201 csc_idx[pos] = k as u32;
202 csc_val[pos] = w;
203 col_cursor[j] += 1;
204 }
205 }
206 }
207
208 let dense_ops = a.rows * b.cols * a.cols;
209 let active_ops = nnz * a.rows;
210 let skipped = dense_ops.saturating_sub(active_ops);
211
212 let mut out_flat = vec![0i8; a.rows * b.cols];
215
216 out_flat
217 .par_chunks_mut(b.cols)
218 .enumerate()
219 .for_each(|(row, row_out)| {
220 let a_row = &a_flat[row * a_cols..(row + 1) * a_cols];
221 for col in 0..b.cols {
222 let start = csc_offsets[col];
223 let end = csc_offsets[col + 1];
224 let mut acc: i32 = 0;
225 for i in start..end {
228 let k = unsafe { *csc_idx.get_unchecked(i) } as usize;
229 let w = unsafe { *csc_val.get_unchecked(i) } as i32;
230 let av = unsafe { *a_row.get_unchecked(k) } as i32;
231 acc += av * w;
232 }
233 row_out[col] = if acc > 0 { 1 } else if acc < 0 { -1 } else { 0 };
234 }
235 });
236
237 let c_data: Vec<Trit> = out_flat.into_iter().map(|v| Trit::from(v)).collect();
239 let c = TritMatrix { rows: a.rows, cols: b.cols, data: c_data };
240
241 (c, skipped)
242}
243
244pub fn linear(input: &TritMatrix, weights: &TritMatrix) -> (TritMatrix, usize) {
252 sparse_matmul(input, weights)
253}
254
255pub struct BenchmarkResult {
259 pub dense_ops: usize,
260 pub sparse_ops: usize,
261 pub skipped_ops: usize,
262 pub skip_rate: f64,
263 pub weight_sparsity: f64,
264}
265
266impl BenchmarkResult {
267 pub fn print_summary(&self) {
268 println!("=== Ternary Sparse Matmul Benchmark ===");
269 println!(" Weight sparsity: {:.1}% zeros", self.weight_sparsity * 100.0);
270 println!(" Dense ops: {}", self.dense_ops);
271 println!(" Sparse ops: {}", self.sparse_ops);
272 println!(" Skipped ops: {}", self.skipped_ops);
273 println!(" Skip rate: {:.1}%", self.skip_rate * 100.0);
274 println!(" Ops saved: {:.1}x fewer multiplies", self.dense_ops as f64 / self.sparse_ops.max(1) as f64);
275 }
276}
277
278pub fn benchmark(a: &TritMatrix, b: &TritMatrix) -> BenchmarkResult {
279 let dense_ops = a.rows * a.cols * b.cols;
280 let (_result, skipped) = sparse_matmul(a, b);
281 let sparse_ops = dense_ops - skipped;
282 BenchmarkResult {
283 dense_ops,
284 sparse_ops,
285 skipped_ops: skipped,
286 skip_rate: skipped as f64 / dense_ops as f64,
287 weight_sparsity: b.sparsity(),
288 }
289}
290
291pub fn trit_activation(t: Trit) -> Trit { t }
297
298pub fn majority(trits: &[Trit]) -> Trit {
301 let sum: i32 = trits.iter().map(|&t| match t {
302 Trit::Affirm => 1,
303 Trit::Reject => -1,
304 Trit::Tend => 0,
305 }).sum();
306 match sum.signum() {
307 1 => Trit::Affirm,
308 -1 => Trit::Reject,
309 _ => Trit::Tend,
310 }
311}
312
313pub struct TernaryMLP {
323 pub w1: TritMatrix, pub w2: TritMatrix, pub in_features: usize,
326 pub hidden_size: usize,
327 pub out_features: usize,
328}
329
330impl TernaryMLP {
331 pub fn new(w1: TritMatrix, w2: TritMatrix) -> Self {
333 let in_features = w1.rows;
334 let hidden_size = w1.cols;
335 let out_features = w2.cols;
336 assert_eq!(w2.rows, hidden_size, "w1.cols must equal w2.rows");
337 Self { w1, w2, in_features, hidden_size, out_features }
338 }
339
340 pub fn from_f32(
342 in_features: usize, hidden_size: usize, out_features: usize,
343 w1_f32: &[f32], w2_f32: &[f32],
344 ) -> Self {
345 let tau1 = bitnet_threshold(w1_f32);
346 let tau2 = bitnet_threshold(w2_f32);
347 let w1 = TritMatrix::from_f32(in_features, hidden_size, w1_f32, tau1);
348 let w2 = TritMatrix::from_f32(hidden_size, out_features, w2_f32, tau2);
349 Self::new(w1, w2)
350 }
351
352 pub fn forward(&self, input: &TritMatrix) -> (TritMatrix, usize, usize) {
356 assert_eq!(input.cols, self.in_features,
357 "input width must match in_features");
358
359 let (hidden, skip1) = sparse_matmul(input, &self.w1);
361
362 let hidden_act = TritMatrix::from_trits(
364 hidden.rows, hidden.cols,
365 hidden.data.iter().map(|&t| trit_activation(t)).collect(),
366 );
367
368 let (output, skip2) = sparse_matmul(&hidden_act, &self.w2);
370
371 (output, skip1, skip2)
372 }
373
374 pub fn predict(&self, input: &TritMatrix) -> usize {
377 let (output, _, _) = self.forward(input);
378 let row = 0;
379 let mut best_col = 0;
380 let mut best_val: i8 = -2;
381 for col in 0..self.out_features {
382 let v = match output.get(row, col) {
383 Trit::Affirm => 1,
384 Trit::Tend => 0,
385 Trit::Reject => -1,
386 };
387 if v > best_val { best_val = v; best_col = col; }
388 }
389 best_col
390 }
391
392 pub fn layer1_sparsity(&self) -> f64 { self.w1.sparsity() }
393 pub fn layer2_sparsity(&self) -> f64 { self.w2.sparsity() }
394}
395
396#[derive(Debug)]
400pub struct TimedResult {
401 pub size: usize, pub dense_ops: usize,
403 pub sparse_ops: usize,
404 pub skipped_ops: usize,
405 pub weight_sparsity: f64,
406 pub skip_rate: f64,
407 pub speedup: f64,
408 pub dense_us: u64, pub sparse_us: u64, }
411
412pub fn timed_benchmark(sizes: &[usize], reps: usize) -> Vec<TimedResult> {
417 use std::time::Instant;
418
419 fn lcg_weights(n: usize, seed: u64) -> Vec<f32> {
421 let mut state = seed;
422 (0..n).map(|_| {
423 state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
424 let f = ((state >> 33) as f32) / (u32::MAX as f32) * 3.0 - 1.5;
427 f
428 }).collect()
429 }
430
431 fn median_us(mut times: Vec<u64>) -> u64 {
432 times.sort_unstable();
433 times[times.len() / 2]
434 }
435
436 sizes.iter().map(|&n| {
437 let weights_a = lcg_weights(n * n, 0xdeadbeef);
438 let weights_b = lcg_weights(n * n, 0xc0ffee42);
439 let tau_a = bitnet_threshold(&weights_a);
440 let tau_b = bitnet_threshold(&weights_b);
441 let a = TritMatrix::from_f32(n, n, &weights_a, tau_a);
442
443 let b = TritMatrix::from_f32(n, n, &weights_b, tau_b);
444
445 let sparsity = b.sparsity();
446 let dense_ops = n * n * n;
447 let (_, skipped) = sparse_matmul(&a, &b); let sparse_ops = dense_ops - skipped;
449
450 let dense_times: Vec<u64> = (0..reps).map(|_| {
452 let t = Instant::now();
453 let _ = dense_matmul(&a, &b);
454 t.elapsed().as_micros() as u64
455 }).collect();
456
457 let sparse_times: Vec<u64> = (0..reps).map(|_| {
459 let t = Instant::now();
460 let _ = sparse_matmul(&a, &b);
461 t.elapsed().as_micros() as u64
462 }).collect();
463
464 let dense_us = median_us(dense_times);
465 let sparse_us = median_us(sparse_times);
466 let speedup = if sparse_us > 0 {
467 dense_us as f64 / sparse_us as f64
468 } else { dense_ops as f64 / sparse_ops.max(1) as f64 };
469
470 TimedResult {
471 size: n, dense_ops, sparse_ops, skipped_ops: skipped,
472 weight_sparsity: sparsity, skip_rate: skipped as f64 / dense_ops as f64,
473 speedup, dense_us, sparse_us,
474 }
475 }).collect()
476}
477
478pub fn print_benchmark_table(results: &[TimedResult]) {
480 println!("\n╔══════════════════════════════════════════════════════════════════════╗");
481 println!( "║ Ternlang Sparse Matmul Benchmark — RFI-IRFOS TIS ║");
482 println!( "╠════════╦══════════╦═══════════╦══════════╦══════════╦═════════════╣");
483 println!( "║ Size ║ Sparsity ║ Dense μs ║ Sparse μs║ Speedup ║ Skip rate ║");
484 println!( "╠════════╬══════════╬═══════════╬══════════╬══════════╬═════════════╣");
485 for r in results {
486 println!("║ {:>4}² ║ {:>5.1}% ║ {:>7} ║ {:>7} ║ {:>5.2}× ║ {:>6.1}% ║",
487 r.size,
488 r.weight_sparsity * 100.0,
489 r.dense_us,
490 r.sparse_us,
491 r.speedup,
492 r.skip_rate * 100.0,
493 );
494 }
495 println!( "╚════════╩══════════╩═══════════╩══════════╩══════════╩═════════════╝");
496}
497
498pub fn bitnet_matrix(rows: usize, cols: usize, seed: u64, target_sparsity: f64) -> TritMatrix {
504 let mut state = seed;
505 let n = rows * cols;
506 let mut data = Vec::with_capacity(n);
507 for _ in 0..n {
508 state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
509 let prob = (state >> 32) as f64 / (u32::MAX as f64 + 1.0);
510 if prob < target_sparsity {
511 data.push(Trit::Tend);
512 } else if (state & 1) == 0 {
513 data.push(Trit::Affirm);
514 } else {
515 data.push(Trit::Reject);
516 }
517 }
518 TritMatrix { rows, cols, data }
519}
520
521pub fn timed_benchmark_bitnet(sizes: &[usize], reps: usize) -> Vec<TimedResult> {
525 timed_benchmark_at_sparsity(0.60, sizes, reps)
526}
527
528pub fn timed_benchmark_at_sparsity(target_sparsity: f64, sizes: &[usize], reps: usize) -> Vec<TimedResult> {
530 use std::time::Instant;
531
532 let bitnet_sparsity: f64 = target_sparsity;
533
534 fn median_us(mut v: Vec<u64>) -> u64 {
535 v.sort_unstable();
536 v[v.len() / 2]
537 }
538
539 sizes.iter().map(|&n| {
540 let a = bitnet_matrix(n, n, 0xdeadbeef, bitnet_sparsity);
541 let b = bitnet_matrix(n, n, 0xc0ffee42, bitnet_sparsity);
542
543 let sparsity = b.sparsity();
544 let dense_ops = n * n * n;
545 let (_, skipped) = sparse_matmul(&a, &b);
546 let sparse_ops = dense_ops - skipped;
547 let speedup_ops = dense_ops as f64 / sparse_ops.max(1) as f64;
548
549 let dense_times: Vec<u64> = (0..reps).map(|_| {
550 let t = Instant::now();
551 let _ = dense_matmul(&a, &b);
552 t.elapsed().as_micros() as u64
553 }).collect();
554
555 let sparse_times: Vec<u64> = (0..reps).map(|_| {
556 let t = Instant::now();
557 let _ = sparse_matmul(&a, &b);
558 t.elapsed().as_micros() as u64
559 }).collect();
560
561 let dense_us = median_us(dense_times);
562 let sparse_us = median_us(sparse_times);
563 let speedup = if sparse_us > 0 {
564 dense_us as f64 / sparse_us as f64
565 } else { speedup_ops };
566
567 TimedResult {
568 size: n, dense_ops, sparse_ops, skipped_ops: skipped,
569 weight_sparsity: sparsity, skip_rate: skipped as f64 / dense_ops as f64,
570 speedup, dense_us, sparse_us,
571 }
572 }).collect()
573}
574
575pub fn xor_dataset() -> Vec<(TritMatrix, usize)> {
580 let inputs = vec![
581 (vec![Trit::Reject, Trit::Reject], 0usize), (vec![Trit::Reject, Trit::Affirm], 1usize), (vec![Trit::Affirm, Trit::Reject], 1usize), (vec![Trit::Affirm, Trit::Affirm], 0usize), ];
586 inputs.into_iter().map(|(row, label)| {
587 (TritMatrix::from_trits(1, 2, row), label)
588 }).collect()
589}
590
591pub fn parity_dataset() -> Vec<(TritMatrix, usize)> {
593 (0u8..8).map(|i| {
594 let bits = vec![
595 if i & 4 != 0 { Trit::Affirm } else { Trit::Reject },
596 if i & 2 != 0 { Trit::Affirm } else { Trit::Reject },
597 if i & 1 != 0 { Trit::Affirm } else { Trit::Reject },
598 ];
599 let parity = (i.count_ones() % 2) as usize;
600 (TritMatrix::from_trits(1, 3, bits), parity)
601 }).collect()
602}
603
604pub fn evaluate(mlp: &TernaryMLP, dataset: &[(TritMatrix, usize)]) -> (usize, usize, f64) {
607 let total = dataset.len();
608 let correct = dataset.iter()
609 .filter(|(input, label)| mlp.predict(input) == *label)
610 .count();
611 let accuracy = correct as f64 / total as f64;
612 (correct, total, accuracy)
613}
614
615pub const TEND_BOUNDARY: f32 = 1.0 / 3.0;
630
631#[derive(Debug, Clone)]
633pub struct TritScalar(pub f32);
634
635impl TritScalar {
636 pub fn new(v: f32) -> Self { TritScalar(v.clamp(-1.0, 1.0)) }
638
639 pub fn trit(&self) -> Trit {
641 if self.0 > TEND_BOUNDARY { Trit::Affirm }
642 else if self.0 < -TEND_BOUNDARY { Trit::Reject }
643 else { Trit::Tend }
644 }
645
646 pub fn label(&self) -> &'static str {
648 match self.trit() {
649 Trit::Affirm => "affirm",
650 Trit::Reject => "reject",
651 Trit::Tend => "tend",
652 }
653 }
654
655 pub fn confidence(&self) -> f32 {
660 let v = self.0.abs();
661 if v > TEND_BOUNDARY {
662 (v - TEND_BOUNDARY) / (1.0 - TEND_BOUNDARY)
663 } else {
664 1.0 - v / TEND_BOUNDARY
665 }
666 }
667
668 pub fn is_actionable(&self, min_confidence: f32) -> bool {
671 self.trit() != Trit::Tend && self.confidence() >= min_confidence
672 }
673
674 pub fn raw(&self) -> f32 { self.0 }
676
677 pub fn trit_i8(&self) -> i8 {
679 match self.trit() { Trit::Affirm => 1, Trit::Reject => -1, Trit::Tend => 0 }
680 }
681}
682
683pub struct TritEvidenceVec {
697 pub dimensions: Vec<String>,
698 pub values: Vec<f32>, pub weights: Vec<f32>, }
701
702impl TritEvidenceVec {
703 pub fn new(dimensions: Vec<String>, values: Vec<f32>, weights: Vec<f32>) -> Self {
704 assert_eq!(dimensions.len(), values.len(), "dimensions and values must match");
705 assert_eq!(dimensions.len(), weights.len(), "dimensions and weights must match");
706 let values = values.iter().map(|&v| v.clamp(-1.0, 1.0)).collect();
707 TritEvidenceVec { dimensions, values, weights }
708 }
709
710 pub fn aggregate(&self) -> TritScalar {
712 let total_weight: f32 = self.weights.iter().sum();
713 if total_weight == 0.0 { return TritScalar::new(0.0); }
714 let weighted_sum: f32 = self.values.iter()
715 .zip(self.weights.iter())
716 .map(|(v, w)| v * w)
717 .sum();
718 TritScalar::new(weighted_sum / total_weight)
719 }
720
721 pub fn scalars(&self) -> Vec<TritScalar> {
723 self.values.iter().map(|&v| TritScalar::new(v)).collect()
724 }
725
726 pub fn dominant(&self) -> Option<(&str, TritScalar)> {
728 self.values.iter()
729 .enumerate()
730 .max_by(|(_, a), (_, b)| a.abs().partial_cmp(&b.abs()).unwrap_or(std::cmp::Ordering::Equal))
731 .map(|(i, &v)| (self.dimensions[i].as_str(), TritScalar::new(v)))
732 }
733}
734
735#[cfg(test)]
738mod tests {
739 use super::*;
740
741 #[test]
742 fn test_quantize_basic() {
743 let weights = vec![-0.9f32, -0.2, 0.0, 0.3, 0.8];
744 let threshold = 0.5;
745 let trits = quantize(&weights, threshold);
746 assert_eq!(trits, vec![Trit::Reject, Trit::Tend, Trit::Tend, Trit::Tend, Trit::Affirm]);
747 }
748
749 #[test]
750 fn test_bitnet_threshold() {
751 let weights = vec![1.0f32, -1.0, 0.5, -0.5];
752 let tau = bitnet_threshold(&weights);
753 assert!((tau - 0.375).abs() < 1e-6);
755 }
756 #[test]
757 fn test_dense_matmul_identity() {
758 let mut id = TritMatrix::new(2, 2);
760 id.set(0, 0, Trit::Affirm);
761 id.set(1, 1, Trit::Affirm);
762
763 let result = dense_matmul(&id, &id);
764 assert_eq!(result.get(0, 0), Trit::Affirm);
765 assert_eq!(result.get(0, 1), Trit::Tend);
766 assert_eq!(result.get(1, 0), Trit::Tend);
767 assert_eq!(result.get(1, 1), Trit::Affirm);
768 }
769
770 #[test]
771 fn test_sparse_matmul_matches_dense() {
772 let weights = vec![0.9f32, -0.1, 0.05, -0.8, 0.0, 0.7, -0.6, 0.2, 0.0];
774 let threshold = 0.5;
775 let w = TritMatrix::from_f32(3, 3, &weights, threshold);
776 let mut input = TritMatrix::new(3, 3);
777 input.set(0, 0, Trit::Affirm);
778 input.set(1, 1, Trit::Reject);
779 input.set(2, 2, Trit::Affirm);
780
781 let dense = dense_matmul(&input, &w);
782 let (sparse, skipped) = sparse_matmul(&input, &w);
783
784 for r in 0..3 {
786 for c in 0..3 {
787 assert_eq!(dense.get(r, c), sparse.get(r, c),
788 "mismatch at ({}, {})", r, c);
789 }
790 }
791 assert!(skipped > 0, "expected skips for a sparse weight matrix");
793 }
794
795 #[test]
796 fn test_sparsity_measurement() {
797 let weights = vec![0.9f32, 0.1, -0.9]; let threshold = 0.5;
799 let m = TritMatrix::from_f32(1, 3, &weights, threshold);
800 assert!((m.sparsity() - 1.0/3.0).abs() < 1e-9);
802 assert_eq!(m.nnz(), 2);
803 }
804
805 #[test]
806 fn test_majority_vote() {
807 assert_eq!(majority(&[Trit::Affirm, Trit::Affirm, Trit::Reject]), Trit::Affirm);
808 assert_eq!(majority(&[Trit::Reject, Trit::Reject, Trit::Affirm]), Trit::Reject);
809 assert_eq!(majority(&[Trit::Affirm, Trit::Reject]), Trit::Tend);
810 assert_eq!(majority(&[Trit::Tend, Trit::Tend]), Trit::Tend);
811 }
812
813 #[test]
814 fn test_mlp_forward_runs() {
815 let w1_f32: Vec<f32> = vec![
817 0.9, -0.8, 0.7, -0.6,
818 -0.7, 0.9, -0.5, 0.8,
819 ];
820 let w2_f32: Vec<f32> = vec![
821 0.9, -0.9,
822 -0.8, 0.8,
823 0.7, -0.7,
824 -0.6, 0.6,
825 ];
826 let mlp = TernaryMLP::from_f32(2, 4, 2, &w1_f32, &w2_f32);
827 let input = TritMatrix::from_trits(1, 2, vec![Trit::Affirm, Trit::Reject]);
828 let (out, s1, s2) = mlp.forward(&input);
829 assert_eq!(out.rows, 1);
830 assert_eq!(out.cols, 2);
831 let _ = (s1, s2);
833 }
834
835 #[test]
836 fn test_mlp_predict_returns_valid_class() {
837 let w1_f32: Vec<f32> = vec![0.9, -0.8, -0.7, 0.9];
838 let w2_f32: Vec<f32> = vec![0.9, -0.9, -0.8, 0.8];
839 let mlp = TernaryMLP::from_f32(2, 2, 2, &w1_f32, &w2_f32);
840 let input = TritMatrix::from_trits(1, 2, vec![Trit::Affirm, Trit::Reject]);
841 let pred = mlp.predict(&input);
842 assert!(pred < 2, "prediction must be a valid class index");
843 }
844
845 #[test]
846 fn test_xor_dataset_shape() {
847 let ds = xor_dataset();
848 assert_eq!(ds.len(), 4);
849 for (input, label) in &ds {
850 assert_eq!(input.rows, 1);
851 assert_eq!(input.cols, 2);
852 assert!(*label < 2);
853 }
854 }
855
856 #[test]
857 fn test_parity_dataset_shape() {
858 let ds = parity_dataset();
859 assert_eq!(ds.len(), 8);
860 for (input, label) in &ds {
861 assert_eq!(input.cols, 3);
862 assert!(*label < 2);
863 }
864 }
865
866 #[test]
867 fn test_xor_mlp_with_known_weights() {
868 let w1_f32 = vec![
874 1.0, -1.0,
875 -1.0, 1.0,
876 ];
877 let w2_f32 = vec![
880 -1.0, 1.0,
881 -1.0, 1.0,
882 ];
883 let mlp = TernaryMLP::from_f32(2, 2, 2, &w1_f32, &w2_f32);
884 let ds = xor_dataset();
885 let (correct, total, acc) = evaluate(&mlp, &ds);
886 println!("XOR MLP: {}/{} = {:.0}%", correct, total, acc * 100.0);
887 assert!(correct >= 2, "MLP should get at least half of XOR correct");
890 }
891
892 #[test]
893 fn test_timed_benchmark_small() {
894 let results = timed_benchmark(&[8, 16], 3);
895 assert_eq!(results.len(), 2);
896 for r in &results {
897 assert!(r.dense_ops > 0);
898 assert!(r.weight_sparsity >= 0.0 && r.weight_sparsity <= 1.0);
899 assert!(r.skip_rate >= 0.0 && r.skip_rate <= 1.0);
900 }
901 print_benchmark_table(&results);
902 }
903
904 #[test]
905 fn test_benchmark_reports_skips() {
906 let weights: Vec<f32> = vec![
908 0.9, 0.1, -0.9, 0.0,
909 0.1, 0.8, 0.0, -0.7,
910 0.0, 0.1, 0.6, 0.2,
911 -0.8, 0.0, 0.1, 0.9,
912 ];
913 let threshold = 0.5;
914 let w = TritMatrix::from_f32(4, 4, &weights, threshold);
915 let input = TritMatrix::new(4, 4); let result = benchmark(&input, &w);
917 assert!(result.skipped_ops > 0);
918 assert!(result.skip_rate > 0.0 && result.skip_rate <= 1.0);
919 result.print_summary();
920 }
921
922 #[test]
923 fn test_full_benchmark() {
924 let results = timed_benchmark(&[32, 64, 128, 256, 512], 5);
925 assert_eq!(results.len(), 5);
926 print_benchmark_table(&results);
927 }
928
929 #[test]
932 fn test_bitnet_benchmark() {
933 let results = timed_benchmark_bitnet(&[32, 64, 128, 256, 512], 5);
934 assert_eq!(results.len(), 5);
935 println!("\n╔══════════════════════════════════════════════════════════════════════╗");
936 println!( "║ BitNet b1.58 Realistic Benchmark — 60% Sparsity — RFI-IRFOS TIS ║");
937 println!( "╠════════╦══════════╦═══════════╦══════════╦══════════╦═════════════╣");
938 println!( "║ Size ║ Sparsity ║ Dense μs ║ Sparse μs║ Speedup ║ Skip rate ║");
939 println!( "╠════════╬══════════╬═══════════╬══════════╬══════════╬═════════════╣");
940 for r in &results {
941 println!("║ {:>4}² ║ {:>5.1}% ║ {:>7} ║ {:>7} ║ {:>5.2}× ║ {:>6.1}% ║",
942 r.size,
943 r.weight_sparsity * 100.0,
944 r.dense_us,
945 r.sparse_us,
946 r.speedup,
947 r.skip_rate * 100.0,
948 );
949 }
950 println!( "╚════════╩══════════╩═══════════╩══════════╩══════════╩═════════════╝");
951 for r in &results {
952 assert!(r.skip_rate >= 0.50, "Expected ≥50% skip rate at 60% sparsity, got {:.1}%", r.skip_rate * 100.0);
953 }
954 }
955
956 #[test]
958 fn test_extreme_sparsity_99() {
959 let results = timed_benchmark_at_sparsity(0.99, &[32, 64, 128, 256, 512], 5);
960 assert_eq!(results.len(), 5);
961 println!("\n╔══════════════════════════════════════════════════════════════════════╗");
962 println!( "║ EXTREME SPARSITY — 99% Zeros — What Happens? ║");
963 println!( "╠════════╦══════════╦═══════════╦══════════╦══════════╦═════════════╣");
964 println!( "║ Size ║ Sparsity ║ Dense μs ║ Sparse μs║ Speedup ║ Skip rate ║");
965 println!( "╠════════╬══════════╬═══════════╬══════════╬══════════╬═════════════╣");
966 for r in &results {
967 println!("║ {:>4}² ║ {:>5.1}% ║ {:>7} ║ {:>7} ║ {:>6.1}× ║ {:>6.1}% ║",
968 r.size,
969 r.weight_sparsity * 100.0,
970 r.dense_us,
971 r.sparse_us,
972 r.speedup,
973 r.skip_rate * 100.0,
974 );
975 }
976 println!( "╚════════╩══════════╩═══════════╩══════════╩══════════╩═════════════╝");
977 for r in &results {
978 assert!(r.skip_rate >= 0.95, "Expected ≥95% skip rate at 99% sparsity");
979 }
980 }
981
982 #[test]
985 fn test_sparsity_sweep() {
986 let sparsities: &[f64] = &[0.25, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95, 0.99];
987 let sizes: &[usize] = &[32, 64, 128, 256, 512];
988
989 let mut grid: Vec<Vec<f64>> = Vec::new();
991 for &sp in sparsities {
992 let row: Vec<f64> = timed_benchmark_at_sparsity(sp, sizes, 3)
993 .into_iter().map(|r| r.speedup).collect();
994 grid.push(row);
995 }
996
997 println!();
999 println!("╔══════════════ SPARSITY GOLDILOCKS SWEEP ══════════════════════════╗");
1000 println!("║ Speedup (sparse / dense) across sparsity × matrix size ║");
1001 println!("╠══════════╦═══════╦═══════╦════════╦════════╦════════╣");
1002 print!( "║ Sparsity ║");
1003 for &n in sizes { print!(" {:>4}² ║", n); }
1004 println!();
1005 println!("╠══════════╬═══════╬═══════╬════════╬════════╬════════╣");
1006
1007 let mut peak_speedup = 0f64;
1008 let mut peak_sp = 0f64;
1009 let mut peak_n = 0usize;
1010
1011 for (i, &sp) in sparsities.iter().enumerate() {
1012 print!("║ {:>5.1}% ║", sp * 100.0);
1013 for (j, &speedup) in grid[i].iter().enumerate() {
1014 if speedup > peak_speedup {
1015 peak_speedup = speedup;
1016 peak_sp = sp;
1017 peak_n = sizes[j];
1018 }
1019 print!(" {:>5.1}× ║", speedup);
1020 }
1021 println!();
1022 }
1023
1024 println!("╚══════════╩═══════╩═══════╩════════╩════════╩════════╝");
1025 println!();
1026 println!(" ★ Peak: {:.1}× at {:.0}% sparsity, {}×{} matrix", peak_speedup, peak_sp * 100.0, peak_n, peak_n);
1027
1028 let avg_speedups: Vec<(f64, f64)> = sparsities.iter().zip(grid.iter())
1030 .map(|(&sp, row)| (sp, row.iter().sum::<f64>() / row.len() as f64))
1031 .collect();
1032 let (best_sp, best_avg) = avg_speedups.iter()
1033 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
1034 .copied().unwrap();
1035 println!(" ◆ Goldilocks zone: {:.0}% sparsity → {:.1}× average across all sizes", best_sp * 100.0, best_avg);
1036 println!();
1037
1038 for row in &grid {
1041 for &s in &row[1..] { assert!(s >= 1.0, "Speedup dropped below 1× — something is wrong");
1043 }
1044 }
1045 }
1046
1047 #[test]
1050 fn test_trit_scalar_zones() {
1051 assert_eq!(TritScalar::new(0.9).label(), "affirm");
1052 assert_eq!(TritScalar::new(-0.9).label(), "reject");
1053 assert_eq!(TritScalar::new(0.0).label(), "tend");
1054 assert_eq!(TritScalar::new(0.33).label(), "tend"); assert_eq!(TritScalar::new(0.34).label(), "affirm"); }
1057
1058 #[test]
1059 fn test_trit_scalar_confidence() {
1060 let s = TritScalar::new(0.0);
1062 assert_eq!(s.label(), "tend");
1063 assert!((s.confidence() - 1.0).abs() < 0.01);
1064
1065 let s = TritScalar::new(1.0);
1067 assert_eq!(s.label(), "affirm");
1068 assert!((s.confidence() - 1.0).abs() < 0.01);
1069
1070 let s = TritScalar::new(TEND_BOUNDARY + 0.001);
1072 assert_eq!(s.label(), "affirm");
1073 assert!(s.confidence() < 0.01);
1074 }
1075
1076 #[test]
1077 fn test_trit_scalar_actionable() {
1078 assert!(TritScalar::new(0.9).is_actionable(0.5));
1080 assert!(!TritScalar::new(0.35).is_actionable(0.8));
1082 assert!(!TritScalar::new(0.0).is_actionable(0.0));
1084 }
1085
1086 #[test]
1087 fn test_trit_scalar_clamp() {
1088 assert!((TritScalar::new(5.0).raw() - 1.0).abs() < 0.001);
1089 assert!((TritScalar::new(-5.0).raw() + 1.0).abs() < 0.001);
1090 }
1091
1092 #[test]
1095 fn test_evidence_vec_aggregate_uniform() {
1096 let ev = TritEvidenceVec::new(
1098 vec!["a".into(), "b".into(), "c".into()],
1099 vec![0.8, 0.9, 0.7],
1100 vec![1.0, 1.0, 1.0],
1101 );
1102 let agg = ev.aggregate();
1103 assert_eq!(agg.label(), "affirm");
1104 assert!(agg.confidence() > 0.5);
1105 }
1106
1107 #[test]
1108 fn test_evidence_vec_mixed_signals() {
1109 let ev = TritEvidenceVec::new(
1111 vec!["strong_reject".into(), "weak_affirm".into()],
1112 vec![-0.9, 0.1],
1113 vec![1.0, 1.0],
1114 );
1115 let agg = ev.aggregate();
1116 assert_eq!(agg.label(), "reject");
1118 }
1119
1120 #[test]
1121 fn test_evidence_vec_weighted_override() {
1122 let ev = TritEvidenceVec::new(
1124 vec!["weak_reject".into(), "strong_affirm".into()],
1125 vec![-0.4, 0.9],
1126 vec![10.0, 1.0], );
1128 let agg = ev.aggregate();
1129 assert_eq!(agg.label(), "tend");
1131 }
1132
1133 #[test]
1134 fn test_evidence_vec_dominant() {
1135 let ev = TritEvidenceVec::new(
1136 vec!["low".into(), "high".into(), "mid".into()],
1137 vec![0.2, -0.95, 0.5],
1138 vec![1.0, 1.0, 1.0],
1139 );
1140 let (label, scalar) = ev.dominant().unwrap();
1141 assert_eq!(label, "high");
1142 assert_eq!(scalar.label(), "reject");
1143 }
1144}
1145
1146#[derive(Debug, Clone)]
1164pub struct DeliberationRound {
1165 pub round: usize,
1166 pub new_evidence: Vec<f32>, pub cumulative_mean: f32, pub scalar: TritScalar,
1169 pub converged: bool, }
1171
1172#[derive(Debug, Clone)]
1174pub struct DeliberationResult {
1175 pub final_trit: i8,
1176 pub final_label: String,
1177 pub final_confidence: f32,
1178 pub converged: bool,
1179 pub rounds_used: usize,
1180 pub trace: Vec<DeliberationRound>,
1181 pub convergence_reason: String,
1182}
1183
1184pub struct DeliberationEngine {
1193 pub target_confidence: f32,
1195 pub max_rounds: usize,
1197 pub alpha: f32,
1199}
1200
1201impl DeliberationEngine {
1202 pub fn new(target_confidence: f32, max_rounds: usize) -> Self {
1203 Self { target_confidence, max_rounds, alpha: 0.4 }
1204 }
1205
1206 pub fn with_alpha(mut self, alpha: f32) -> Self { self.alpha = alpha.clamp(0.01, 1.0); self }
1207
1208 pub fn run(&self, rounds_evidence: Vec<Vec<f32>>) -> DeliberationResult {
1211 let mut ema: f32 = 0.0; let mut initialized = false;
1213 let mut trace = Vec::new();
1214
1215 let rounds_to_run = self.max_rounds.min(
1216 if rounds_evidence.is_empty() { self.max_rounds } else { rounds_evidence.len() }
1217 );
1218
1219 for round in 0..rounds_to_run {
1220 let new_ev: Vec<f32> = rounds_evidence.get(round).cloned().unwrap_or_default();
1221
1222 if !new_ev.is_empty() {
1224 let round_mean = new_ev.iter().sum::<f32>() / new_ev.len() as f32;
1225 ema = if !initialized {
1226 initialized = true;
1227 round_mean
1228 } else {
1229 self.alpha * round_mean + (1.0 - self.alpha) * ema
1230 };
1231 }
1232
1233 let scalar = TritScalar::new(ema);
1234 let converged = scalar.confidence() >= self.target_confidence;
1235
1236 trace.push(DeliberationRound {
1237 round,
1238 new_evidence: new_ev,
1239 cumulative_mean: ema,
1240 scalar: scalar.clone(),
1241 converged,
1242 });
1243
1244 if converged { break; }
1245 }
1246
1247 let last = trace.last().cloned().unwrap_or_else(|| DeliberationRound {
1248 round: 0, new_evidence: vec![], cumulative_mean: 0.0,
1249 scalar: TritScalar::new(0.0), converged: false,
1250 });
1251
1252 let convergence_reason = if last.converged {
1253 format!("confidence {:.1}% ≥ target {:.1}% after {} round(s)",
1254 last.scalar.confidence() * 100.0,
1255 self.target_confidence * 100.0,
1256 last.round + 1)
1257 } else {
1258 format!("max rounds ({}) reached — confidence {:.1}% below target {:.1}%",
1259 self.max_rounds,
1260 last.scalar.confidence() * 100.0,
1261 self.target_confidence * 100.0)
1262 };
1263
1264 DeliberationResult {
1265 final_trit: last.scalar.trit_i8(),
1266 final_label: last.scalar.label().to_string(),
1267 final_confidence: last.scalar.confidence(),
1268 converged: last.converged,
1269 rounds_used: last.round + 1,
1270 trace,
1271 convergence_reason,
1272 }
1273 }
1274}
1275
1276#[derive(Debug, Clone)]
1280pub struct CoalitionMember {
1281 pub label: String,
1282 pub trit: i8, pub confidence: f32, pub weight: f32, }
1286
1287impl CoalitionMember {
1288 pub fn new(label: impl Into<String>, trit: i8, confidence: f32, weight: f32) -> Self {
1289 Self {
1290 label: label.into(),
1291 trit: trit.clamp(-1, 1),
1292 confidence: confidence.clamp(0.0, 1.0),
1293 weight: weight.max(0.0),
1294 }
1295 }
1296}
1297
1298#[derive(Debug, Clone)]
1300pub struct CoalitionResult {
1301 pub trit: i8,
1302 pub label: String,
1303 pub aggregate_score: f32, pub quorum: f32, pub dissent_rate: f32, pub abstain_rate: f32, pub member_count: usize,
1308 pub effective_weight: f32, pub breakdown: Vec<(String, i8, f32)>, }
1311
1312pub fn coalition_vote(members: &[CoalitionMember]) -> CoalitionResult {
1317 if members.is_empty() {
1318 return CoalitionResult {
1319 trit: 0, label: "tend".into(), aggregate_score: 0.0,
1320 quorum: 0.0, dissent_rate: 0.0, abstain_rate: 1.0,
1321 member_count: 0, effective_weight: 0.0, breakdown: vec![],
1322 };
1323 }
1324
1325 let total_weight: f32 = members.iter().map(|m| m.weight).sum();
1326 let total_weight = if total_weight == 0.0 { 1.0 } else { total_weight };
1327
1328 let mut weighted_sum: f32 = 0.0;
1329 let mut non_zero_weight: f32 = 0.0;
1330 let mut breakdown = Vec::new();
1331
1332 for m in members {
1333 let contribution = (m.trit as f32) * m.confidence * m.weight;
1334 weighted_sum += contribution;
1335 if m.trit != 0 { non_zero_weight += m.weight; }
1336 breakdown.push((m.label.clone(), m.trit, contribution / total_weight));
1337 }
1338
1339 let aggregate_score = weighted_sum / total_weight;
1340 let scalar = TritScalar::new(aggregate_score);
1341 let result_trit: i8 = scalar.trit_i8();
1342
1343 let quorum = non_zero_weight / total_weight;
1344 let abstain_rate = 1.0 - quorum;
1345 let dissent_rate = members.iter()
1346 .filter(|m| m.trit != 0 && m.trit.signum() != result_trit.signum())
1347 .map(|m| m.weight)
1348 .sum::<f32>() / total_weight;
1349
1350 CoalitionResult {
1351 trit: result_trit,
1352 label: scalar.label().to_string(),
1353 aggregate_score,
1354 quorum,
1355 dissent_rate,
1356 abstain_rate,
1357 member_count: members.len(),
1358 effective_weight: non_zero_weight,
1359 breakdown,
1360 }
1361}
1362
1363#[derive(Debug, Clone)]
1367pub struct GateDimension {
1368 pub name: String,
1369 pub evidence: f32, pub weight: f32, pub hard_block: bool,
1374}
1375
1376impl GateDimension {
1377 pub fn new(name: impl Into<String>, evidence: f32, weight: f32) -> Self {
1378 Self { name: name.into(), evidence, weight, hard_block: false }
1379 }
1380 pub fn hard(mut self) -> Self { self.hard_block = true; self }
1381}
1382
1383#[derive(Debug, Clone, PartialEq, Eq)]
1385pub enum GateVerdict {
1386 Proceed,
1388 Hold,
1390 Block,
1392}
1393
1394impl GateVerdict {
1395 pub fn label(&self) -> &'static str {
1396 match self {
1397 GateVerdict::Proceed => "proceed",
1398 GateVerdict::Hold => "hold",
1399 GateVerdict::Block => "block",
1400 }
1401 }
1402}
1403
1404#[derive(Debug, Clone)]
1406pub struct GateResult {
1407 pub verdict: GateVerdict,
1408 pub aggregate: TritScalar,
1409 pub hard_blocked_by: Vec<String>, pub dim_results: Vec<(String, TritScalar, bool)>, pub explanation: String,
1412}
1413
1414pub fn action_gate(dimensions: &[GateDimension]) -> GateResult {
1421 let mut hard_blocked_by = Vec::new();
1422 let mut dim_results = Vec::new();
1423 let mut weighted_sum = 0.0f32;
1424 let mut total_weight = 0.0f32;
1425
1426 for dim in dimensions {
1427 let scalar = TritScalar::new(dim.evidence);
1428 let is_neg = matches!(scalar.trit(), Trit::Reject);
1429
1430 if dim.hard_block && is_neg {
1431 hard_blocked_by.push(dim.name.clone());
1432 }
1433
1434 weighted_sum += dim.evidence * dim.weight;
1435 total_weight += dim.weight;
1436 dim_results.push((dim.name.clone(), scalar, dim.hard_block));
1437 }
1438
1439 if !hard_blocked_by.is_empty() {
1441 let explanation = format!(
1442 "BLOCKED — hard constraint(s) violated: {}",
1443 hard_blocked_by.join(", ")
1444 );
1445 return GateResult {
1446 verdict: GateVerdict::Block,
1447 aggregate: TritScalar::new(-1.0),
1448 hard_blocked_by,
1449 dim_results,
1450 explanation,
1451 };
1452 }
1453
1454 let agg_score = if total_weight > 0.0 { weighted_sum / total_weight } else { 0.0 };
1455 let aggregate = TritScalar::new(agg_score);
1456
1457 let verdict = match aggregate.trit() {
1458 Trit::Affirm => GateVerdict::Proceed,
1459 Trit::Tend => GateVerdict::Hold,
1460 Trit::Reject => GateVerdict::Block,
1461 };
1462
1463 let explanation = match &verdict {
1464 GateVerdict::Proceed => format!(
1465 "PROCEED — all dimensions pass (aggregate confidence {:.0}%)",
1466 aggregate.confidence() * 100.0
1467 ),
1468 GateVerdict::Hold => format!(
1469 "HOLD — insufficient evidence (aggregate {:.3} within deliberation zone)",
1470 aggregate.raw()
1471 ),
1472 GateVerdict::Block => format!(
1473 "BLOCK — weighted aggregate {:.3} below threshold (confidence {:.0}%)",
1474 aggregate.raw(), aggregate.confidence() * 100.0
1475 ),
1476 };
1477
1478 GateResult { verdict, aggregate, hard_blocked_by, dim_results, explanation }
1479}
1480
1481#[derive(Debug, Clone)]
1496pub struct ScalarTemperature {
1497 pub trit: i8,
1498 pub confidence: f32,
1499 pub temperature: f32,
1500 pub reasoning: String,
1501 pub prompt_hint: String,
1503}
1504
1505pub fn scalar_temperature(scalar: &TritScalar) -> ScalarTemperature {
1506 let t = scalar.trit();
1507 let c = scalar.confidence(); let (temp, reasoning, prompt_hint) = match t {
1510 Trit::Affirm => {
1511 let temp = 0.3 - (c * 0.25); (
1514 temp.max(0.05),
1515 format!("Affirm (confidence {:.0}%) — execute precisely, minimal exploration", c * 100.0),
1516 "Be concise and direct. Evidence is clear. Do not hedge.".to_string(),
1517 )
1518 }
1519 Trit::Reject => {
1520 let temp = 0.15 - (c * 0.10); (
1523 temp.max(0.05),
1524 format!("Reject (confidence {:.0}%) — decline firmly, minimal hedging", c * 100.0),
1525 "Decline clearly. Do not offer alternatives unless explicitly asked. Evidence is against.".to_string(),
1526 )
1527 }
1528 Trit::Tend => {
1529 let temp = 0.7 + ((1.0 - c) * 0.3); (
1532 temp.min(1.0),
1533 format!("Tend (confidence {:.0}%) — evidence is conflicted, explore broadly", c * 100.0),
1534 "You are in deliberation. Present multiple perspectives. Ask clarifying questions. Do not commit.".to_string(),
1535 )
1536 }
1537 };
1538
1539 ScalarTemperature {
1540 trit: scalar.trit_i8(),
1541 confidence: c,
1542 temperature: (temp * 1000.0).round() / 1000.0,
1543 reasoning,
1544 prompt_hint,
1545 }
1546}
1547
1548#[derive(Debug, Clone)]
1560pub struct HallucinationScore {
1561 pub trust_trit: i8,
1562 pub trust_label: String,
1563 pub mean: f32, pub variance: f32, pub consistency: f32, pub signal_count: usize,
1567 pub explanation: String,
1568}
1569
1570pub fn hallucination_score(signals: &[f32]) -> HallucinationScore {
1571 if signals.is_empty() {
1572 return HallucinationScore {
1573 trust_trit: 0, trust_label: "tend".into(), mean: 0.0,
1574 variance: 0.0, consistency: 0.0, signal_count: 0,
1575 explanation: "No signals provided — cannot assess consistency.".into(),
1576 };
1577 }
1578
1579 let n = signals.len() as f32;
1580 let mean = signals.iter().sum::<f32>() / n;
1581 let variance = signals.iter().map(|&s| (s - mean).powi(2)).sum::<f32>() / n;
1582
1583 let norm_variance = variance.min(1.0);
1585 let consistency = 1.0 - norm_variance;
1586
1587 let trust_evidence = (consistency * 2.0 - 1.0) * mean.abs(); let trust = TritScalar::new(trust_evidence);
1592
1593 let explanation = if trust.trit() == Trit::Affirm {
1594 format!(
1595 "Consistent signals (variance {:.3}, consistency {:.0}%) — evidence coheres around {:.3}",
1596 variance, consistency * 100.0, mean
1597 )
1598 } else if trust.trit() == Trit::Reject {
1599 format!(
1600 "HIGH VARIANCE (variance {:.3}) — signals are internally contradictory. Possible hallucination or conflated sources.",
1601 variance
1602 )
1603 } else {
1604 format!(
1605 "Mixed consistency (variance {:.3}, mean {:.3}) — gather more evidence before relying on this claim.",
1606 variance, mean
1607 )
1608 };
1609
1610 HallucinationScore {
1611 trust_trit: trust.trit_i8(),
1612 trust_label: trust.label().to_string(),
1613 mean,
1614 variance,
1615 consistency,
1616 signal_count: signals.len(),
1617 explanation,
1618 }
1619}
1620
1621#[cfg(test)]
1624mod reasoning_tests {
1625 use super::*;
1626
1627 #[test]
1630 fn test_deliberation_converges_on_strong_evidence() {
1631 let engine = DeliberationEngine::new(0.7, 10).with_alpha(0.7);
1633 let rounds = vec![
1634 vec![0.85, 0.9], vec![0.9, 0.95], vec![0.92, 0.95, 0.98], ];
1638 let result = engine.run(rounds);
1639 assert!(result.converged, "should converge on strong positive evidence (got confidence {:.2})", result.final_confidence);
1640 assert_eq!(result.final_trit, 1, "should be +1 (affirm)");
1641 assert!(result.rounds_used <= 3);
1642 }
1643
1644 #[test]
1645 fn test_deliberation_holds_on_weak_evidence() {
1646 let engine = DeliberationEngine::new(0.95, 3);
1647 let rounds = vec![
1648 vec![0.1f32],
1649 vec![-0.05],
1650 vec![0.15],
1651 ];
1652 let result = engine.run(rounds);
1653 assert!(!result.converged, "should not converge on weak conflicting evidence");
1654 assert_eq!(result.final_trit, 0, "should stay at hold/tend");
1655 assert_eq!(result.rounds_used, 3);
1656 }
1657
1658 #[test]
1659 fn test_deliberation_negative_convergence() {
1660 let engine = DeliberationEngine::new(0.8, 10);
1661 let rounds = vec![
1662 vec![-0.9f32, -0.85],
1663 vec![-0.95, -0.99],
1664 ];
1665 let result = engine.run(rounds);
1666 assert!(result.converged);
1667 assert_eq!(result.final_trit, -1);
1668 }
1669
1670 #[test]
1673 fn test_coalition_unanimous_affirm() {
1674 let members = vec![
1675 CoalitionMember::new("safety", 1, 0.9, 3.0),
1676 CoalitionMember::new("utility", 1, 0.8, 1.0),
1677 CoalitionMember::new("alignment", 1, 0.95, 2.0),
1678 ];
1679 let result = coalition_vote(&members);
1680 assert_eq!(result.trit, 1);
1681 assert_eq!(result.label, "affirm");
1682 assert!(result.quorum > 0.99, "all voted");
1683 assert!(result.dissent_rate < 0.01);
1684 }
1685
1686 #[test]
1687 fn test_coalition_split_vote_tends_to_hold() {
1688 let members = vec![
1689 CoalitionMember::new("agent_a", 1, 0.8, 1.0),
1690 CoalitionMember::new("agent_b", -1, 0.8, 1.0),
1691 CoalitionMember::new("agent_c", 0, 0.5, 1.0),
1692 ];
1693 let result = coalition_vote(&members);
1694 assert_eq!(result.trit, 0);
1696 assert!(result.dissent_rate > 0.0, "there is dissent");
1697 }
1698
1699 #[test]
1700 fn test_coalition_high_weight_overrides() {
1701 let members = vec![
1702 CoalitionMember::new("expert", 1, 0.95, 10.0), CoalitionMember::new("novice_a", -1, 0.5, 1.0),
1704 CoalitionMember::new("novice_b", -1, 0.5, 1.0),
1705 ];
1706 let result = coalition_vote(&members);
1707 assert_eq!(result.trit, 1, "high-weight expert should dominate");
1709 }
1710
1711 #[test]
1714 fn test_gate_all_positive_proceeds() {
1715 let dims = vec![
1716 GateDimension::new("safety", 0.8, 3.0),
1717 GateDimension::new("utility", 0.7, 1.0),
1718 GateDimension::new("legality", 0.9, 2.0),
1719 ];
1720 let result = action_gate(&dims);
1721 assert_eq!(result.verdict, GateVerdict::Proceed);
1722 }
1723
1724 #[test]
1725 fn test_gate_hard_block_fires() {
1726 let dims = vec![
1727 GateDimension::new("utility", 0.9, 1.0),
1728 GateDimension::new("safety", -0.8, 3.0).hard(), GateDimension::new("legality", 0.7, 1.0),
1730 ];
1731 let result = action_gate(&dims);
1732 assert_eq!(result.verdict, GateVerdict::Block);
1733 assert!(result.hard_blocked_by.contains(&"safety".to_string()));
1734 }
1735
1736 #[test]
1737 fn test_gate_mixed_soft_dims_holds() {
1738 let dims = vec![
1739 GateDimension::new("utility", 0.8, 1.0),
1740 GateDimension::new("risk", -0.7, 1.0), ];
1742 let result = action_gate(&dims);
1744 assert_ne!(result.verdict, GateVerdict::Block); }
1747
1748 #[test]
1751 fn test_temperature_affirm_is_low() {
1752 let sc = TritScalar::new(0.9);
1753 let temp = scalar_temperature(&sc);
1754 assert_eq!(temp.trit, 1);
1755 assert!(temp.temperature < 0.3, "affirm → low temperature");
1756 }
1757
1758 #[test]
1759 fn test_temperature_tend_is_high() {
1760 let sc = TritScalar::new(0.05); let temp = scalar_temperature(&sc);
1762 assert_eq!(temp.trit, 0);
1763 assert!(temp.temperature >= 0.7, "tend → high temperature for exploration");
1764 }
1765
1766 #[test]
1767 fn test_temperature_reject_is_low() {
1768 let sc = TritScalar::new(-0.9);
1769 let temp = scalar_temperature(&sc);
1770 assert_eq!(temp.trit, -1);
1771 assert!(temp.temperature < 0.15, "reject → low temperature, firm");
1772 }
1773
1774 #[test]
1777 fn test_hallucination_consistent_signals_trusted() {
1778 let signals = vec![0.8, 0.82, 0.79, 0.81, 0.83];
1780 let score = hallucination_score(&signals);
1781 assert_eq!(score.trust_trit, 1, "consistent signals should be trusted");
1782 assert!(score.variance < 0.01);
1783 assert!(score.consistency > 0.99);
1784 }
1785
1786 #[test]
1787 fn test_hallucination_chaotic_signals_flagged() {
1788 let signals = vec![0.9, -0.9, 0.8, -0.8, 0.95, -0.7];
1790 let score = hallucination_score(&signals);
1791 assert!(score.variance > 0.5, "should have high variance");
1793 assert!(score.trust_trit <= 0, "chaotic signals should not be trusted");
1794 }
1795
1796 #[test]
1797 fn test_hallucination_empty_returns_hold() {
1798 let score = hallucination_score(&[]);
1799 assert_eq!(score.trust_trit, 0);
1800 assert_eq!(score.signal_count, 0);
1801 }
1802}
1803
1804use std::collections::HashMap;
1819use crate::coherence::ModelCoherence;
1820
1821pub struct TritTransformerConfig {
1822 pub dim: usize,
1823 pub n_layers: usize,
1824 pub n_heads: usize,
1825 pub n_kv_heads: usize,
1826 pub vocab_size: usize,
1827 pub multiple_of: usize,
1828 pub ffn_dim_multiplier: Option<f64>,
1829 pub norm_eps: f32,
1830 pub max_seq_len: usize,
1831}
1832
1833impl Default for TritTransformerConfig {
1834 fn default() -> Self {
1835 Self {
1836 dim: 2048,
1837 n_layers: 16,
1838 n_heads: 32,
1839 n_kv_heads: 8,
1840 vocab_size: 128256, multiple_of: 256,
1842 ffn_dim_multiplier: None,
1843 norm_eps: 1e-5,
1844 max_seq_len: 2048,
1845 }
1846 }
1847}
1848
1849pub struct TritBlock {
1851 pub wq: TritMatrix,
1852 pub wk: TritMatrix,
1853 pub wv: TritMatrix,
1854 pub wo: TritMatrix,
1855 pub w1: TritMatrix,
1856 pub w2: TritMatrix,
1857 pub w3: TritMatrix,
1858 pub attention_norm: Vec<f32>, pub ffn_norm: Vec<f32>,
1860}
1861
1862pub struct TritTransformer {
1864 pub config: TritTransformerConfig,
1865 pub tok_embeddings: TritMatrix,
1866 pub layers: Vec<TritBlock>,
1867 pub norm: Vec<f32>,
1868 pub output: TritMatrix,
1869 pub freq_cis: Vec<(f32, f32)>, }
1871
1872impl TritTransformer {
1873 pub fn from_coherence(coherence: ModelCoherence, config: TritTransformerConfig) -> Self {
1875 println!("ternlang-ml: Building TritTransformer (Layers: {})...", config.n_layers);
1876
1877 let mut layers = Vec::with_capacity(config.n_layers);
1878 let mut layer_map: HashMap<String, TritMatrix> = HashMap::new();
1879
1880 for layer in coherence.layers {
1881 layer_map.insert(layer.name.clone(), layer.to_trit_matrix());
1882 }
1883
1884 let mut get = |name: &str| {
1886 layer_map.remove(name).unwrap_or_else(|| panic!("Missing layer: {}", name))
1887 };
1888
1889 let tok_embeddings = get("token_embd.weight");
1890 let output = get("output.weight");
1891
1892 let norm = vec![1.0; config.dim];
1897
1898 for i in 0..config.n_layers {
1899 layers.push(TritBlock {
1900 wq: get(&format!("layers.{}.attention.wq.weight", i)),
1901 wk: get(&format!("layers.{}.attention.wk.weight", i)),
1902 wv: get(&format!("layers.{}.attention.wv.weight", i)),
1903 wo: get(&format!("layers.{}.attention.wo.weight", i)),
1904 w1: get(&format!("layers.{}.feed_forward.w1.weight", i)),
1905 w2: get(&format!("layers.{}.feed_forward.w2.weight", i)),
1906 w3: get(&format!("layers.{}.feed_forward.w3.weight", i)),
1907 attention_norm: vec![1.0; config.dim],
1908 ffn_norm: vec![1.0; config.dim],
1909 });
1910 }
1911
1912 let freq_cis = precompute_freqs_cis(config.dim / config.n_heads, config.max_seq_len);
1914
1915 Self {
1916 config,
1917 tok_embeddings,
1918 layers,
1919 norm,
1920 output,
1921 freq_cis,
1922 }
1923 }
1924
1925 pub fn forward(&self, token: usize, pos: usize) -> Vec<f32> {
1928 let mut h = self.get_embedding(token);
1929
1930 for layer in &self.layers {
1931 let h_norm = rms_norm(&h, &layer.attention_norm, self.config.norm_eps);
1933 let attn_out = self.attention(layer, &h_norm, pos);
1934 for i in 0..h.len() { h[i] += attn_out[i]; }
1935
1936 let h_norm = rms_norm(&h, &layer.ffn_norm, self.config.norm_eps);
1938 let ffn_out = self.feed_forward(layer, &h_norm);
1939 for i in 0..h.len() { h[i] += ffn_out[i]; }
1940 }
1941
1942 let h = rms_norm(&h, &self.norm, self.config.norm_eps);
1943 self.project_output(&h)
1944 }
1945
1946 fn get_embedding(&self, token: usize) -> Vec<f32> {
1947 let start = token * self.config.dim;
1948 let mut embd = Vec::with_capacity(self.config.dim);
1949 for i in 0..self.config.dim {
1950 embd.push(trit_to_f32(self.tok_embeddings.data[start + i]));
1951 }
1952 embd
1953 }
1954
1955 fn attention(&self, layer: &TritBlock, x: &[f32], pos: usize) -> Vec<f32> {
1956 let x_trit = TritMatrix::from_trits(1, x.len(), x.iter().map(|&v| trit_from_f32_approx(v)).collect());
1959
1960 let (q_trit, _) = sparse_matmul(&x_trit, &layer.wq);
1961 let (k_trit, _) = sparse_matmul(&x_trit, &layer.wk);
1962 let (v_trit, _) = sparse_matmul(&x_trit, &layer.wv);
1963
1964 let mut q = q_trit.data.iter().map(|&t| trit_to_f32(t)).collect::<Vec<_>>();
1965 let mut k = k_trit.data.iter().map(|&t| trit_to_f32(t)).collect::<Vec<_>>();
1966 let v = v_trit.data.iter().map(|&t| trit_to_f32(t)).collect::<Vec<_>>();
1967
1968 apply_rope(&mut q, pos, &self.freq_cis, self.config.n_heads);
1970 apply_rope(&mut k, pos, &self.freq_cis, self.config.n_heads);
1971
1972 let v_trit = TritMatrix::from_trits(1, v.len(), v.iter().map(|&val| trit_from_f32_approx(val)).collect());
1977 let (out, _) = sparse_matmul(&v_trit, &layer.wo);
1978 out.data.iter().map(|&t| trit_to_f32(t)).collect()
1979 }
1980
1981 fn feed_forward(&self, layer: &TritBlock, x: &[f32]) -> Vec<f32> {
1982 let x_trit = TritMatrix::from_trits(1, x.len(), x.iter().map(|&v| trit_from_f32_approx(v)).collect());
1983
1984 let (w1_x, _) = sparse_matmul(&x_trit, &layer.w1);
1986 let (w3_x, _) = sparse_matmul(&x_trit, &layer.w3);
1987
1988 let mut hidden = Vec::with_capacity(w1_x.data.len());
1989 for i in 0..w1_x.data.len() {
1990 let v1 = trit_to_f32(w1_x.data[i]);
1991 let v3 = trit_to_f32(w3_x.data[i]);
1992 let silu_v3 = v3 / (1.0 + (-v3).exp());
1994 hidden.push(v1 * silu_v3);
1995 }
1996
1997 let hidden_trit = TritMatrix::from_trits(1, hidden.len(), hidden.iter().map(|&v| trit_from_f32_approx(v)).collect());
1998 let (out, _) = sparse_matmul(&hidden_trit, &layer.w2);
1999 out.data.iter().map(|&t| trit_to_f32(t)).collect()
2000 }
2001
2002 fn project_output(&self, x: &[f32]) -> Vec<f32> {
2003 let x_trit = TritMatrix::from_trits(1, x.len(), x.iter().map(|&v| trit_from_f32_approx(v)).collect());
2004 let (logits, _) = sparse_matmul(&x_trit, &self.output);
2005 logits.data.iter().map(|&t| trit_to_f32(t)).collect()
2006 }
2007}
2008
2009fn rms_norm(x: &[f32], weight: &[f32], eps: f32) -> Vec<f32> {
2012 let sum_sq = x.iter().map(|&v| v * v).sum::<f32>();
2013 let inv_rms = 1.0 / (sum_sq / x.len() as f32 + eps).sqrt();
2014 x.iter().zip(weight.iter()).map(|(&v, &w)| v * inv_rms * w).collect()
2015}
2016
2017fn precompute_freqs_cis(dim: usize, end: usize) -> Vec<(f32, f32)> {
2018 let mut freqs_cis = Vec::with_capacity(end * (dim / 2));
2019 for pos in 0..end {
2020 for i in 0..(dim / 2) {
2021 let freq = 1.0 / 10000.0f32.powf((i * 2) as f32 / dim as f32);
2022 let val = pos as f32 * freq;
2023 freqs_cis.push((val.cos(), val.sin()));
2024 }
2025 }
2026 freqs_cis
2027}
2028
2029fn apply_rope(x: &mut [f32], pos: usize, freq_cis: &[(f32, f32)], n_heads: usize) {
2030 let head_dim = x.len() / n_heads;
2031 for h in 0..n_heads {
2032 let start = h * head_dim;
2033 for i in 0..(head_dim / 2) {
2034 let (cos, sin) = freq_cis[pos * (head_dim / 2) + i];
2035 let x0 = x[start + i];
2036 let x1 = x[start + i + head_dim / 2];
2037 x[start + i] = x0 * cos - x1 * sin;
2038 x[start + i + head_dim / 2] = x0 * sin + x1 * cos;
2039 }
2040 }
2041}
2042
2043pub fn trit_to_f32(t: Trit) -> f32 {
2044 match t {
2045 Trit::Affirm => 1.0,
2046 Trit::Reject => -1.0,
2047 Trit::Tend => 0.0,
2048 }
2049}
2050
2051pub fn trit_from_f32_approx(v: f32) -> Trit {
2052 if v > 0.5 { Trit::Affirm }
2053 else if v < -0.5 { Trit::Reject }
2054 else { Trit::Tend }
2055}