Skip to main content

ternlang_compress/
sparse.rs

1// Sparse zero-index (CSR — Compressed Sparse Row) for ternary weight matrices.
2//
3// A ternary matrix has three states: {-1, 0, +1}.  The majority of weights
4// in a quantized LLM are 0 (typically 60-99% depending on sparsity).
5//
6// Storing the FULL matrix wastes memory on zeros.  Instead we store:
7//   - `values`:  only the non-zero trit values (+1 or -1), packed as i8
8//   - `col_idx`: column index for each non-zero value (u32)
9//   - `row_ptr`: pointer into `values` for the start of each row (u64)
10//
11// At inference the sparse_matmul kernel iterates `row_ptr[row]..row_ptr[row+1]`
12// and accumulates only the non-zero contributions — zeros are never touched.
13//
14// Memory breakdown for a 4096×4096 weight matrix at 60% sparsity:
15//   Dense ternary (2-bit packed): 4096×4096×2b = 4 MB
16//   CSR (i8 values + u32 col_idx): 6.7M non-zeros × 5 bytes + 16KB row_ptr ≈ 32 MB
17//   — dense is better here. CSR wins at >80% sparsity:
18//   80% sparse: 3.3M nnz × 5 bytes ≈ 16 MB < dense.
19//
20// Practical recommendation: use CSR for sparsity > 75%, packed ternary otherwise.
21// The pipeline chooses automatically based on measured sparsity.
22
23use ternlang_core::trit::Trit;
24use serde::{Deserialize, Serialize};
25
26/// Compressed Sparse Row index for one weight matrix.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SparseIndex {
29    pub rows: usize,
30    pub cols: usize,
31    /// Non-zero trit values, stored as i8 (+1 or -1).
32    pub values: Vec<i8>,
33    /// Column of each non-zero value.  Same length as `values`.
34    pub col_idx: Vec<u32>,
35    /// `row_ptr[r]` = index into `values` where row `r` starts.
36    /// Length = rows + 1.  `row_ptr[rows]` = values.len().
37    pub row_ptr: Vec<u64>,
38    /// Number of non-zero elements.
39    pub nnz: usize,
40    /// Sparsity = 1 - nnz / (rows * cols).
41    pub sparsity: f64,
42}
43
44impl SparseIndex {
45    /// Build a CSR index from a flat row-major slice of trits.
46    pub fn from_trits(rows: usize, cols: usize, data: &[Trit]) -> Self {
47        assert_eq!(data.len(), rows * cols);
48
49        let mut values  = Vec::new();
50        let mut col_idx = Vec::new();
51        let mut row_ptr = vec![0u64; rows + 1];
52
53        for r in 0..rows {
54            row_ptr[r] = values.len() as u64;
55            for c in 0..cols {
56                let t = data[r * cols + c];
57                if t != Trit::Tend {
58                    values.push(match t {
59                        Trit::Affirm =>  1i8,
60                        Trit::Reject => -1i8,
61                        Trit::Tend   =>  0i8,
62                    });
63                    col_idx.push(c as u32);
64                }
65            }
66        }
67        row_ptr[rows] = values.len() as u64;
68
69        let nnz      = values.len();
70        let total    = rows * cols;
71        let sparsity = 1.0 - nnz as f64 / total as f64;
72
73        Self { rows, cols, values, col_idx, row_ptr, nnz, sparsity }
74    }
75
76    /// Reconstruct the dense trit slice (for testing / correctness checks).
77    pub fn to_dense(&self) -> Vec<Trit> {
78        let mut out = vec![Trit::Tend; self.rows * self.cols];
79        for r in 0..self.rows {
80            let start = self.row_ptr[r] as usize;
81            let end   = self.row_ptr[r + 1] as usize;
82            for idx in start..end {
83                let c = self.col_idx[idx] as usize;
84                out[r * self.cols + c] = if self.values[idx] > 0 {
85                    Trit::Affirm
86                } else {
87                    Trit::Reject
88                };
89            }
90        }
91        out
92    }
93
94    /// Memory footprint in bytes.
95    pub fn memory_bytes(&self) -> usize {
96        self.values.len()      // i8 = 1 byte each
97        + self.col_idx.len() * 4  // u32 = 4 bytes each
98        + self.row_ptr.len() * 8  // u64 = 8 bytes each
99    }
100
101    /// Whether CSR is more memory-efficient than a 2-bit packed dense matrix.
102    pub fn is_efficient(&self) -> bool {
103        let dense_packed = (self.rows * self.cols + 3) / 4; // 2 bits per trit
104        self.memory_bytes() < dense_packed
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[test]
113    fn roundtrip_sparse() {
114        let trits = vec![
115            Trit::Affirm, Trit::Tend,   Trit::Reject,
116            Trit::Tend,   Trit::Tend,   Trit::Affirm,
117            Trit::Reject, Trit::Affirm, Trit::Tend,
118        ];
119        let idx = SparseIndex::from_trits(3, 3, &trits);
120        assert_eq!(idx.nnz, 5);
121        let dense = idx.to_dense();
122        assert_eq!(dense, trits);
123    }
124}