1use ternlang_core::trit::Trit;
19use serde::{Serialize, Deserialize};
20
21pub mod spectra_compat {
24 use super::*;
25
26 pub fn import_spectra_weights(raw_data: &[f32], rows: usize, cols: usize) -> TritMatrix {
29 println!("ternlang-ml: Annexing Spectra-1.1 weights (Scale: 1.2T tokens)...");
30 TritMatrix::from_f32(rows, cols, raw_data, 0.5)
32 }
33}
34
35pub mod coherence;
36pub mod qat;
37pub mod perplexity;
38pub mod tritfloat;
39pub mod tritfloat_tensor;
40pub use tritfloat::TritFloat;
41pub use tritfloat_tensor::TritFloatTensor;
42
43pub fn quantize(weights: &[f32], threshold: f32) -> Vec<Trit> {
54 weights.iter().map(|&w| {
55 if w > threshold {
56 Trit::Affirm
57 } else if w < -threshold {
58 Trit::Reject
59 } else {
60 Trit::Tend
61 }
62 }).collect()
63}
64
65pub fn bitnet_threshold(weights: &[f32]) -> f32 {
67 let mean_abs = weights.iter().map(|w| w.abs()).sum::<f32>() / weights.len() as f32;
68 0.5 * mean_abs
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct TritMatrix {
76 pub rows: usize,
77 pub cols: usize,
78 pub data: Vec<Trit>,
79}
80
81impl TritMatrix {
82 pub fn new(rows: usize, cols: usize) -> Self {
83 Self { rows, cols, data: vec![Trit::Tend; rows * cols] }
84 }
85
86 pub fn from_trits(rows: usize, cols: usize, data: Vec<Trit>) -> Self {
87 assert_eq!(data.len(), rows * cols);
88 Self { rows, cols, data }
89 }
90
91 pub fn from_f32(rows: usize, cols: usize, weights: &[f32], threshold: f32) -> Self {
92 Self::from_trits(rows, cols, quantize(weights, threshold))
93 }
94
95 #[inline]
96 pub fn get(&self, row: usize, col: usize) -> Trit {
97 self.data[row * self.cols + col]
98 }
99
100 #[inline]
101 pub fn set(&mut self, row: usize, col: usize, val: Trit) {
102 self.data[row * self.cols + col] = val;
103 }
104
105 pub fn sparsity(&self) -> f64 {
107 let zeros = self.data.iter().filter(|&&t| t == Trit::Tend).count();
108 zeros as f64 / self.data.len() as f64
109 }
110
111 pub fn nnz(&self) -> usize {
113 self.data.iter().filter(|&&t| t != Trit::Tend).count()
114 }
115
116 pub fn to_i8_vec(&self) -> Vec<i8> {
118 self.data.iter().map(|&t| match t {
119 Trit::Affirm => 1,
120 Trit::Reject => -1,
121 Trit::Tend => 0,
122 }).collect()
123 }
124}
125
126pub fn dense_matmul(a: &TritMatrix, b: &TritMatrix) -> TritMatrix {
132 assert_eq!(a.cols, b.rows, "matmul dimension mismatch: a.cols must equal b.rows");
133 let mut c = TritMatrix::new(a.rows, b.cols);
134 for row in 0..a.rows {
135 for col in 0..b.cols {
136 let mut acc = Trit::Tend;
137 for k in 0..a.cols {
138 let prod = a.get(row, k) * b.get(k, col);
139 let (sum, _carry) = acc + prod;
140 acc = sum;
141 }
142 c.set(row, col, acc);
143 }
144 }
145 c
146}
147
148pub fn sparse_matmul(a: &TritMatrix, b: &TritMatrix) -> (TritMatrix, usize) {
168 use rayon::prelude::*;
169
170 assert_eq!(a.cols, b.rows, "matmul dimension mismatch");
171
172 #[inline(always)]
173 fn t2i(t: Trit) -> i8 {
174 match t { Trit::Reject => -1, Trit::Tend => 0, Trit::Affirm => 1 }
175 }
176
177 let a_flat: Vec<i8> = a.data.iter().map(|&t| t2i(t)).collect();
179 let a_cols = a.cols;
180
181 let mut csc_offsets = vec![0usize; b.cols + 1];
186 for k in 0..b.rows {
188 for j in 0..b.cols {
189 if t2i(b.data[k * b.cols + j]) != 0 {
190 csc_offsets[j + 1] += 1;
191 }
192 }
193 }
194 for j in 0..b.cols {
196 csc_offsets[j + 1] += csc_offsets[j];
197 }
198 let nnz = csc_offsets[b.cols];
199 let mut csc_idx = vec![0u32; nnz];
200 let mut csc_val = vec![0i8; nnz];
201 let mut col_cursor = csc_offsets[..b.cols].to_vec(); for k in 0..b.rows {
203 for j in 0..b.cols {
204 let w = t2i(b.data[k * b.cols + j]);
205 if w != 0 {
206 let pos = col_cursor[j];
207 csc_idx[pos] = k as u32;
208 csc_val[pos] = w;
209 col_cursor[j] += 1;
210 }
211 }
212 }
213
214 let dense_ops = a.rows * b.cols * a.cols;
215 let active_ops = nnz * a.rows;
216 let skipped = dense_ops.saturating_sub(active_ops);
217
218 let mut out_flat = vec![0i8; a.rows * b.cols];
221
222 out_flat
223 .par_chunks_mut(b.cols)
224 .enumerate()
225 .for_each(|(row, row_out)| {
226 let a_row = &a_flat[row * a_cols..(row + 1) * a_cols];
227 for col in 0..b.cols {
228 let start = csc_offsets[col];
229 let end = csc_offsets[col + 1];
230 let mut acc: i32 = 0;
231 for i in start..end {
234 let k = unsafe { *csc_idx.get_unchecked(i) } as usize;
235 let w = unsafe { *csc_val.get_unchecked(i) } as i32;
236 let av = unsafe { *a_row.get_unchecked(k) } as i32;
237 acc += av * w;
238 }
239 row_out[col] = if acc > 0 { 1 } else if acc < 0 { -1 } else { 0 };
240 }
241 });
242
243 let c_data: Vec<Trit> = out_flat.into_iter().map(|v| Trit::from(v)).collect();
245 let c = TritMatrix { rows: a.rows, cols: b.cols, data: c_data };
246
247 (c, skipped)
248}
249
250pub fn linear_confident(
258 activations: &TritFloatTensor,
259 weights: &TritMatrix,
260) -> (TritFloatTensor, usize) {
261 TritFloatTensor::matmul_trit(activations, weights)
262}
263
264pub fn linear(input: &TritMatrix, weights: &TritMatrix) -> (TritMatrix, usize) {
272 sparse_matmul(input, weights)
273}
274
275pub struct BenchmarkResult {
279 pub dense_ops: usize,
280 pub sparse_ops: usize,
281 pub skipped_ops: usize,
282 pub skip_rate: f64,
283 pub weight_sparsity: f64,
284}
285
286impl BenchmarkResult {
287 pub fn print_summary(&self) {
288 println!("=== Ternary Sparse Matmul Benchmark ===");
289 println!(" Weight sparsity: {:.1}% zeros", self.weight_sparsity * 100.0);
290 println!(" Dense ops: {}", self.dense_ops);
291 println!(" Sparse ops: {}", self.sparse_ops);
292 println!(" Skipped ops: {}", self.skipped_ops);
293 println!(" Skip rate: {:.1}%", self.skip_rate * 100.0);
294 println!(" Ops saved: {:.1}x fewer multiplies", self.dense_ops as f64 / self.sparse_ops.max(1) as f64);
295 }
296}
297
298pub fn benchmark(a: &TritMatrix, b: &TritMatrix) -> BenchmarkResult {
299 let dense_ops = a.rows * a.cols * b.cols;
300 let (_result, skipped) = sparse_matmul(a, b);
301 let sparse_ops = dense_ops - skipped;
302 BenchmarkResult {
303 dense_ops,
304 sparse_ops,
305 skipped_ops: skipped,
306 skip_rate: skipped as f64 / dense_ops as f64,
307 weight_sparsity: b.sparsity(),
308 }
309}
310
311pub fn trit_activation(t: Trit) -> Trit { t }
317
318pub fn majority(trits: &[Trit]) -> Trit {
321 let sum: i32 = trits.iter().map(|&t| match t {
322 Trit::Affirm => 1,
323 Trit::Reject => -1,
324 Trit::Tend => 0,
325 }).sum();
326 match sum.signum() {
327 1 => Trit::Affirm,
328 -1 => Trit::Reject,
329 _ => Trit::Tend,
330 }
331}
332
333pub struct TernaryMLP {
343 pub w1: TritMatrix, pub w2: TritMatrix, pub in_features: usize,
346 pub hidden_size: usize,
347 pub out_features: usize,
348}
349
350impl TernaryMLP {
351 pub fn new(w1: TritMatrix, w2: TritMatrix) -> Self {
353 let in_features = w1.rows;
354 let hidden_size = w1.cols;
355 let out_features = w2.cols;
356 assert_eq!(w2.rows, hidden_size, "w1.cols must equal w2.rows");
357 Self { w1, w2, in_features, hidden_size, out_features }
358 }
359
360 pub fn from_f32(
362 in_features: usize, hidden_size: usize, out_features: usize,
363 w1_f32: &[f32], w2_f32: &[f32],
364 ) -> Self {
365 let tau1 = bitnet_threshold(w1_f32);
366 let tau2 = bitnet_threshold(w2_f32);
367 let w1 = TritMatrix::from_f32(in_features, hidden_size, w1_f32, tau1);
368 let w2 = TritMatrix::from_f32(hidden_size, out_features, w2_f32, tau2);
369 Self::new(w1, w2)
370 }
371
372 pub fn forward(&self, input: &TritMatrix) -> (TritMatrix, usize, usize) {
376 assert_eq!(input.cols, self.in_features,
377 "input width must match in_features");
378
379 let (hidden, skip1) = sparse_matmul(input, &self.w1);
381
382 let hidden_act = TritMatrix::from_trits(
384 hidden.rows, hidden.cols,
385 hidden.data.iter().map(|&t| trit_activation(t)).collect(),
386 );
387
388 let (output, skip2) = sparse_matmul(&hidden_act, &self.w2);
390
391 (output, skip1, skip2)
392 }
393
394 pub fn predict(&self, input: &TritMatrix) -> usize {
397 let (output, _, _) = self.forward(input);
398 let row = 0;
399 let mut best_col = 0;
400 let mut best_val: i8 = -2;
401 for col in 0..self.out_features {
402 let v = match output.get(row, col) {
403 Trit::Affirm => 1,
404 Trit::Tend => 0,
405 Trit::Reject => -1,
406 };
407 if v > best_val { best_val = v; best_col = col; }
408 }
409 best_col
410 }
411
412 pub fn layer1_sparsity(&self) -> f64 { self.w1.sparsity() }
413 pub fn layer2_sparsity(&self) -> f64 { self.w2.sparsity() }
414
415 pub fn forward_logits(&self, input: &[f32]) -> Vec<f32> {
422 assert_eq!(input.len(), self.in_features);
423 let (inf, hs, outf) = (self.in_features, self.hidden_size, self.out_features);
424
425 let w1_f: Vec<f32> = self.w1.to_i8_vec().iter().map(|&v| v as f32).collect();
427 let w2_f: Vec<f32> = self.w2.to_i8_vec().iter().map(|&v| v as f32).collect();
428
429 let mut hidden = vec![0.0f32; hs];
431 for j in 0..hs {
432 for i in 0..inf {
433 hidden[j] += input[i] * w1_f[i * hs + j];
434 }
435 }
436
437 let hidden_act: Vec<f32> = hidden.iter().map(|&h| {
439 if h > 0.0 { 1.0 } else if h < 0.0 { -1.0 } else { 0.0 }
440 }).collect();
441
442 let mut output = vec![0.0f32; outf];
444 for j in 0..outf {
445 for i in 0..hs {
446 output[j] += hidden_act[i] * w2_f[i * outf + j];
447 }
448 }
449 output
450 }
451}
452
453#[derive(Debug)]
457pub struct TimedResult {
458 pub size: usize, pub dense_ops: usize,
460 pub sparse_ops: usize,
461 pub skipped_ops: usize,
462 pub weight_sparsity: f64,
463 pub skip_rate: f64,
464 pub speedup: f64,
465 pub dense_us: u64, pub sparse_us: u64, }
468
469pub fn timed_benchmark(sizes: &[usize], reps: usize) -> Vec<TimedResult> {
474 use std::time::Instant;
475
476 fn lcg_weights(n: usize, seed: u64) -> Vec<f32> {
478 let mut state = seed;
479 (0..n).map(|_| {
480 state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
481 let f = ((state >> 33) as f32) / (u32::MAX as f32) * 3.0 - 1.5;
484 f
485 }).collect()
486 }
487
488 fn median_us(mut times: Vec<u64>) -> u64 {
489 times.sort_unstable();
490 times[times.len() / 2]
491 }
492
493 sizes.iter().map(|&n| {
494 let weights_a = lcg_weights(n * n, 0xdeadbeef);
495 let weights_b = lcg_weights(n * n, 0xc0ffee42);
496 let tau_a = bitnet_threshold(&weights_a);
497 let tau_b = bitnet_threshold(&weights_b);
498 let a = TritMatrix::from_f32(n, n, &weights_a, tau_a);
499
500 let b = TritMatrix::from_f32(n, n, &weights_b, tau_b);
501
502 let sparsity = b.sparsity();
503 let dense_ops = n * n * n;
504 let (_, skipped) = sparse_matmul(&a, &b); let sparse_ops = dense_ops - skipped;
506
507 let dense_times: Vec<u64> = (0..reps).map(|_| {
509 let t = Instant::now();
510 let _ = dense_matmul(&a, &b);
511 t.elapsed().as_micros() as u64
512 }).collect();
513
514 let sparse_times: Vec<u64> = (0..reps).map(|_| {
516 let t = Instant::now();
517 let _ = sparse_matmul(&a, &b);
518 t.elapsed().as_micros() as u64
519 }).collect();
520
521 let dense_us = median_us(dense_times);
522 let sparse_us = median_us(sparse_times);
523 let speedup = if sparse_us > 0 {
524 dense_us as f64 / sparse_us as f64
525 } else { dense_ops as f64 / sparse_ops.max(1) as f64 };
526
527 TimedResult {
528 size: n, dense_ops, sparse_ops, skipped_ops: skipped,
529 weight_sparsity: sparsity, skip_rate: skipped as f64 / dense_ops as f64,
530 speedup, dense_us, sparse_us,
531 }
532 }).collect()
533}
534
535pub fn print_benchmark_table(results: &[TimedResult]) {
537 println!("\n╔══════════════════════════════════════════════════════════════════════╗");
538 println!( "║ Ternlang Sparse Matmul Benchmark — RFI-IRFOS TIS ║");
539 println!( "╠════════╦══════════╦═══════════╦══════════╦══════════╦═════════════╣");
540 println!( "║ Size ║ Sparsity ║ Dense μs ║ Sparse μs║ Speedup ║ Skip rate ║");
541 println!( "╠════════╬══════════╬═══════════╬══════════╬══════════╬═════════════╣");
542 for r in results {
543 println!("║ {:>4}² ║ {:>5.1}% ║ {:>7} ║ {:>7} ║ {:>5.2}× ║ {:>6.1}% ║",
544 r.size,
545 r.weight_sparsity * 100.0,
546 r.dense_us,
547 r.sparse_us,
548 r.speedup,
549 r.skip_rate * 100.0,
550 );
551 }
552 println!( "╚════════╩══════════╩═══════════╩══════════╩══════════╩═════════════╝");
553}
554
555pub fn bitnet_matrix(rows: usize, cols: usize, seed: u64, target_sparsity: f64) -> TritMatrix {
561 let mut state = seed;
562 let n = rows * cols;
563 let mut data = Vec::with_capacity(n);
564 for _ in 0..n {
565 state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
566 let prob = (state >> 32) as f64 / (u32::MAX as f64 + 1.0);
567 if prob < target_sparsity {
568 data.push(Trit::Tend);
569 } else if (state & 1) == 0 {
570 data.push(Trit::Affirm);
571 } else {
572 data.push(Trit::Reject);
573 }
574 }
575 TritMatrix { rows, cols, data }
576}
577
578pub fn timed_benchmark_bitnet(sizes: &[usize], reps: usize) -> Vec<TimedResult> {
582 timed_benchmark_at_sparsity(0.60, sizes, reps)
583}
584
585pub fn timed_benchmark_at_sparsity(target_sparsity: f64, sizes: &[usize], reps: usize) -> Vec<TimedResult> {
587 use std::time::Instant;
588
589 let bitnet_sparsity: f64 = target_sparsity;
590
591 fn median_us(mut v: Vec<u64>) -> u64 {
592 v.sort_unstable();
593 v[v.len() / 2]
594 }
595
596 sizes.iter().map(|&n| {
597 let a = bitnet_matrix(n, n, 0xdeadbeef, bitnet_sparsity);
598 let b = bitnet_matrix(n, n, 0xc0ffee42, bitnet_sparsity);
599
600 let sparsity = b.sparsity();
601 let dense_ops = n * n * n;
602 let (_, skipped) = sparse_matmul(&a, &b);
603 let sparse_ops = dense_ops - skipped;
604 let speedup_ops = dense_ops as f64 / sparse_ops.max(1) as f64;
605
606 let dense_times: Vec<u64> = (0..reps).map(|_| {
607 let t = Instant::now();
608 let _ = dense_matmul(&a, &b);
609 t.elapsed().as_micros() as u64
610 }).collect();
611
612 let sparse_times: Vec<u64> = (0..reps).map(|_| {
613 let t = Instant::now();
614 let _ = sparse_matmul(&a, &b);
615 t.elapsed().as_micros() as u64
616 }).collect();
617
618 let dense_us = median_us(dense_times);
619 let sparse_us = median_us(sparse_times);
620 let speedup = if sparse_us > 0 {
621 dense_us as f64 / sparse_us as f64
622 } else { speedup_ops };
623
624 TimedResult {
625 size: n, dense_ops, sparse_ops, skipped_ops: skipped,
626 weight_sparsity: sparsity, skip_rate: skipped as f64 / dense_ops as f64,
627 speedup, dense_us, sparse_us,
628 }
629 }).collect()
630}
631
632pub fn xor_dataset() -> Vec<(TritMatrix, usize)> {
637 let inputs = vec![
638 (vec![Trit::Reject, Trit::Reject], 0usize), (vec![Trit::Reject, Trit::Affirm], 1usize), (vec![Trit::Affirm, Trit::Reject], 1usize), (vec![Trit::Affirm, Trit::Affirm], 0usize), ];
643 inputs.into_iter().map(|(row, label)| {
644 (TritMatrix::from_trits(1, 2, row), label)
645 }).collect()
646}
647
648pub fn parity_dataset() -> Vec<(TritMatrix, usize)> {
650 (0u8..8).map(|i| {
651 let bits = vec![
652 if i & 4 != 0 { Trit::Affirm } else { Trit::Reject },
653 if i & 2 != 0 { Trit::Affirm } else { Trit::Reject },
654 if i & 1 != 0 { Trit::Affirm } else { Trit::Reject },
655 ];
656 let parity = (i.count_ones() % 2) as usize;
657 (TritMatrix::from_trits(1, 3, bits), parity)
658 }).collect()
659}
660
661pub fn evaluate(mlp: &TernaryMLP, dataset: &[(TritMatrix, usize)]) -> (usize, usize, f64) {
664 let total = dataset.len();
665 let correct = dataset.iter()
666 .filter(|(input, label)| mlp.predict(input) == *label)
667 .count();
668 let accuracy = correct as f64 / total as f64;
669 (correct, total, accuracy)
670}
671
672pub const TEND_BOUNDARY: f32 = 1.0 / 3.0;
687
688#[derive(Debug, Clone)]
690pub struct TritScalar(pub f32);
691
692impl TritScalar {
693 pub fn new(v: f32) -> Self { TritScalar(v.clamp(-1.0, 1.0)) }
695
696 pub fn trit(&self) -> Trit {
698 if self.0 > TEND_BOUNDARY { Trit::Affirm }
699 else if self.0 < -TEND_BOUNDARY { Trit::Reject }
700 else { Trit::Tend }
701 }
702
703 pub fn label(&self) -> &'static str {
705 match self.trit() {
706 Trit::Affirm => "affirm",
707 Trit::Reject => "reject",
708 Trit::Tend => "tend",
709 }
710 }
711
712 pub fn confidence(&self) -> f32 {
717 let v = self.0.abs();
718 if v > TEND_BOUNDARY {
719 (v - TEND_BOUNDARY) / (1.0 - TEND_BOUNDARY)
720 } else {
721 1.0 - v / TEND_BOUNDARY
722 }
723 }
724
725 pub fn is_actionable(&self, min_confidence: f32) -> bool {
728 self.trit() != Trit::Tend && self.confidence() >= min_confidence
729 }
730
731 pub fn raw(&self) -> f32 { self.0 }
733
734 pub fn trit_i8(&self) -> i8 {
736 match self.trit() { Trit::Affirm => 1, Trit::Reject => -1, Trit::Tend => 0 }
737 }
738}
739
740pub struct TritEvidenceVec {
754 pub dimensions: Vec<String>,
755 pub values: Vec<f32>, pub weights: Vec<f32>, }
758
759impl TritEvidenceVec {
760 pub fn new(dimensions: Vec<String>, values: Vec<f32>, weights: Vec<f32>) -> Self {
761 assert_eq!(dimensions.len(), values.len(), "dimensions and values must match");
762 assert_eq!(dimensions.len(), weights.len(), "dimensions and weights must match");
763 let values = values.iter().map(|&v| v.clamp(-1.0, 1.0)).collect();
764 TritEvidenceVec { dimensions, values, weights }
765 }
766
767 pub fn aggregate(&self) -> TritScalar {
769 let total_weight: f32 = self.weights.iter().sum();
770 if total_weight == 0.0 { return TritScalar::new(0.0); }
771 let weighted_sum: f32 = self.values.iter()
772 .zip(self.weights.iter())
773 .map(|(v, w)| v * w)
774 .sum();
775 TritScalar::new(weighted_sum / total_weight)
776 }
777
778 pub fn scalars(&self) -> Vec<TritScalar> {
780 self.values.iter().map(|&v| TritScalar::new(v)).collect()
781 }
782
783 pub fn dominant(&self) -> Option<(&str, TritScalar)> {
785 self.values.iter()
786 .enumerate()
787 .max_by(|(_, a), (_, b)| a.abs().partial_cmp(&b.abs()).unwrap_or(std::cmp::Ordering::Equal))
788 .map(|(i, &v)| (self.dimensions[i].as_str(), TritScalar::new(v)))
789 }
790}
791
792#[cfg(test)]
795mod tests {
796 use super::*;
797
798 #[test]
799 fn test_quantize_basic() {
800 let weights = vec![-0.9f32, -0.2, 0.0, 0.3, 0.8];
801 let threshold = 0.5;
802 let trits = quantize(&weights, threshold);
803 assert_eq!(trits, vec![Trit::Reject, Trit::Tend, Trit::Tend, Trit::Tend, Trit::Affirm]);
804 }
805
806 #[test]
807 fn test_bitnet_threshold() {
808 let weights = vec![1.0f32, -1.0, 0.5, -0.5];
809 let tau = bitnet_threshold(&weights);
810 assert!((tau - 0.375).abs() < 1e-6);
812 }
813 #[test]
814 fn test_dense_matmul_identity() {
815 let mut id = TritMatrix::new(2, 2);
817 id.set(0, 0, Trit::Affirm);
818 id.set(1, 1, Trit::Affirm);
819
820 let result = dense_matmul(&id, &id);
821 assert_eq!(result.get(0, 0), Trit::Affirm);
822 assert_eq!(result.get(0, 1), Trit::Tend);
823 assert_eq!(result.get(1, 0), Trit::Tend);
824 assert_eq!(result.get(1, 1), Trit::Affirm);
825 }
826
827 #[test]
828 fn test_sparse_matmul_matches_dense() {
829 let weights = vec![0.9f32, -0.1, 0.05, -0.8, 0.0, 0.7, -0.6, 0.2, 0.0];
831 let threshold = 0.5;
832 let w = TritMatrix::from_f32(3, 3, &weights, threshold);
833 let mut input = TritMatrix::new(3, 3);
834 input.set(0, 0, Trit::Affirm);
835 input.set(1, 1, Trit::Reject);
836 input.set(2, 2, Trit::Affirm);
837
838 let dense = dense_matmul(&input, &w);
839 let (sparse, skipped) = sparse_matmul(&input, &w);
840
841 for r in 0..3 {
843 for c in 0..3 {
844 assert_eq!(dense.get(r, c), sparse.get(r, c),
845 "mismatch at ({}, {})", r, c);
846 }
847 }
848 assert!(skipped > 0, "expected skips for a sparse weight matrix");
850 }
851
852 #[test]
853 fn test_sparsity_measurement() {
854 let weights = vec![0.9f32, 0.1, -0.9]; let threshold = 0.5;
856 let m = TritMatrix::from_f32(1, 3, &weights, threshold);
857 assert!((m.sparsity() - 1.0/3.0).abs() < 1e-9);
859 assert_eq!(m.nnz(), 2);
860 }
861
862 #[test]
863 fn test_majority_vote() {
864 assert_eq!(majority(&[Trit::Affirm, Trit::Affirm, Trit::Reject]), Trit::Affirm);
865 assert_eq!(majority(&[Trit::Reject, Trit::Reject, Trit::Affirm]), Trit::Reject);
866 assert_eq!(majority(&[Trit::Affirm, Trit::Reject]), Trit::Tend);
867 assert_eq!(majority(&[Trit::Tend, Trit::Tend]), Trit::Tend);
868 }
869
870 #[test]
871 fn test_mlp_forward_runs() {
872 let w1_f32: Vec<f32> = vec![
874 0.9, -0.8, 0.7, -0.6,
875 -0.7, 0.9, -0.5, 0.8,
876 ];
877 let w2_f32: Vec<f32> = vec![
878 0.9, -0.9,
879 -0.8, 0.8,
880 0.7, -0.7,
881 -0.6, 0.6,
882 ];
883 let mlp = TernaryMLP::from_f32(2, 4, 2, &w1_f32, &w2_f32);
884 let input = TritMatrix::from_trits(1, 2, vec![Trit::Affirm, Trit::Reject]);
885 let (out, s1, s2) = mlp.forward(&input);
886 assert_eq!(out.rows, 1);
887 assert_eq!(out.cols, 2);
888 let _ = (s1, s2);
890 }
891
892 #[test]
893 fn test_mlp_predict_returns_valid_class() {
894 let w1_f32: Vec<f32> = vec![0.9, -0.8, -0.7, 0.9];
895 let w2_f32: Vec<f32> = vec![0.9, -0.9, -0.8, 0.8];
896 let mlp = TernaryMLP::from_f32(2, 2, 2, &w1_f32, &w2_f32);
897 let input = TritMatrix::from_trits(1, 2, vec![Trit::Affirm, Trit::Reject]);
898 let pred = mlp.predict(&input);
899 assert!(pred < 2, "prediction must be a valid class index");
900 }
901
902 #[test]
903 fn test_xor_dataset_shape() {
904 let ds = xor_dataset();
905 assert_eq!(ds.len(), 4);
906 for (input, label) in &ds {
907 assert_eq!(input.rows, 1);
908 assert_eq!(input.cols, 2);
909 assert!(*label < 2);
910 }
911 }
912
913 #[test]
914 fn test_parity_dataset_shape() {
915 let ds = parity_dataset();
916 assert_eq!(ds.len(), 8);
917 for (input, label) in &ds {
918 assert_eq!(input.cols, 3);
919 assert!(*label < 2);
920 }
921 }
922
923 #[test]
924 fn test_xor_mlp_with_known_weights() {
925 let w1_f32 = vec![
931 1.0, -1.0,
932 -1.0, 1.0,
933 ];
934 let w2_f32 = vec![
937 -1.0, 1.0,
938 -1.0, 1.0,
939 ];
940 let mlp = TernaryMLP::from_f32(2, 2, 2, &w1_f32, &w2_f32);
941 let ds = xor_dataset();
942 let (correct, total, acc) = evaluate(&mlp, &ds);
943 println!("XOR MLP: {}/{} = {:.0}%", correct, total, acc * 100.0);
944 assert!(correct >= 2, "MLP should get at least half of XOR correct");
947 }
948
949 #[test]
950 fn test_timed_benchmark_small() {
951 let results = timed_benchmark(&[8, 16], 3);
952 assert_eq!(results.len(), 2);
953 for r in &results {
954 assert!(r.dense_ops > 0);
955 assert!(r.weight_sparsity >= 0.0 && r.weight_sparsity <= 1.0);
956 assert!(r.skip_rate >= 0.0 && r.skip_rate <= 1.0);
957 }
958 print_benchmark_table(&results);
959 }
960
961 #[test]
962 fn test_benchmark_reports_skips() {
963 let weights: Vec<f32> = vec![
965 0.9, 0.1, -0.9, 0.0,
966 0.1, 0.8, 0.0, -0.7,
967 0.0, 0.1, 0.6, 0.2,
968 -0.8, 0.0, 0.1, 0.9,
969 ];
970 let threshold = 0.5;
971 let w = TritMatrix::from_f32(4, 4, &weights, threshold);
972 let input = TritMatrix::new(4, 4); let result = benchmark(&input, &w);
974 assert!(result.skipped_ops > 0);
975 assert!(result.skip_rate > 0.0 && result.skip_rate <= 1.0);
976 result.print_summary();
977 }
978
979 #[test]
980 fn test_full_benchmark() {
981 let results = timed_benchmark(&[32, 64, 128, 256, 512], 5);
982 assert_eq!(results.len(), 5);
983 print_benchmark_table(&results);
984 }
985
986 #[test]
989 fn test_bitnet_benchmark() {
990 let results = timed_benchmark_bitnet(&[32, 64, 128, 256, 512], 5);
991 assert_eq!(results.len(), 5);
992 println!("\n╔══════════════════════════════════════════════════════════════════════╗");
993 println!( "║ BitNet b1.58 Realistic Benchmark — 60% Sparsity — RFI-IRFOS TIS ║");
994 println!( "╠════════╦══════════╦═══════════╦══════════╦══════════╦═════════════╣");
995 println!( "║ Size ║ Sparsity ║ Dense μs ║ Sparse μs║ Speedup ║ Skip rate ║");
996 println!( "╠════════╬══════════╬═══════════╬══════════╬══════════╬═════════════╣");
997 for r in &results {
998 println!("║ {:>4}² ║ {:>5.1}% ║ {:>7} ║ {:>7} ║ {:>5.2}× ║ {:>6.1}% ║",
999 r.size,
1000 r.weight_sparsity * 100.0,
1001 r.dense_us,
1002 r.sparse_us,
1003 r.speedup,
1004 r.skip_rate * 100.0,
1005 );
1006 }
1007 println!( "╚════════╩══════════╩═══════════╩══════════╩══════════╩═════════════╝");
1008 for r in &results {
1009 assert!(r.skip_rate >= 0.50, "Expected ≥50% skip rate at 60% sparsity, got {:.1}%", r.skip_rate * 100.0);
1010 }
1011 }
1012
1013 #[test]
1015 fn test_extreme_sparsity_99() {
1016 let results = timed_benchmark_at_sparsity(0.99, &[32, 64, 128, 256, 512], 5);
1017 assert_eq!(results.len(), 5);
1018 println!("\n╔══════════════════════════════════════════════════════════════════════╗");
1019 println!( "║ EXTREME SPARSITY — 99% Zeros — What Happens? ║");
1020 println!( "╠════════╦══════════╦═══════════╦══════════╦══════════╦═════════════╣");
1021 println!( "║ Size ║ Sparsity ║ Dense μs ║ Sparse μs║ Speedup ║ Skip rate ║");
1022 println!( "╠════════╬══════════╬═══════════╬══════════╬══════════╬═════════════╣");
1023 for r in &results {
1024 println!("║ {:>4}² ║ {:>5.1}% ║ {:>7} ║ {:>7} ║ {:>6.1}× ║ {:>6.1}% ║",
1025 r.size,
1026 r.weight_sparsity * 100.0,
1027 r.dense_us,
1028 r.sparse_us,
1029 r.speedup,
1030 r.skip_rate * 100.0,
1031 );
1032 }
1033 println!( "╚════════╩══════════╩═══════════╩══════════╩══════════╩═════════════╝");
1034 for r in &results {
1035 assert!(r.skip_rate >= 0.95, "Expected ≥95% skip rate at 99% sparsity");
1036 }
1037 }
1038
1039 #[test]
1042 fn test_sparsity_sweep() {
1043 let sparsities: &[f64] = &[0.25, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95, 0.99];
1044 let sizes: &[usize] = &[32, 64, 128, 256, 512];
1045
1046 let mut grid: Vec<Vec<f64>> = Vec::new();
1048 for &sp in sparsities {
1049 let row: Vec<f64> = timed_benchmark_at_sparsity(sp, sizes, 3)
1050 .into_iter().map(|r| r.speedup).collect();
1051 grid.push(row);
1052 }
1053
1054 println!();
1056 println!("╔══════════════ SPARSITY GOLDILOCKS SWEEP ══════════════════════════╗");
1057 println!("║ Speedup (sparse / dense) across sparsity × matrix size ║");
1058 println!("╠══════════╦═══════╦═══════╦════════╦════════╦════════╣");
1059 print!( "║ Sparsity ║");
1060 for &n in sizes { print!(" {:>4}² ║", n); }
1061 println!();
1062 println!("╠══════════╬═══════╬═══════╬════════╬════════╬════════╣");
1063
1064 let mut peak_speedup = 0f64;
1065 let mut peak_sp = 0f64;
1066 let mut peak_n = 0usize;
1067
1068 for (i, &sp) in sparsities.iter().enumerate() {
1069 print!("║ {:>5.1}% ║", sp * 100.0);
1070 for (j, &speedup) in grid[i].iter().enumerate() {
1071 if speedup > peak_speedup {
1072 peak_speedup = speedup;
1073 peak_sp = sp;
1074 peak_n = sizes[j];
1075 }
1076 print!(" {:>5.1}× ║", speedup);
1077 }
1078 println!();
1079 }
1080
1081 println!("╚══════════╩═══════╩═══════╩════════╩════════╩════════╝");
1082 println!();
1083 println!(" ★ Peak: {:.1}× at {:.0}% sparsity, {}×{} matrix", peak_speedup, peak_sp * 100.0, peak_n, peak_n);
1084
1085 let avg_speedups: Vec<(f64, f64)> = sparsities.iter().zip(grid.iter())
1087 .map(|(&sp, row)| (sp, row.iter().sum::<f64>() / row.len() as f64))
1088 .collect();
1089 let (best_sp, best_avg) = avg_speedups.iter()
1090 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
1091 .copied().unwrap();
1092 println!(" ◆ Goldilocks zone: {:.0}% sparsity → {:.1}× average across all sizes", best_sp * 100.0, best_avg);
1093 println!();
1094
1095 for row in &grid {
1098 for &s in &row[1..] { assert!(s >= 1.0, "Speedup dropped below 1× — something is wrong");
1100 }
1101 }
1102 }
1103
1104 #[test]
1107 fn test_trit_scalar_zones() {
1108 assert_eq!(TritScalar::new(0.9).label(), "affirm");
1109 assert_eq!(TritScalar::new(-0.9).label(), "reject");
1110 assert_eq!(TritScalar::new(0.0).label(), "tend");
1111 assert_eq!(TritScalar::new(0.33).label(), "tend"); assert_eq!(TritScalar::new(0.34).label(), "affirm"); }
1114
1115 #[test]
1116 fn test_trit_scalar_confidence() {
1117 let s = TritScalar::new(0.0);
1119 assert_eq!(s.label(), "tend");
1120 assert!((s.confidence() - 1.0).abs() < 0.01);
1121
1122 let s = TritScalar::new(1.0);
1124 assert_eq!(s.label(), "affirm");
1125 assert!((s.confidence() - 1.0).abs() < 0.01);
1126
1127 let s = TritScalar::new(TEND_BOUNDARY + 0.001);
1129 assert_eq!(s.label(), "affirm");
1130 assert!(s.confidence() < 0.01);
1131 }
1132
1133 #[test]
1134 fn test_trit_scalar_actionable() {
1135 assert!(TritScalar::new(0.9).is_actionable(0.5));
1137 assert!(!TritScalar::new(0.35).is_actionable(0.8));
1139 assert!(!TritScalar::new(0.0).is_actionable(0.0));
1141 }
1142
1143 #[test]
1144 fn test_trit_scalar_clamp() {
1145 assert!((TritScalar::new(5.0).raw() - 1.0).abs() < 0.001);
1146 assert!((TritScalar::new(-5.0).raw() + 1.0).abs() < 0.001);
1147 }
1148
1149 #[test]
1152 fn test_evidence_vec_aggregate_uniform() {
1153 let ev = TritEvidenceVec::new(
1155 vec!["a".into(), "b".into(), "c".into()],
1156 vec![0.8, 0.9, 0.7],
1157 vec![1.0, 1.0, 1.0],
1158 );
1159 let agg = ev.aggregate();
1160 assert_eq!(agg.label(), "affirm");
1161 assert!(agg.confidence() > 0.5);
1162 }
1163
1164 #[test]
1165 fn test_evidence_vec_mixed_signals() {
1166 let ev = TritEvidenceVec::new(
1168 vec!["strong_reject".into(), "weak_affirm".into()],
1169 vec![-0.9, 0.1],
1170 vec![1.0, 1.0],
1171 );
1172 let agg = ev.aggregate();
1173 assert_eq!(agg.label(), "reject");
1175 }
1176
1177 #[test]
1178 fn test_evidence_vec_weighted_override() {
1179 let ev = TritEvidenceVec::new(
1181 vec!["weak_reject".into(), "strong_affirm".into()],
1182 vec![-0.4, 0.9],
1183 vec![10.0, 1.0], );
1185 let agg = ev.aggregate();
1186 assert_eq!(agg.label(), "tend");
1188 }
1189
1190 #[test]
1191 fn test_evidence_vec_dominant() {
1192 let ev = TritEvidenceVec::new(
1193 vec!["low".into(), "high".into(), "mid".into()],
1194 vec![0.2, -0.95, 0.5],
1195 vec![1.0, 1.0, 1.0],
1196 );
1197 let (label, scalar) = ev.dominant().unwrap();
1198 assert_eq!(label, "high");
1199 assert_eq!(scalar.label(), "reject");
1200 }
1201}
1202
1203#[derive(Debug, Clone)]
1221pub struct DeliberationRound {
1222 pub round: usize,
1223 pub new_evidence: Vec<f32>, pub cumulative_mean: f32, pub scalar: TritScalar,
1226 pub converged: bool, }
1228
1229#[derive(Debug, Clone)]
1231pub struct DeliberationResult {
1232 pub final_trit: i8,
1233 pub final_label: String,
1234 pub final_confidence: f32,
1235 pub converged: bool,
1236 pub rounds_used: usize,
1237 pub trace: Vec<DeliberationRound>,
1238 pub convergence_reason: String,
1239}
1240
1241pub struct DeliberationEngine {
1250 pub target_confidence: f32,
1252 pub max_rounds: usize,
1254 pub alpha: f32,
1256}
1257
1258impl DeliberationEngine {
1259 pub fn new(target_confidence: f32, max_rounds: usize) -> Self {
1260 Self { target_confidence, max_rounds, alpha: 0.4 }
1261 }
1262
1263 pub fn with_alpha(mut self, alpha: f32) -> Self { self.alpha = alpha.clamp(0.01, 1.0); self }
1264
1265 pub fn run(&self, rounds_evidence: Vec<Vec<f32>>) -> DeliberationResult {
1268 let mut ema: f32 = 0.0; let mut initialized = false;
1270 let mut trace = Vec::new();
1271
1272 let rounds_to_run = self.max_rounds.min(
1273 if rounds_evidence.is_empty() { self.max_rounds } else { rounds_evidence.len() }
1274 );
1275
1276 for round in 0..rounds_to_run {
1277 let new_ev: Vec<f32> = rounds_evidence.get(round).cloned().unwrap_or_default();
1278
1279 if !new_ev.is_empty() {
1281 let round_mean = new_ev.iter().sum::<f32>() / new_ev.len() as f32;
1282 ema = if !initialized {
1283 initialized = true;
1284 round_mean
1285 } else {
1286 self.alpha * round_mean + (1.0 - self.alpha) * ema
1287 };
1288 }
1289
1290 let scalar = TritScalar::new(ema);
1291 let converged = scalar.confidence() >= self.target_confidence;
1292
1293 trace.push(DeliberationRound {
1294 round,
1295 new_evidence: new_ev,
1296 cumulative_mean: ema,
1297 scalar: scalar.clone(),
1298 converged,
1299 });
1300
1301 if converged { break; }
1302 }
1303
1304 let last = trace.last().cloned().unwrap_or_else(|| DeliberationRound {
1305 round: 0, new_evidence: vec![], cumulative_mean: 0.0,
1306 scalar: TritScalar::new(0.0), converged: false,
1307 });
1308
1309 let convergence_reason = if last.converged {
1310 format!("confidence {:.1}% ≥ target {:.1}% after {} round(s)",
1311 last.scalar.confidence() * 100.0,
1312 self.target_confidence * 100.0,
1313 last.round + 1)
1314 } else {
1315 format!("max rounds ({}) reached — confidence {:.1}% below target {:.1}%",
1316 self.max_rounds,
1317 last.scalar.confidence() * 100.0,
1318 self.target_confidence * 100.0)
1319 };
1320
1321 DeliberationResult {
1322 final_trit: last.scalar.trit_i8(),
1323 final_label: last.scalar.label().to_string(),
1324 final_confidence: last.scalar.confidence(),
1325 converged: last.converged,
1326 rounds_used: last.round + 1,
1327 trace,
1328 convergence_reason,
1329 }
1330 }
1331}
1332
1333#[derive(Debug, Clone)]
1337pub struct CoalitionMember {
1338 pub label: String,
1339 pub trit: i8, pub confidence: f32, pub weight: f32, }
1343
1344impl CoalitionMember {
1345 pub fn new(label: impl Into<String>, trit: i8, confidence: f32, weight: f32) -> Self {
1346 Self {
1347 label: label.into(),
1348 trit: trit.clamp(-1, 1),
1349 confidence: confidence.clamp(0.0, 1.0),
1350 weight: weight.max(0.0),
1351 }
1352 }
1353}
1354
1355#[derive(Debug, Clone)]
1357pub struct CoalitionResult {
1358 pub trit: i8,
1359 pub label: String,
1360 pub aggregate_score: f32, pub quorum: f32, pub dissent_rate: f32, pub abstain_rate: f32, pub member_count: usize,
1365 pub effective_weight: f32, pub breakdown: Vec<(String, i8, f32)>, }
1368
1369pub fn coalition_vote(members: &[CoalitionMember]) -> CoalitionResult {
1374 if members.is_empty() {
1375 return CoalitionResult {
1376 trit: 0, label: "tend".into(), aggregate_score: 0.0,
1377 quorum: 0.0, dissent_rate: 0.0, abstain_rate: 1.0,
1378 member_count: 0, effective_weight: 0.0, breakdown: vec![],
1379 };
1380 }
1381
1382 let total_weight: f32 = members.iter().map(|m| m.weight).sum();
1383 let total_weight = if total_weight == 0.0 { 1.0 } else { total_weight };
1384
1385 let mut weighted_sum: f32 = 0.0;
1386 let mut non_zero_weight: f32 = 0.0;
1387 let mut breakdown = Vec::new();
1388
1389 for m in members {
1390 let contribution = (m.trit as f32) * m.confidence * m.weight;
1391 weighted_sum += contribution;
1392 if m.trit != 0 { non_zero_weight += m.weight; }
1393 breakdown.push((m.label.clone(), m.trit, contribution / total_weight));
1394 }
1395
1396 let aggregate_score = weighted_sum / total_weight;
1397 let scalar = TritScalar::new(aggregate_score);
1398 let result_trit: i8 = scalar.trit_i8();
1399
1400 let quorum = non_zero_weight / total_weight;
1401 let abstain_rate = 1.0 - quorum;
1402 let dissent_rate = members.iter()
1403 .filter(|m| m.trit != 0 && m.trit.signum() != result_trit.signum())
1404 .map(|m| m.weight)
1405 .sum::<f32>() / total_weight;
1406
1407 CoalitionResult {
1408 trit: result_trit,
1409 label: scalar.label().to_string(),
1410 aggregate_score,
1411 quorum,
1412 dissent_rate,
1413 abstain_rate,
1414 member_count: members.len(),
1415 effective_weight: non_zero_weight,
1416 breakdown,
1417 }
1418}
1419
1420#[derive(Debug, Clone)]
1424pub struct GateDimension {
1425 pub name: String,
1426 pub evidence: f32, pub weight: f32, pub hard_block: bool,
1431}
1432
1433impl GateDimension {
1434 pub fn new(name: impl Into<String>, evidence: f32, weight: f32) -> Self {
1435 Self { name: name.into(), evidence, weight, hard_block: false }
1436 }
1437 pub fn hard(mut self) -> Self { self.hard_block = true; self }
1438}
1439
1440#[derive(Debug, Clone, PartialEq, Eq)]
1442pub enum GateVerdict {
1443 Proceed,
1445 Hold,
1447 Block,
1449}
1450
1451impl GateVerdict {
1452 pub fn label(&self) -> &'static str {
1453 match self {
1454 GateVerdict::Proceed => "proceed",
1455 GateVerdict::Hold => "hold",
1456 GateVerdict::Block => "block",
1457 }
1458 }
1459}
1460
1461#[derive(Debug, Clone)]
1463pub struct GateResult {
1464 pub verdict: GateVerdict,
1465 pub aggregate: TritScalar,
1466 pub hard_blocked_by: Vec<String>, pub dim_results: Vec<(String, TritScalar, bool)>, pub explanation: String,
1469}
1470
1471pub fn action_gate(dimensions: &[GateDimension]) -> GateResult {
1478 let mut hard_blocked_by = Vec::new();
1479 let mut dim_results = Vec::new();
1480 let mut weighted_sum = 0.0f32;
1481 let mut total_weight = 0.0f32;
1482
1483 for dim in dimensions {
1484 let scalar = TritScalar::new(dim.evidence);
1485 let is_neg = matches!(scalar.trit(), Trit::Reject);
1486
1487 if dim.hard_block && is_neg {
1488 hard_blocked_by.push(dim.name.clone());
1489 }
1490
1491 weighted_sum += dim.evidence * dim.weight;
1492 total_weight += dim.weight;
1493 dim_results.push((dim.name.clone(), scalar, dim.hard_block));
1494 }
1495
1496 if !hard_blocked_by.is_empty() {
1498 let explanation = format!(
1499 "BLOCKED — hard constraint(s) violated: {}",
1500 hard_blocked_by.join(", ")
1501 );
1502 return GateResult {
1503 verdict: GateVerdict::Block,
1504 aggregate: TritScalar::new(-1.0),
1505 hard_blocked_by,
1506 dim_results,
1507 explanation,
1508 };
1509 }
1510
1511 let agg_score = if total_weight > 0.0 { weighted_sum / total_weight } else { 0.0 };
1512 let aggregate = TritScalar::new(agg_score);
1513
1514 let verdict = match aggregate.trit() {
1515 Trit::Affirm => GateVerdict::Proceed,
1516 Trit::Tend => GateVerdict::Hold,
1517 Trit::Reject => GateVerdict::Block,
1518 };
1519
1520 let explanation = match &verdict {
1521 GateVerdict::Proceed => format!(
1522 "PROCEED — all dimensions pass (aggregate confidence {:.0}%)",
1523 aggregate.confidence() * 100.0
1524 ),
1525 GateVerdict::Hold => format!(
1526 "HOLD — insufficient evidence (aggregate {:.3} within deliberation zone)",
1527 aggregate.raw()
1528 ),
1529 GateVerdict::Block => format!(
1530 "BLOCK — weighted aggregate {:.3} below threshold (confidence {:.0}%)",
1531 aggregate.raw(), aggregate.confidence() * 100.0
1532 ),
1533 };
1534
1535 GateResult { verdict, aggregate, hard_blocked_by, dim_results, explanation }
1536}
1537
1538#[derive(Debug, Clone)]
1553pub struct ScalarTemperature {
1554 pub trit: i8,
1555 pub confidence: f32,
1556 pub temperature: f32,
1557 pub reasoning: String,
1558 pub prompt_hint: String,
1560}
1561
1562pub fn scalar_temperature(scalar: &TritScalar) -> ScalarTemperature {
1563 let t = scalar.trit();
1564 let c = scalar.confidence(); let (temp, reasoning, prompt_hint) = match t {
1567 Trit::Affirm => {
1568 let temp = 0.3 - (c * 0.25); (
1571 temp.max(0.05),
1572 format!("Affirm (confidence {:.0}%) — execute precisely, minimal exploration", c * 100.0),
1573 "Be concise and direct. Evidence is clear. Do not hedge.".to_string(),
1574 )
1575 }
1576 Trit::Reject => {
1577 let temp = 0.15 - (c * 0.10); (
1580 temp.max(0.05),
1581 format!("Reject (confidence {:.0}%) — decline firmly, minimal hedging", c * 100.0),
1582 "Decline clearly. Do not offer alternatives unless explicitly asked. Evidence is against.".to_string(),
1583 )
1584 }
1585 Trit::Tend => {
1586 let temp = 0.7 + ((1.0 - c) * 0.3); (
1589 temp.min(1.0),
1590 format!("Tend (confidence {:.0}%) — evidence is conflicted, explore broadly", c * 100.0),
1591 "You are in deliberation. Present multiple perspectives. Ask clarifying questions. Do not commit.".to_string(),
1592 )
1593 }
1594 };
1595
1596 ScalarTemperature {
1597 trit: scalar.trit_i8(),
1598 confidence: c,
1599 temperature: (temp * 1000.0).round() / 1000.0,
1600 reasoning,
1601 prompt_hint,
1602 }
1603}
1604
1605#[derive(Debug, Clone)]
1617pub struct HallucinationScore {
1618 pub trust_trit: i8,
1619 pub trust_label: String,
1620 pub mean: f32, pub variance: f32, pub consistency: f32, pub signal_count: usize,
1624 pub explanation: String,
1625}
1626
1627pub fn hallucination_score(signals: &[f32]) -> HallucinationScore {
1628 if signals.is_empty() {
1629 return HallucinationScore {
1630 trust_trit: 0, trust_label: "tend".into(), mean: 0.0,
1631 variance: 0.0, consistency: 0.0, signal_count: 0,
1632 explanation: "No signals provided — cannot assess consistency.".into(),
1633 };
1634 }
1635
1636 let n = signals.len() as f32;
1637 let mean = signals.iter().sum::<f32>() / n;
1638 let variance = signals.iter().map(|&s| (s - mean).powi(2)).sum::<f32>() / n;
1639
1640 let norm_variance = variance.min(1.0);
1642 let consistency = 1.0 - norm_variance;
1643
1644 let trust_evidence = (consistency * 2.0 - 1.0) * mean.abs(); let trust = TritScalar::new(trust_evidence);
1649
1650 let explanation = if trust.trit() == Trit::Affirm {
1651 format!(
1652 "Consistent signals (variance {:.3}, consistency {:.0}%) — evidence coheres around {:.3}",
1653 variance, consistency * 100.0, mean
1654 )
1655 } else if trust.trit() == Trit::Reject {
1656 format!(
1657 "HIGH VARIANCE (variance {:.3}) — signals are internally contradictory. Possible hallucination or conflated sources.",
1658 variance
1659 )
1660 } else {
1661 format!(
1662 "Mixed consistency (variance {:.3}, mean {:.3}) — gather more evidence before relying on this claim.",
1663 variance, mean
1664 )
1665 };
1666
1667 HallucinationScore {
1668 trust_trit: trust.trit_i8(),
1669 trust_label: trust.label().to_string(),
1670 mean,
1671 variance,
1672 consistency,
1673 signal_count: signals.len(),
1674 explanation,
1675 }
1676}
1677
1678#[cfg(test)]
1681mod reasoning_tests {
1682 use super::*;
1683
1684 #[test]
1687 fn test_deliberation_converges_on_strong_evidence() {
1688 let engine = DeliberationEngine::new(0.7, 10).with_alpha(0.7);
1690 let rounds = vec![
1691 vec![0.85, 0.9], vec![0.9, 0.95], vec![0.92, 0.95, 0.98], ];
1695 let result = engine.run(rounds);
1696 assert!(result.converged, "should converge on strong positive evidence (got confidence {:.2})", result.final_confidence);
1697 assert_eq!(result.final_trit, 1, "should be +1 (affirm)");
1698 assert!(result.rounds_used <= 3);
1699 }
1700
1701 #[test]
1702 fn test_deliberation_holds_on_weak_evidence() {
1703 let engine = DeliberationEngine::new(0.95, 3);
1704 let rounds = vec![
1705 vec![0.1f32],
1706 vec![-0.05],
1707 vec![0.15],
1708 ];
1709 let result = engine.run(rounds);
1710 assert!(!result.converged, "should not converge on weak conflicting evidence");
1711 assert_eq!(result.final_trit, 0, "should stay at hold/tend");
1712 assert_eq!(result.rounds_used, 3);
1713 }
1714
1715 #[test]
1716 fn test_deliberation_negative_convergence() {
1717 let engine = DeliberationEngine::new(0.8, 10);
1718 let rounds = vec![
1719 vec![-0.9f32, -0.85],
1720 vec![-0.95, -0.99],
1721 ];
1722 let result = engine.run(rounds);
1723 assert!(result.converged);
1724 assert_eq!(result.final_trit, -1);
1725 }
1726
1727 #[test]
1730 fn test_coalition_unanimous_affirm() {
1731 let members = vec![
1732 CoalitionMember::new("safety", 1, 0.9, 3.0),
1733 CoalitionMember::new("utility", 1, 0.8, 1.0),
1734 CoalitionMember::new("alignment", 1, 0.95, 2.0),
1735 ];
1736 let result = coalition_vote(&members);
1737 assert_eq!(result.trit, 1);
1738 assert_eq!(result.label, "affirm");
1739 assert!(result.quorum > 0.99, "all voted");
1740 assert!(result.dissent_rate < 0.01);
1741 }
1742
1743 #[test]
1744 fn test_coalition_split_vote_tends_to_hold() {
1745 let members = vec![
1746 CoalitionMember::new("agent_a", 1, 0.8, 1.0),
1747 CoalitionMember::new("agent_b", -1, 0.8, 1.0),
1748 CoalitionMember::new("agent_c", 0, 0.5, 1.0),
1749 ];
1750 let result = coalition_vote(&members);
1751 assert_eq!(result.trit, 0);
1753 assert!(result.dissent_rate > 0.0, "there is dissent");
1754 }
1755
1756 #[test]
1757 fn test_coalition_high_weight_overrides() {
1758 let members = vec![
1759 CoalitionMember::new("expert", 1, 0.95, 10.0), CoalitionMember::new("novice_a", -1, 0.5, 1.0),
1761 CoalitionMember::new("novice_b", -1, 0.5, 1.0),
1762 ];
1763 let result = coalition_vote(&members);
1764 assert_eq!(result.trit, 1, "high-weight expert should dominate");
1766 }
1767
1768 #[test]
1771 fn test_gate_all_positive_proceeds() {
1772 let dims = vec![
1773 GateDimension::new("safety", 0.8, 3.0),
1774 GateDimension::new("utility", 0.7, 1.0),
1775 GateDimension::new("legality", 0.9, 2.0),
1776 ];
1777 let result = action_gate(&dims);
1778 assert_eq!(result.verdict, GateVerdict::Proceed);
1779 }
1780
1781 #[test]
1782 fn test_gate_hard_block_fires() {
1783 let dims = vec![
1784 GateDimension::new("utility", 0.9, 1.0),
1785 GateDimension::new("safety", -0.8, 3.0).hard(), GateDimension::new("legality", 0.7, 1.0),
1787 ];
1788 let result = action_gate(&dims);
1789 assert_eq!(result.verdict, GateVerdict::Block);
1790 assert!(result.hard_blocked_by.contains(&"safety".to_string()));
1791 }
1792
1793 #[test]
1794 fn test_gate_mixed_soft_dims_holds() {
1795 let dims = vec![
1796 GateDimension::new("utility", 0.8, 1.0),
1797 GateDimension::new("risk", -0.7, 1.0), ];
1799 let result = action_gate(&dims);
1801 assert_ne!(result.verdict, GateVerdict::Block); }
1804
1805 #[test]
1808 fn test_temperature_affirm_is_low() {
1809 let sc = TritScalar::new(0.9);
1810 let temp = scalar_temperature(&sc);
1811 assert_eq!(temp.trit, 1);
1812 assert!(temp.temperature < 0.3, "affirm → low temperature");
1813 }
1814
1815 #[test]
1816 fn test_temperature_tend_is_high() {
1817 let sc = TritScalar::new(0.05); let temp = scalar_temperature(&sc);
1819 assert_eq!(temp.trit, 0);
1820 assert!(temp.temperature >= 0.7, "tend → high temperature for exploration");
1821 }
1822
1823 #[test]
1824 fn test_temperature_reject_is_low() {
1825 let sc = TritScalar::new(-0.9);
1826 let temp = scalar_temperature(&sc);
1827 assert_eq!(temp.trit, -1);
1828 assert!(temp.temperature < 0.15, "reject → low temperature, firm");
1829 }
1830
1831 #[test]
1834 fn test_hallucination_consistent_signals_trusted() {
1835 let signals = vec![0.8, 0.82, 0.79, 0.81, 0.83];
1837 let score = hallucination_score(&signals);
1838 assert_eq!(score.trust_trit, 1, "consistent signals should be trusted");
1839 assert!(score.variance < 0.01);
1840 assert!(score.consistency > 0.99);
1841 }
1842
1843 #[test]
1844 fn test_hallucination_chaotic_signals_flagged() {
1845 let signals = vec![0.9, -0.9, 0.8, -0.8, 0.95, -0.7];
1847 let score = hallucination_score(&signals);
1848 assert!(score.variance > 0.5, "should have high variance");
1850 assert!(score.trust_trit <= 0, "chaotic signals should not be trusted");
1851 }
1852
1853 #[test]
1854 fn test_hallucination_empty_returns_hold() {
1855 let score = hallucination_score(&[]);
1856 assert_eq!(score.trust_trit, 0);
1857 assert_eq!(score.signal_count, 0);
1858 }
1859}
1860
1861use std::collections::HashMap;
1876use crate::coherence::ModelCoherence;
1877
1878pub struct TritTransformerConfig {
1879 pub dim: usize,
1880 pub n_layers: usize,
1881 pub n_heads: usize,
1882 pub n_kv_heads: usize,
1883 pub vocab_size: usize,
1884 pub multiple_of: usize,
1885 pub ffn_dim_multiplier: Option<f64>,
1886 pub norm_eps: f32,
1887 pub max_seq_len: usize,
1888}
1889
1890impl Default for TritTransformerConfig {
1891 fn default() -> Self {
1892 Self {
1893 dim: 2048,
1894 n_layers: 16,
1895 n_heads: 32,
1896 n_kv_heads: 8,
1897 vocab_size: 128256, multiple_of: 256,
1899 ffn_dim_multiplier: None,
1900 norm_eps: 1e-5,
1901 max_seq_len: 2048,
1902 }
1903 }
1904}
1905
1906pub struct TritBlock {
1908 pub wq: TritMatrix,
1909 pub wk: TritMatrix,
1910 pub wv: TritMatrix,
1911 pub wo: TritMatrix,
1912 pub w1: TritMatrix,
1913 pub w2: TritMatrix,
1914 pub w3: TritMatrix,
1915 pub attention_norm: Vec<f32>, pub ffn_norm: Vec<f32>,
1917}
1918
1919pub struct TritTransformer {
1921 pub config: TritTransformerConfig,
1922 pub tok_embeddings: TritMatrix,
1923 pub layers: Vec<TritBlock>,
1924 pub norm: Vec<f32>,
1925 pub output: TritMatrix,
1926 pub freq_cis: Vec<(f32, f32)>, }
1928
1929impl TritTransformer {
1930 pub fn from_coherence(coherence: ModelCoherence, config: TritTransformerConfig) -> Self {
1932 println!("ternlang-ml: Building TritTransformer (Layers: {})...", config.n_layers);
1933
1934 let mut layers = Vec::with_capacity(config.n_layers);
1935 let mut layer_map: HashMap<String, TritMatrix> = HashMap::new();
1936
1937 for layer in coherence.layers {
1938 layer_map.insert(layer.name.clone(), layer.to_trit_matrix());
1939 }
1940
1941 let mut get = |name: &str| {
1943 layer_map.remove(name).unwrap_or_else(|| panic!("Missing layer: {}", name))
1944 };
1945
1946 let tok_embeddings = get("token_embd.weight");
1947 let output = get("output.weight");
1948
1949 let norm = vec![1.0; config.dim];
1954
1955 for i in 0..config.n_layers {
1956 layers.push(TritBlock {
1957 wq: get(&format!("layers.{}.attention.wq.weight", i)),
1958 wk: get(&format!("layers.{}.attention.wk.weight", i)),
1959 wv: get(&format!("layers.{}.attention.wv.weight", i)),
1960 wo: get(&format!("layers.{}.attention.wo.weight", i)),
1961 w1: get(&format!("layers.{}.feed_forward.w1.weight", i)),
1962 w2: get(&format!("layers.{}.feed_forward.w2.weight", i)),
1963 w3: get(&format!("layers.{}.feed_forward.w3.weight", i)),
1964 attention_norm: vec![1.0; config.dim],
1965 ffn_norm: vec![1.0; config.dim],
1966 });
1967 }
1968
1969 let freq_cis = precompute_freqs_cis(config.dim / config.n_heads, config.max_seq_len);
1971
1972 Self {
1973 config,
1974 tok_embeddings,
1975 layers,
1976 norm,
1977 output,
1978 freq_cis,
1979 }
1980 }
1981
1982 pub fn forward(&self, token: usize, pos: usize) -> Vec<f32> {
1985 let mut h = self.get_embedding(token);
1986
1987 for layer in &self.layers {
1988 let h_norm = rms_norm(&h, &layer.attention_norm, self.config.norm_eps);
1990 let attn_out = self.attention(layer, &h_norm, pos);
1991 for i in 0..h.len() { h[i] += attn_out[i]; }
1992
1993 let h_norm = rms_norm(&h, &layer.ffn_norm, self.config.norm_eps);
1995 let ffn_out = self.feed_forward(layer, &h_norm);
1996 for i in 0..h.len() { h[i] += ffn_out[i]; }
1997 }
1998
1999 let h = rms_norm(&h, &self.norm, self.config.norm_eps);
2000 self.project_output(&h)
2001 }
2002
2003 fn get_embedding(&self, token: usize) -> Vec<f32> {
2004 let start = token * self.config.dim;
2005 let mut embd = Vec::with_capacity(self.config.dim);
2006 for i in 0..self.config.dim {
2007 embd.push(trit_to_f32(self.tok_embeddings.data[start + i]));
2008 }
2009 embd
2010 }
2011
2012 fn attention(&self, layer: &TritBlock, x: &[f32], pos: usize) -> Vec<f32> {
2013 let x_trit = TritMatrix::from_trits(1, x.len(), x.iter().map(|&v| trit_from_f32_approx(v)).collect());
2016
2017 let (q_trit, _) = sparse_matmul(&x_trit, &layer.wq);
2018 let (k_trit, _) = sparse_matmul(&x_trit, &layer.wk);
2019 let (v_trit, _) = sparse_matmul(&x_trit, &layer.wv);
2020
2021 let mut q = q_trit.data.iter().map(|&t| trit_to_f32(t)).collect::<Vec<_>>();
2022 let mut k = k_trit.data.iter().map(|&t| trit_to_f32(t)).collect::<Vec<_>>();
2023 let v = v_trit.data.iter().map(|&t| trit_to_f32(t)).collect::<Vec<_>>();
2024
2025 apply_rope(&mut q, pos, &self.freq_cis, self.config.n_heads);
2027 apply_rope(&mut k, pos, &self.freq_cis, self.config.n_heads);
2028
2029 let v_trit = TritMatrix::from_trits(1, v.len(), v.iter().map(|&val| trit_from_f32_approx(val)).collect());
2034 let (out, _) = sparse_matmul(&v_trit, &layer.wo);
2035 out.data.iter().map(|&t| trit_to_f32(t)).collect()
2036 }
2037
2038 fn feed_forward(&self, layer: &TritBlock, x: &[f32]) -> Vec<f32> {
2039 let x_trit = TritMatrix::from_trits(1, x.len(), x.iter().map(|&v| trit_from_f32_approx(v)).collect());
2040
2041 let (w1_x, _) = sparse_matmul(&x_trit, &layer.w1);
2043 let (w3_x, _) = sparse_matmul(&x_trit, &layer.w3);
2044
2045 let mut hidden = Vec::with_capacity(w1_x.data.len());
2046 for i in 0..w1_x.data.len() {
2047 let v1 = trit_to_f32(w1_x.data[i]);
2048 let v3 = trit_to_f32(w3_x.data[i]);
2049 let silu_v3 = v3 / (1.0 + (-v3).exp());
2051 hidden.push(v1 * silu_v3);
2052 }
2053
2054 let hidden_trit = TritMatrix::from_trits(1, hidden.len(), hidden.iter().map(|&v| trit_from_f32_approx(v)).collect());
2055 let (out, _) = sparse_matmul(&hidden_trit, &layer.w2);
2056 out.data.iter().map(|&t| trit_to_f32(t)).collect()
2057 }
2058
2059 fn project_output(&self, x: &[f32]) -> Vec<f32> {
2060 let x_trit = TritMatrix::from_trits(1, x.len(), x.iter().map(|&v| trit_from_f32_approx(v)).collect());
2061 let (logits, _) = sparse_matmul(&x_trit, &self.output);
2062 logits.data.iter().map(|&t| trit_to_f32(t)).collect()
2063 }
2064}
2065
2066fn rms_norm(x: &[f32], weight: &[f32], eps: f32) -> Vec<f32> {
2069 let sum_sq = x.iter().map(|&v| v * v).sum::<f32>();
2070 let inv_rms = 1.0 / (sum_sq / x.len() as f32 + eps).sqrt();
2071 x.iter().zip(weight.iter()).map(|(&v, &w)| v * inv_rms * w).collect()
2072}
2073
2074fn precompute_freqs_cis(dim: usize, end: usize) -> Vec<(f32, f32)> {
2075 let mut freqs_cis = Vec::with_capacity(end * (dim / 2));
2076 for pos in 0..end {
2077 for i in 0..(dim / 2) {
2078 let freq = 1.0 / 10000.0f32.powf((i * 2) as f32 / dim as f32);
2079 let val = pos as f32 * freq;
2080 freqs_cis.push((val.cos(), val.sin()));
2081 }
2082 }
2083 freqs_cis
2084}
2085
2086fn apply_rope(x: &mut [f32], pos: usize, freq_cis: &[(f32, f32)], n_heads: usize) {
2087 let head_dim = x.len() / n_heads;
2088 for h in 0..n_heads {
2089 let start = h * head_dim;
2090 for i in 0..(head_dim / 2) {
2091 let (cos, sin) = freq_cis[pos * (head_dim / 2) + i];
2092 let x0 = x[start + i];
2093 let x1 = x[start + i + head_dim / 2];
2094 x[start + i] = x0 * cos - x1 * sin;
2095 x[start + i + head_dim / 2] = x0 * sin + x1 * cos;
2096 }
2097 }
2098}
2099
2100pub fn trit_to_f32(t: Trit) -> f32 {
2101 match t {
2102 Trit::Affirm => 1.0,
2103 Trit::Reject => -1.0,
2104 Trit::Tend => 0.0,
2105 }
2106}
2107
2108pub fn trit_from_f32_approx(v: f32) -> Trit {
2109 if v > 0.5 { Trit::Affirm }
2110 else if v < -0.5 { Trit::Reject }
2111 else { Trit::Tend }
2112}