1use crate::model::{LayerStorage, TernLayer, TernModel};
12use crate::quantize::PerLayerQuant;
13use crate::sparse::SparseIndex;
14use crate::FORMAT_VERSION;
15
16use ternlang_core::trit::Trit;
17use thiserror::Error;
18
19#[derive(Debug, Error)]
20pub enum CompressError {
21 #[error("empty model: no layers provided")]
22 NoLayers,
23 #[error("layer `{0}`: shape product ({1}) does not match weights length ({2})")]
24 ShapeMismatch(String, usize, usize),
25 #[error("I/O error: {0}")]
26 Io(#[from] std::io::Error),
27}
28
29#[derive(Debug, Clone)]
31pub struct CompressConfig {
32 pub source_model: String,
34 pub architecture: String,
36 pub vocab_size: usize,
37 pub hidden_size: usize,
38 pub num_layers: usize,
39
40 pub csr_sparsity_threshold: f64,
43
44 pub verbose: bool,
46}
47
48impl Default for CompressConfig {
49 fn default() -> Self {
50 Self {
51 source_model: String::from("unknown"),
52 architecture: String::from("unknown"),
53 vocab_size: 0,
54 hidden_size: 0,
55 num_layers: 0,
56 csr_sparsity_threshold: 0.75,
57 verbose: false,
58 }
59 }
60}
61
62pub fn compress(
68 layers: Vec<(String, Vec<f32>, Vec<usize>)>,
69 cfg: CompressConfig,
70) -> Result<TernModel, CompressError> {
71 if layers.is_empty() {
72 return Err(CompressError::NoLayers);
73 }
74
75 for (name, weights, shape) in &layers {
77 let expected: usize = shape.iter().product();
78 if expected != weights.len() {
79 return Err(CompressError::ShapeMismatch(
80 name.clone(), expected, weights.len(),
81 ));
82 }
83 }
84
85 #[cfg(feature = "parallel")]
87 let quants: Vec<PerLayerQuant> = {
88 use crate::quantize::quantize_layers_parallel;
89 quantize_layers_parallel(layers)
90 };
91 #[cfg(not(feature = "parallel"))]
92 let quants: Vec<PerLayerQuant> = quantize_layers(layers);
93
94 let mut tern_layers = Vec::with_capacity(quants.len());
96
97 for q in quants {
98 if cfg.verbose {
99 tracing::info!(
100 layer = %q.name,
101 sparsity = format!("{:.1}%", q.sparsity * 100.0),
102 scale = q.scale,
103 "quantized"
104 );
105 }
106
107 let storage = if q.sparsity >= cfg.csr_sparsity_threshold {
108 let (rows, cols) = layer_dims(&q.shape);
110 let idx = SparseIndex::from_trits(rows, cols, &q.trits);
111 LayerStorage::Sparse(idx)
112 } else {
113 let (rows, cols) = layer_dims(&q.shape);
115 let packed = pack_trits_2bit(&q.trits);
116 LayerStorage::Dense { rows, cols, packed }
117 };
118
119 tern_layers.push(TernLayer {
120 name: q.name,
121 scale: q.scale,
122 storage,
123 sparsity: q.sparsity,
124 original_dtype: "f32".into(),
125 });
126 }
127
128 Ok(TernModel {
129 source_model: cfg.source_model,
130 format_version: FORMAT_VERSION,
131 layers: tern_layers,
132 architecture: cfg.architecture,
133 vocab_size: cfg.vocab_size,
134 hidden_size: cfg.hidden_size,
135 num_layers: cfg.num_layers,
136 })
137}
138
139fn layer_dims(shape: &[usize]) -> (usize, usize) {
145 match shape.len() {
146 0 => (1, 1),
147 1 => (1, shape[0]),
148 2 => (shape[0], shape[1]),
149 _ => (shape[0], shape[1..].iter().product()),
150 }
151}
152
153fn pack_trits_2bit(trits: &[Trit]) -> Vec<u8> {
157 let n_bytes = (trits.len() + 3) / 4;
158 let mut out = vec![0u8; n_bytes];
159 for (i, &t) in trits.iter().enumerate() {
160 let bits: u8 = match t {
161 Trit::Tend => 0b00,
162 Trit::Affirm => 0b01,
163 Trit::Reject => 0b10,
164 };
165 let byte_idx = i / 4;
166 let shift = (i % 4) * 2;
167 out[byte_idx] |= bits << shift;
168 }
169 out
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn compress_tiny_model() {
178 let weights: Vec<f32> = (0..16).map(|i| (i as f32) - 8.0).collect();
179 let layers = vec![
180 ("layer0.weight".into(), weights, vec![4, 4]),
181 ];
182 let cfg = CompressConfig {
183 source_model: "test-1M".into(),
184 num_layers: 1,
185 ..Default::default()
186 };
187 let model = compress(layers, cfg).unwrap();
188 assert_eq!(model.layers.len(), 1);
189 assert!(model.mean_sparsity() >= 0.0);
190 println!("{}", model.summary());
191 }
192
193 #[test]
194 fn pack_unpack_roundtrip() {
195 let trits = vec![
196 Trit::Affirm, Trit::Tend, Trit::Reject, Trit::Affirm,
197 Trit::Reject, Trit::Tend, Trit::Affirm, Trit::Tend,
198 ];
199 let packed = pack_trits_2bit(&trits);
200 let expected_byte0: u8 = 0b01_10_00_01;
202 assert_eq!(packed[0], expected_byte0);
203 }
204}