Skip to main content

ternlang_compress/
pipeline.rs

1// Main compression pipeline.
2//
3// Usage:
4//   let cfg = CompressConfig::default();
5//   let model = compress(layers, cfg)?;
6//
7// `layers` is a Vec of (name, f32_weights, shape) triples produced by a
8// model loader (see format.rs for the GGUF/safetensors loaders — currently
9// stubbed, to be completed when integrating candle or llama.cpp bindings).
10
11use crate::model::{LayerStorage, TernLayer, TernModel};
12use crate::quantize::PerLayerQuant;
13use crate::sparse::SparseIndex;
14use crate::FORMAT_VERSION;
15
16use ternlang_core::trit::Trit;
17use thiserror::Error;
18
19#[derive(Debug, Error)]
20pub enum CompressError {
21    #[error("empty model: no layers provided")]
22    NoLayers,
23    #[error("layer `{0}`: shape product ({1}) does not match weights length ({2})")]
24    ShapeMismatch(String, usize, usize),
25    #[error("I/O error: {0}")]
26    Io(#[from] std::io::Error),
27}
28
29/// Configuration for the compression pipeline.
30#[derive(Debug, Clone)]
31pub struct CompressConfig {
32    /// Source model identifier string (for metadata).
33    pub source_model: String,
34    /// Architecture string (e.g. "LlamaForCausalLM").
35    pub architecture: String,
36    pub vocab_size:   usize,
37    pub hidden_size:  usize,
38    pub num_layers:   usize,
39
40    /// Sparsity threshold above which CSR storage is used instead of dense.
41    /// Default: 0.75 (CSR is more efficient above ~75% sparsity).
42    pub csr_sparsity_threshold: f64,
43
44    /// If true, print per-layer stats during compression.
45    pub verbose: bool,
46}
47
48impl Default for CompressConfig {
49    fn default() -> Self {
50        Self {
51            source_model:           String::from("unknown"),
52            architecture:           String::from("unknown"),
53            vocab_size:             0,
54            hidden_size:            0,
55            num_layers:             0,
56            csr_sparsity_threshold: 0.75,
57            verbose:                false,
58        }
59    }
60}
61
62/// Compress a list of float weight tensors to a TernModel.
63///
64/// # Arguments
65/// * `layers` — `(name, weights_f32, shape)` for each weight matrix.
66/// * `cfg`    — pipeline configuration.
67pub fn compress(
68    layers: Vec<(String, Vec<f32>, Vec<usize>)>,
69    cfg: CompressConfig,
70) -> Result<TernModel, CompressError> {
71    if layers.is_empty() {
72        return Err(CompressError::NoLayers);
73    }
74
75    // Validate shapes
76    for (name, weights, shape) in &layers {
77        let expected: usize = shape.iter().product();
78        if expected != weights.len() {
79            return Err(CompressError::ShapeMismatch(
80                name.clone(), expected, weights.len(),
81            ));
82        }
83    }
84
85    // Quantize all layers (parallel if feature enabled)
86    #[cfg(feature = "parallel")]
87    let quants: Vec<PerLayerQuant> = {
88        use crate::quantize::quantize_layers_parallel;
89        quantize_layers_parallel(layers)
90    };
91    #[cfg(not(feature = "parallel"))]
92    let quants: Vec<PerLayerQuant> = quantize_layers(layers);
93
94    // Build TernLayers
95    let mut tern_layers = Vec::with_capacity(quants.len());
96
97    for q in quants {
98        if cfg.verbose {
99            tracing::info!(
100                layer = %q.name,
101                sparsity = format!("{:.1}%", q.sparsity * 100.0),
102                scale = q.scale,
103                "quantized"
104            );
105        }
106
107        let storage = if q.sparsity >= cfg.csr_sparsity_threshold {
108            // High sparsity — use CSR
109            let (rows, cols) = layer_dims(&q.shape);
110            let idx = SparseIndex::from_trits(rows, cols, &q.trits);
111            LayerStorage::Sparse(idx)
112        } else {
113            // Lower sparsity — use 2-bit packed dense
114            let (rows, cols) = layer_dims(&q.shape);
115            let packed = pack_trits_2bit(&q.trits);
116            LayerStorage::Dense { rows, cols, packed }
117        };
118
119        tern_layers.push(TernLayer {
120            name:          q.name,
121            scale:         q.scale,
122            storage,
123            sparsity:      q.sparsity,
124            original_dtype: "f32".into(),
125        });
126    }
127
128    Ok(TernModel {
129        source_model:   cfg.source_model,
130        format_version: FORMAT_VERSION,
131        layers:         tern_layers,
132        architecture:   cfg.architecture,
133        vocab_size:     cfg.vocab_size,
134        hidden_size:    cfg.hidden_size,
135        num_layers:     cfg.num_layers,
136    })
137}
138
139// ─── Helpers ─────────────────────────────────────────────────────────────────
140
141/// Extract (rows, cols) from a shape slice.
142/// For rank-1 tensors uses (1, n); for rank-2 uses (shape[0], shape[1]);
143/// for higher ranks flattens inner dims into cols.
144fn layer_dims(shape: &[usize]) -> (usize, usize) {
145    match shape.len() {
146        0 => (1, 1),
147        1 => (1, shape[0]),
148        2 => (shape[0], shape[1]),
149        _ => (shape[0], shape[1..].iter().product()),
150    }
151}
152
153/// Pack trit slice into 2-bit-per-trit bytes.
154/// Encoding: 0b00 = Tend (0), 0b01 = Affirm (+1), 0b10 = Reject (-1).
155/// 4 trits per byte, LSB first.
156fn pack_trits_2bit(trits: &[Trit]) -> Vec<u8> {
157    let n_bytes = (trits.len() + 3) / 4;
158    let mut out = vec![0u8; n_bytes];
159    for (i, &t) in trits.iter().enumerate() {
160        let bits: u8 = match t {
161            Trit::Tend   => 0b00,
162            Trit::Affirm => 0b01,
163            Trit::Reject => 0b10,
164        };
165        let byte_idx = i / 4;
166        let shift    = (i % 4) * 2;
167        out[byte_idx] |= bits << shift;
168    }
169    out
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn compress_tiny_model() {
178        let weights: Vec<f32> = (0..16).map(|i| (i as f32) - 8.0).collect();
179        let layers = vec![
180            ("layer0.weight".into(), weights, vec![4, 4]),
181        ];
182        let cfg = CompressConfig {
183            source_model: "test-1M".into(),
184            num_layers: 1,
185            ..Default::default()
186        };
187        let model = compress(layers, cfg).unwrap();
188        assert_eq!(model.layers.len(), 1);
189        assert!(model.mean_sparsity() >= 0.0);
190        println!("{}", model.summary());
191    }
192
193    #[test]
194    fn pack_unpack_roundtrip() {
195        let trits = vec![
196            Trit::Affirm, Trit::Tend, Trit::Reject, Trit::Affirm,
197            Trit::Reject, Trit::Tend, Trit::Affirm, Trit::Tend,
198        ];
199        let packed = pack_trits_2bit(&trits);
200        // Manually verify first byte: Affirm=01 | Tend=00<<2 | Reject=10<<4 | Affirm=01<<6
201        let expected_byte0: u8 = 0b01_10_00_01;
202        assert_eq!(packed[0], expected_byte0);
203    }
204}