ternlang_compress/sparse.rs
1// Sparse zero-index (CSR — Compressed Sparse Row) for ternary weight matrices.
2//
3// A ternary matrix has three states: {-1, 0, +1}. The majority of weights
4// in a quantized LLM are 0 (typically 60-99% depending on sparsity).
5//
6// Storing the FULL matrix wastes memory on zeros. Instead we store:
7// - `values`: only the non-zero trit values (+1 or -1), packed as i8
8// - `col_idx`: column index for each non-zero value (u32)
9// - `row_ptr`: pointer into `values` for the start of each row (u64)
10//
11// At inference the sparse_matmul kernel iterates `row_ptr[row]..row_ptr[row+1]`
12// and accumulates only the non-zero contributions — zeros are never touched.
13//
14// Memory breakdown for a 4096×4096 weight matrix at 60% sparsity:
15// Dense ternary (2-bit packed): 4096×4096×2b = 4 MB
16// CSR (i8 values + u32 col_idx): 6.7M non-zeros × 5 bytes + 16KB row_ptr ≈ 32 MB
17// — dense is better here. CSR wins at >80% sparsity:
18// 80% sparse: 3.3M nnz × 5 bytes ≈ 16 MB < dense.
19//
20// Practical recommendation: use CSR for sparsity > 75%, packed ternary otherwise.
21// The pipeline chooses automatically based on measured sparsity.
22
23use ternlang_core::trit::Trit;
24use serde::{Deserialize, Serialize};
25
26/// Compressed Sparse Row index for one weight matrix.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SparseIndex {
29 pub rows: usize,
30 pub cols: usize,
31 /// Non-zero trit values, stored as i8 (+1 or -1).
32 pub values: Vec<i8>,
33 /// Column of each non-zero value. Same length as `values`.
34 pub col_idx: Vec<u32>,
35 /// `row_ptr[r]` = index into `values` where row `r` starts.
36 /// Length = rows + 1. `row_ptr[rows]` = values.len().
37 pub row_ptr: Vec<u64>,
38 /// Number of non-zero elements.
39 pub nnz: usize,
40 /// Sparsity = 1 - nnz / (rows * cols).
41 pub sparsity: f64,
42}
43
44impl SparseIndex {
45 /// Build a CSR index from a flat row-major slice of trits.
46 pub fn from_trits(rows: usize, cols: usize, data: &[Trit]) -> Self {
47 assert_eq!(data.len(), rows * cols);
48
49 let mut values = Vec::new();
50 let mut col_idx = Vec::new();
51 let mut row_ptr = vec![0u64; rows + 1];
52
53 for r in 0..rows {
54 row_ptr[r] = values.len() as u64;
55 for c in 0..cols {
56 let t = data[r * cols + c];
57 if t != Trit::Tend {
58 values.push(match t {
59 Trit::Affirm => 1i8,
60 Trit::Reject => -1i8,
61 Trit::Tend => 0i8,
62 });
63 col_idx.push(c as u32);
64 }
65 }
66 }
67 row_ptr[rows] = values.len() as u64;
68
69 let nnz = values.len();
70 let total = rows * cols;
71 let sparsity = 1.0 - nnz as f64 / total as f64;
72
73 Self { rows, cols, values, col_idx, row_ptr, nnz, sparsity }
74 }
75
76 /// Reconstruct the dense trit slice (for testing / correctness checks).
77 pub fn to_dense(&self) -> Vec<Trit> {
78 let mut out = vec![Trit::Tend; self.rows * self.cols];
79 for r in 0..self.rows {
80 let start = self.row_ptr[r] as usize;
81 let end = self.row_ptr[r + 1] as usize;
82 for idx in start..end {
83 let c = self.col_idx[idx] as usize;
84 out[r * self.cols + c] = if self.values[idx] > 0 {
85 Trit::Affirm
86 } else {
87 Trit::Reject
88 };
89 }
90 }
91 out
92 }
93
94 /// Memory footprint in bytes.
95 pub fn memory_bytes(&self) -> usize {
96 self.values.len() // i8 = 1 byte each
97 + self.col_idx.len() * 4 // u32 = 4 bytes each
98 + self.row_ptr.len() * 8 // u64 = 8 bytes each
99 }
100
101 /// Whether CSR is more memory-efficient than a 2-bit packed dense matrix.
102 pub fn is_efficient(&self) -> bool {
103 let dense_packed = (self.rows * self.cols + 3) / 4; // 2 bits per trit
104 self.memory_bytes() < dense_packed
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[test]
113 fn roundtrip_sparse() {
114 let trits = vec![
115 Trit::Affirm, Trit::Tend, Trit::Reject,
116 Trit::Tend, Trit::Tend, Trit::Affirm,
117 Trit::Reject, Trit::Affirm, Trit::Tend,
118 ];
119 let idx = SparseIndex::from_trits(3, 3, &trits);
120 assert_eq!(idx.nnz, 5);
121 let dense = idx.to_dense();
122 assert_eq!(dense, trits);
123 }
124}