1use crate::{TritMatrix, sparse_matmul};
2use ternlang_core::trit::Trit;
3use serde::{Deserialize, Serialize};
4use std::path::Path;
5use std::fs::File;
6use std::io::BufReader;
7
8#[derive(Serialize, Deserialize, Debug)]
9pub struct PackedDense {
10 pub rows: usize,
11 pub cols: usize,
12 pub packed: Vec<u8>,
13}
14
15#[derive(Serialize, Deserialize, Debug)]
16pub enum Storage {
17 Dense(PackedDense),
18}
19
20#[derive(Serialize, Deserialize, Debug)]
21pub struct Layer {
22 pub name: String,
23 pub scale: f32,
24 pub sparsity: f32,
25 pub storage: Storage,
26}
27
28#[derive(Serialize, Deserialize, Debug)]
29pub struct ModelCoherence {
30 pub source_model: String,
31 pub layers: Vec<Layer>,
32}
33
34impl ModelCoherence {
35 pub fn save_bin(&self, path: &Path) -> anyhow::Result<()> {
37 let file = File::create(path)?;
38 bincode::serialize_into(file, self)?;
39 Ok(())
40 }
41
42 pub fn load_bin(path: &Path) -> anyhow::Result<Self> {
44 let file = File::open(path)?;
45 let reader = BufReader::new(file);
46 let model: Self = bincode::deserialize_from(reader)?;
47 Ok(model)
48 }
49}
50
51impl Layer {
52 pub fn to_trit_matrix(&self) -> TritMatrix {
53 unpack_layer(self)
54 }
55}
56
57pub fn unpack_layer(layer: &Layer) -> TritMatrix {
58 match &layer.storage {
59 Storage::Dense(dense) => {
60 let mut trits = Vec::with_capacity(dense.rows * dense.cols);
61 for (_byte_idx, &byte) in dense.packed.iter().enumerate() {
62 for bit_idx in 0..4 {
63 if trits.len() >= dense.rows * dense.cols {
64 break;
65 }
66 let bits = (byte >> (bit_idx * 2)) & 0b11;
67 let trit = match bits {
68 0b01 => Trit::Reject,
69 0b11 => Trit::Tend,
70 0b10 => Trit::Affirm,
71 0b00 => Trit::Tend, _ => unreachable!(),
73 };
74 trits.push(trit);
75 }
76 }
77 TritMatrix::from_trits(dense.rows, dense.cols, trits)
78 }
79 }
80}
81
82pub fn run_coherence_test(json_path: &Path, target_layer: &str) -> anyhow::Result<()> {
83 println!("--- RFI-IRFOS TIS: Coherence Test [Phase 12A] ---");
84 println!("Loading model from: {:?}", json_path);
85
86 let file = File::open(json_path)?;
87 let reader = BufReader::new(file);
88 let model: ModelCoherence = serde_json::from_reader(reader)?;
89
90 println!("Model: {}", model.source_model);
91 println!("Total layers: {}", model.layers.len());
92
93 let layer = model.layers.iter().find(|l| l.name == target_layer)
94 .ok_or_else(|| anyhow::anyhow!("Layer {} not found", target_layer))?;
95
96 println!("Testing layer: {} (Sparsity: {:.2}%)", layer.name, layer.sparsity * 100.0);
97
98 let w = unpack_layer(layer);
99
100 let mut input = TritMatrix::new(1, w.rows);
103 for i in 0..w.rows {
104 input.set(0, i, Trit::Affirm);
105 }
106
107 println!("Running sparse_matmul (Forward Pass POC)...");
108 let (output, skipped) = sparse_matmul(&input, &w);
109
110 println!("Done.");
111 println!("Output shape: {}x{}", output.rows, output.cols);
112 println!("Skipped ops: {} (Sparsity Advantage: {:.2}x)",
113 skipped, (skipped as f64 + output.rows as f64 * output.cols as f64 * w.rows as f64) / (output.rows as f64 * output.cols as f64 * w.rows as f64 - skipped as f64).max(1.0));
114
115 let non_zeros = output.data.iter().filter(|&&t| t != Trit::Tend).count();
117 let signal_ratio = non_zeros as f32 / output.data.len() as f32;
118
119 println!("Signal Ratio: {:.2}% ({} / {})", signal_ratio * 100.0, non_zeros, output.data.len());
120
121 if signal_ratio > 0.05 {
122 println!("[SUCCESS] Signal coherence detected. Model is not a void.");
123 } else {
124 println!("[WARNING] Low signal detected. Model might be overly sparse or collapsed.");
125 }
126
127 Ok(())
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn test_llama_coherence() {
136 let json_path = Path::new("../../llama32-1b.tern.json");
139 if json_path.exists() {
140 run_coherence_test(json_path, "blk.0.ffn_gate.weight").unwrap();
141 } else {
142 println!("Skipping coherence test: llama32-1b.tern.json not found at {:?}", json_path);
143 let fallback = Path::new("../llama32-1b.tern.json");
145 if fallback.exists() {
146 run_coherence_test(fallback, "blk.0.ffn_gate.weight").unwrap();
147 } else {
148 panic!("Could not find llama32-1b.tern.json at {:?} or {:?}", json_path, fallback);
149 }
150 }
151 }
152}