ruvllm_esp32/optimizations/
binary_quant.rs

1//! Binary Quantization - 32x Memory Compression
2//!
3//! Adapted from ruvector-postgres/src/quantization/binary.rs
4//! Converts f32/i8 vectors to 1-bit per dimension with Hamming distance.
5
6use heapless::Vec as HVec;
7
8/// Maximum binary vector size in bytes (supports up to 512 dimensions)
9pub const MAX_BINARY_SIZE: usize = 64;
10
11/// Binary quantized vector - 1 bit per dimension
12#[derive(Debug, Clone)]
13pub struct BinaryVector<const N: usize> {
14    /// Packed binary data (8 dimensions per byte)
15    pub data: HVec<u8, N>,
16    /// Original dimension count
17    pub dim: usize,
18    /// Threshold used for binarization
19    pub threshold: i8,
20}
21
22impl<const N: usize> BinaryVector<N> {
23    /// Create binary vector from INT8 values
24    /// Values >= threshold become 1, values < threshold become 0
25    pub fn from_i8(values: &[i8], threshold: i8) -> crate::Result<Self> {
26        let dim = values.len();
27        let num_bytes = (dim + 7) / 8;
28
29        if num_bytes > N {
30            return Err(crate::Error::BufferOverflow);
31        }
32
33        let mut data = HVec::new();
34
35        for chunk_idx in 0..(num_bytes) {
36            let mut byte = 0u8;
37            for bit_idx in 0..8 {
38                let val_idx = chunk_idx * 8 + bit_idx;
39                if val_idx < dim && values[val_idx] >= threshold {
40                    byte |= 1 << bit_idx;
41                }
42            }
43            data.push(byte).map_err(|_| crate::Error::BufferOverflow)?;
44        }
45
46        Ok(Self { data, dim, threshold })
47    }
48
49    /// Create binary vector from f32 values (for host-side quantization)
50    #[cfg(feature = "host-test")]
51    pub fn from_f32(values: &[f32], threshold: f32) -> crate::Result<Self> {
52        let i8_threshold = (threshold * 127.0) as i8;
53        let i8_values: heapless::Vec<i8, 512> = values
54            .iter()
55            .map(|&v| (v * 127.0).clamp(-128.0, 127.0) as i8)
56            .collect();
57        Self::from_i8(&i8_values, i8_threshold)
58    }
59
60    /// Get number of packed bytes
61    pub fn num_bytes(&self) -> usize {
62        self.data.len()
63    }
64
65    /// Memory savings compared to INT8
66    pub fn compression_ratio(&self) -> f32 {
67        self.dim as f32 / self.data.len() as f32
68    }
69}
70
71/// Binary embedding table for vocabulary (32x smaller than INT8)
72pub struct BinaryEmbedding<const VOCAB: usize, const DIM_BYTES: usize> {
73    /// Packed binary embeddings [VOCAB * DIM_BYTES]
74    data: HVec<u8, { 32 * 1024 }>, // Max 32KB
75    /// Vocabulary size
76    vocab_size: usize,
77    /// Dimensions (in bits)
78    dim: usize,
79    /// Bytes per embedding
80    bytes_per_embed: usize,
81}
82
83impl<const VOCAB: usize, const DIM_BYTES: usize> BinaryEmbedding<VOCAB, DIM_BYTES> {
84    /// Create random binary embeddings for testing
85    pub fn random(vocab_size: usize, dim: usize, seed: u32) -> crate::Result<Self> {
86        let bytes_per_embed = (dim + 7) / 8;
87        let total_bytes = vocab_size * bytes_per_embed;
88
89        let mut data = HVec::new();
90        let mut rng_state = seed;
91
92        for _ in 0..total_bytes {
93            rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
94            let byte = ((rng_state >> 16) & 0xFF) as u8;
95            data.push(byte).map_err(|_| crate::Error::BufferOverflow)?;
96        }
97
98        Ok(Self {
99            data,
100            vocab_size,
101            dim,
102            bytes_per_embed,
103        })
104    }
105
106    /// Look up binary embedding for a token
107    pub fn lookup(&self, token_id: u16, output: &mut [u8]) -> crate::Result<()> {
108        let id = token_id as usize;
109        if id >= self.vocab_size {
110            return Err(crate::Error::InvalidModel("Token ID out of range"));
111        }
112
113        let start = id * self.bytes_per_embed;
114        let end = start + self.bytes_per_embed;
115
116        if output.len() < self.bytes_per_embed {
117            return Err(crate::Error::BufferOverflow);
118        }
119
120        output[..self.bytes_per_embed].copy_from_slice(&self.data[start..end]);
121        Ok(())
122    }
123
124    /// Memory size in bytes
125    pub fn memory_size(&self) -> usize {
126        self.data.len()
127    }
128
129    /// Compression vs INT8 embedding of same dimensions
130    pub fn compression_vs_int8(&self) -> f32 {
131        8.0 // 8 bits per dimension -> 1 bit per dimension = 8x
132    }
133}
134
135/// Hamming distance between two binary vectors
136///
137/// Counts the number of differing bits. Uses POPCNT-like operations.
138/// On ESP32, this is extremely fast as it uses simple bitwise operations.
139#[inline]
140pub fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
141    debug_assert_eq!(a.len(), b.len());
142
143    let mut distance: u32 = 0;
144
145    // Process 4 bytes at a time for better performance
146    let chunks = a.len() / 4;
147    for i in 0..chunks {
148        let idx = i * 4;
149        let xor0 = a[idx] ^ b[idx];
150        let xor1 = a[idx + 1] ^ b[idx + 1];
151        let xor2 = a[idx + 2] ^ b[idx + 2];
152        let xor3 = a[idx + 3] ^ b[idx + 3];
153
154        distance += popcount8(xor0) + popcount8(xor1) + popcount8(xor2) + popcount8(xor3);
155    }
156
157    // Handle remainder
158    for i in (chunks * 4)..a.len() {
159        distance += popcount8(a[i] ^ b[i]);
160    }
161
162    distance
163}
164
165/// Hamming similarity (inverted distance, normalized to 0-1 range)
166#[inline]
167pub fn hamming_similarity(a: &[u8], b: &[u8]) -> f32 {
168    let total_bits = (a.len() * 8) as f32;
169    let distance = hamming_distance(a, b) as f32;
170    1.0 - (distance / total_bits)
171}
172
173/// Hamming similarity as fixed-point (0-255 range)
174#[inline]
175pub fn hamming_similarity_fixed(a: &[u8], b: &[u8]) -> u8 {
176    let total_bits = (a.len() * 8) as u32;
177    let matching_bits = total_bits - hamming_distance(a, b);
178    ((matching_bits * 255) / total_bits) as u8
179}
180
181/// Population count for a single byte (count of 1 bits)
182/// Uses lookup table for ESP32 efficiency
183#[inline]
184pub fn popcount8(x: u8) -> u32 {
185    // Lookup table for byte population count
186    const POPCOUNT_TABLE: [u8; 256] = [
187        0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
188        1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
189        1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
190        2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
191        1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
192        2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
193        2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
194        3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
195        1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
196        2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
197        2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
198        3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
199        2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
200        3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
201        3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
202        4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8,
203    ];
204    POPCOUNT_TABLE[x as usize] as u32
205}
206
207/// XNOR-popcount for binary neural network inference
208/// Equivalent to computing dot product of {-1, +1} vectors
209#[inline]
210pub fn xnor_popcount(a: &[u8], b: &[u8]) -> i32 {
211    debug_assert_eq!(a.len(), b.len());
212
213    let total_bits = (a.len() * 8) as i32;
214    let mut matching: i32 = 0;
215
216    for (&x, &y) in a.iter().zip(b.iter()) {
217        // XNOR: same bits = 1, different bits = 0
218        let xnor = !(x ^ y);
219        matching += popcount8(xnor) as i32;
220    }
221
222    // Convert to {-1, +1} dot product equivalent
223    // matching bits contribute +1, non-matching contribute -1
224    // result = 2 * matching - total_bits
225    2 * matching - total_bits
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn test_binary_quantization() {
234        let values = [10i8, -5, 20, -10, 0, 15, -8, 30];
235        let binary = BinaryVector::<8>::from_i8(&values, 0).unwrap();
236
237        assert_eq!(binary.dim, 8);
238        assert_eq!(binary.num_bytes(), 1);
239
240        // Expected: bits where value >= 0: positions 0, 2, 4, 5, 7
241        // Binary: 10110101 = 0xB5
242        assert_eq!(binary.data[0], 0b10110101);
243    }
244
245    #[test]
246    fn test_hamming_distance() {
247        let a = [0b11110000u8, 0b10101010];
248        let b = [0b11110000u8, 0b10101010];
249        assert_eq!(hamming_distance(&a, &b), 0);
250
251        let c = [0b00001111u8, 0b01010101];
252        assert_eq!(hamming_distance(&a, &c), 16); // All bits different
253    }
254
255    #[test]
256    fn test_xnor_popcount() {
257        let a = [0b11111111u8];
258        let b = [0b11111111u8];
259        // Perfect match: 8 matching bits -> 2*8 - 8 = 8
260        assert_eq!(xnor_popcount(&a, &b), 8);
261
262        let c = [0b00000000u8];
263        // Complete mismatch: 0 matching bits -> 2*0 - 8 = -8
264        assert_eq!(xnor_popcount(&a, &c), -8);
265    }
266
267    #[test]
268    fn test_compression_ratio() {
269        let values = [0i8; 64];
270        let binary = BinaryVector::<8>::from_i8(&values, 0).unwrap();
271        assert_eq!(binary.compression_ratio(), 8.0);
272    }
273}