ternlang_compress/
model.rs1use serde::{Deserialize, Serialize};
8use crate::sparse::SparseIndex;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub enum LayerStorage {
13 Dense {
15 rows: usize,
16 cols: usize,
17 packed: Vec<u8>,
19 },
20 Sparse(SparseIndex),
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct TernLayer {
27 pub name: String,
29 pub scale: f32,
31 pub storage: LayerStorage,
33 pub sparsity: f64,
35 pub original_dtype: String,
37}
38
39impl TernLayer {
40 pub fn num_params(&self) -> usize {
42 match &self.storage {
43 LayerStorage::Dense { rows, cols, .. } => rows * cols,
44 LayerStorage::Sparse(idx) => idx.rows * idx.cols,
45 }
46 }
47
48 pub fn memory_bytes(&self) -> usize {
50 match &self.storage {
51 LayerStorage::Dense { rows, cols, .. } => (rows * cols + 3) / 4,
52 LayerStorage::Sparse(idx) => idx.memory_bytes(),
53 }
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct TernModel {
60 pub source_model: String,
62 pub format_version: u32,
64 pub layers: Vec<TernLayer>,
66 pub architecture: String,
68 pub vocab_size: usize,
70 pub hidden_size: usize,
72 pub num_layers: usize,
74}
75
76impl TernModel {
77 pub fn total_params(&self) -> usize {
79 self.layers.iter().map(|l| l.num_params()).sum()
80 }
81
82 pub fn total_memory_bytes(&self) -> usize {
84 self.layers.iter().map(|l| l.memory_bytes()).sum()
85 }
86
87 pub fn mean_sparsity(&self) -> f64 {
89 if self.layers.is_empty() { return 0.0; }
90 let sum: f64 = self.layers.iter().map(|l| l.sparsity).sum();
91 sum / self.layers.len() as f64
92 }
93
94 pub fn compression_ratio_vs_f16(&self) -> f64 {
96 let f16_bytes = self.total_params() * 2;
97 if f16_bytes == 0 { return 1.0; }
98 f16_bytes as f64 / self.total_memory_bytes() as f64
99 }
100
101 pub fn summary(&self) -> String {
103 format!(
104 "TernModel: {}\n Params: {:>12}\n Ternary mem: {:>12} MB\n Mean sparsity: {:.1}%\n vs f16: {:.1}× smaller",
105 self.source_model,
106 self.total_params(),
107 self.total_memory_bytes() / 1_048_576,
108 self.mean_sparsity() * 100.0,
109 self.compression_ratio_vs_f16(),
110 )
111 }
112}