Skip to main content

ternlang_ml/
coherence.rs

1use crate::{TritMatrix, sparse_matmul};
2use ternlang_core::trit::Trit;
3use serde::{Deserialize, Serialize};
4use std::path::Path;
5use std::fs::File;
6use std::io::BufReader;
7
8#[derive(Serialize, Deserialize, Debug)]
9pub struct PackedDense {
10    pub rows: usize,
11    pub cols: usize,
12    pub packed: Vec<u8>,
13}
14
15#[derive(Serialize, Deserialize, Debug)]
16pub enum Storage {
17    Dense(PackedDense),
18}
19
20#[derive(Serialize, Deserialize, Debug)]
21pub struct Layer {
22    pub name: String,
23    pub scale: f32,
24    pub sparsity: f32,
25    pub storage: Storage,
26}
27
28#[derive(Serialize, Deserialize, Debug)]
29pub struct ModelCoherence {
30    pub source_model: String,
31    pub layers: Vec<Layer>,
32}
33
34impl ModelCoherence {
35    /// Save model to a fast binary format.
36    pub fn save_bin(&self, path: &Path) -> anyhow::Result<()> {
37        let file = File::create(path)?;
38        bincode::serialize_into(file, self)?;
39        Ok(())
40    }
41
42    /// Load model from the fast binary format.
43    pub fn load_bin(path: &Path) -> anyhow::Result<Self> {
44        let file = File::open(path)?;
45        let reader = BufReader::new(file);
46        let model: Self = bincode::deserialize_from(reader)?;
47        Ok(model)
48    }
49}
50
51impl Layer {
52    pub fn to_trit_matrix(&self) -> TritMatrix {
53        unpack_layer(self)
54    }
55}
56
57pub fn unpack_layer(layer: &Layer) -> TritMatrix {
58    match &layer.storage {
59        Storage::Dense(dense) => {
60            let mut trits = Vec::with_capacity(dense.rows * dense.cols);
61            for (_byte_idx, &byte) in dense.packed.iter().enumerate() {
62                for bit_idx in 0..4 {
63                    if trits.len() >= dense.rows * dense.cols {
64                        break;
65                    }
66                    let bits = (byte >> (bit_idx * 2)) & 0b11;
67                    let trit = match bits {
68                        0b01 => Trit::Reject,
69                        0b11 => Trit::Tend,
70                        0b10 => Trit::Affirm,
71                        0b00 => Trit::Tend, // Should not happen in fixed script
72                        _ => unreachable!(),
73                    };
74                    trits.push(trit);
75                }
76            }
77            TritMatrix::from_trits(dense.rows, dense.cols, trits)
78        }
79    }
80}
81
82pub fn run_coherence_test(json_path: &Path, target_layer: &str) -> anyhow::Result<()> {
83    println!("--- RFI-IRFOS TIS: Coherence Test [Phase 12A] ---");
84    println!("Loading model from: {:?}", json_path);
85    
86    let file = File::open(json_path)?;
87    let reader = BufReader::new(file);
88    let model: ModelCoherence = serde_json::from_reader(reader)?;
89    
90    println!("Model: {}", model.source_model);
91    println!("Total layers: {}", model.layers.len());
92    
93    let layer = model.layers.iter().find(|l| l.name == target_layer)
94        .ok_or_else(|| anyhow::anyhow!("Layer {} not found", target_layer))?;
95        
96    println!("Testing layer: {} (Sparsity: {:.2}%)", layer.name, layer.sparsity * 100.0);
97    
98    let w = unpack_layer(layer);
99    
100    // Create a mock input: [1 x in_features]
101    // For a forward pass check, we use all Affirm (+1) inputs.
102    let mut input = TritMatrix::new(1, w.rows);
103    for i in 0..w.rows {
104        input.set(0, i, Trit::Affirm);
105    }
106    
107    println!("Running sparse_matmul (Forward Pass POC)...");
108    let (output, skipped) = sparse_matmul(&input, &w);
109    
110    println!("Done.");
111    println!("Output shape: {}x{}", output.rows, output.cols);
112    println!("Skipped ops:  {} (Sparsity Advantage: {:.2}x)", 
113        skipped, (skipped as f64 + output.rows as f64 * output.cols as f64 * w.rows as f64) / (output.rows as f64 * output.cols as f64 * w.rows as f64 - skipped as f64).max(1.0));
114    
115    // Check signal: how many non-zero trits in output?
116    let non_zeros = output.data.iter().filter(|&&t| t != Trit::Tend).count();
117    let signal_ratio = non_zeros as f32 / output.data.len() as f32;
118    
119    println!("Signal Ratio: {:.2}% ({} / {})", signal_ratio * 100.0, non_zeros, output.data.len());
120    
121    if signal_ratio > 0.05 {
122        println!("[SUCCESS] Signal coherence detected. Model is not a void.");
123    } else {
124        println!("[WARNING] Low signal detected. Model might be overly sparse or collapsed.");
125    }
126    
127    Ok(())
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn test_llama_coherence() {
136        // CWD during cargo test is crate root: ternlang-root/ternlang-ml/
137        // llama32-1b.tern.json is in TIS/ root (../../)
138        let json_path = Path::new("../../llama32-1b.tern.json");
139        if json_path.exists() {
140            run_coherence_test(json_path, "blk.0.ffn_gate.weight").unwrap();
141        } else {
142            println!("Skipping coherence test: llama32-1b.tern.json not found at {:?}", json_path);
143            // Fallback: try one level up just in case
144            let fallback = Path::new("../llama32-1b.tern.json");
145            if fallback.exists() {
146                run_coherence_test(fallback, "blk.0.ffn_gate.weight").unwrap();
147            } else {
148                panic!("Could not find llama32-1b.tern.json at {:?} or {:?}", json_path, fallback);
149            }
150        }
151    }
152}