1use heapless::Vec as HVec;
7
8pub const MAX_BINARY_SIZE: usize = 64;
10
11#[derive(Debug, Clone)]
13pub struct BinaryVector<const N: usize> {
14 pub data: HVec<u8, N>,
16 pub dim: usize,
18 pub threshold: i8,
20}
21
22impl<const N: usize> BinaryVector<N> {
23 pub fn from_i8(values: &[i8], threshold: i8) -> crate::Result<Self> {
26 let dim = values.len();
27 let num_bytes = (dim + 7) / 8;
28
29 if num_bytes > N {
30 return Err(crate::Error::BufferOverflow);
31 }
32
33 let mut data = HVec::new();
34
35 for chunk_idx in 0..(num_bytes) {
36 let mut byte = 0u8;
37 for bit_idx in 0..8 {
38 let val_idx = chunk_idx * 8 + bit_idx;
39 if val_idx < dim && values[val_idx] >= threshold {
40 byte |= 1 << bit_idx;
41 }
42 }
43 data.push(byte).map_err(|_| crate::Error::BufferOverflow)?;
44 }
45
46 Ok(Self { data, dim, threshold })
47 }
48
49 #[cfg(feature = "host-test")]
51 pub fn from_f32(values: &[f32], threshold: f32) -> crate::Result<Self> {
52 let i8_threshold = (threshold * 127.0) as i8;
53 let i8_values: heapless::Vec<i8, 512> = values
54 .iter()
55 .map(|&v| (v * 127.0).clamp(-128.0, 127.0) as i8)
56 .collect();
57 Self::from_i8(&i8_values, i8_threshold)
58 }
59
60 pub fn num_bytes(&self) -> usize {
62 self.data.len()
63 }
64
65 pub fn compression_ratio(&self) -> f32 {
67 self.dim as f32 / self.data.len() as f32
68 }
69}
70
71pub struct BinaryEmbedding<const VOCAB: usize, const DIM_BYTES: usize> {
73 data: HVec<u8, { 32 * 1024 }>, vocab_size: usize,
77 dim: usize,
79 bytes_per_embed: usize,
81}
82
83impl<const VOCAB: usize, const DIM_BYTES: usize> BinaryEmbedding<VOCAB, DIM_BYTES> {
84 pub fn random(vocab_size: usize, dim: usize, seed: u32) -> crate::Result<Self> {
86 let bytes_per_embed = (dim + 7) / 8;
87 let total_bytes = vocab_size * bytes_per_embed;
88
89 let mut data = HVec::new();
90 let mut rng_state = seed;
91
92 for _ in 0..total_bytes {
93 rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
94 let byte = ((rng_state >> 16) & 0xFF) as u8;
95 data.push(byte).map_err(|_| crate::Error::BufferOverflow)?;
96 }
97
98 Ok(Self {
99 data,
100 vocab_size,
101 dim,
102 bytes_per_embed,
103 })
104 }
105
106 pub fn lookup(&self, token_id: u16, output: &mut [u8]) -> crate::Result<()> {
108 let id = token_id as usize;
109 if id >= self.vocab_size {
110 return Err(crate::Error::InvalidModel("Token ID out of range"));
111 }
112
113 let start = id * self.bytes_per_embed;
114 let end = start + self.bytes_per_embed;
115
116 if output.len() < self.bytes_per_embed {
117 return Err(crate::Error::BufferOverflow);
118 }
119
120 output[..self.bytes_per_embed].copy_from_slice(&self.data[start..end]);
121 Ok(())
122 }
123
124 pub fn memory_size(&self) -> usize {
126 self.data.len()
127 }
128
129 pub fn compression_vs_int8(&self) -> f32 {
131 8.0 }
133}
134
135#[inline]
140pub fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
141 debug_assert_eq!(a.len(), b.len());
142
143 let mut distance: u32 = 0;
144
145 let chunks = a.len() / 4;
147 for i in 0..chunks {
148 let idx = i * 4;
149 let xor0 = a[idx] ^ b[idx];
150 let xor1 = a[idx + 1] ^ b[idx + 1];
151 let xor2 = a[idx + 2] ^ b[idx + 2];
152 let xor3 = a[idx + 3] ^ b[idx + 3];
153
154 distance += popcount8(xor0) + popcount8(xor1) + popcount8(xor2) + popcount8(xor3);
155 }
156
157 for i in (chunks * 4)..a.len() {
159 distance += popcount8(a[i] ^ b[i]);
160 }
161
162 distance
163}
164
165#[inline]
167pub fn hamming_similarity(a: &[u8], b: &[u8]) -> f32 {
168 let total_bits = (a.len() * 8) as f32;
169 let distance = hamming_distance(a, b) as f32;
170 1.0 - (distance / total_bits)
171}
172
173#[inline]
175pub fn hamming_similarity_fixed(a: &[u8], b: &[u8]) -> u8 {
176 let total_bits = (a.len() * 8) as u32;
177 let matching_bits = total_bits - hamming_distance(a, b);
178 ((matching_bits * 255) / total_bits) as u8
179}
180
181#[inline]
184pub fn popcount8(x: u8) -> u32 {
185 const POPCOUNT_TABLE: [u8; 256] = [
187 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
188 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
189 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
190 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
191 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
192 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
193 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
194 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
195 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
196 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
197 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
198 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
199 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
200 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
201 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
202 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8,
203 ];
204 POPCOUNT_TABLE[x as usize] as u32
205}
206
207#[inline]
210pub fn xnor_popcount(a: &[u8], b: &[u8]) -> i32 {
211 debug_assert_eq!(a.len(), b.len());
212
213 let total_bits = (a.len() * 8) as i32;
214 let mut matching: i32 = 0;
215
216 for (&x, &y) in a.iter().zip(b.iter()) {
217 let xnor = !(x ^ y);
219 matching += popcount8(xnor) as i32;
220 }
221
222 2 * matching - total_bits
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn test_binary_quantization() {
234 let values = [10i8, -5, 20, -10, 0, 15, -8, 30];
235 let binary = BinaryVector::<8>::from_i8(&values, 0).unwrap();
236
237 assert_eq!(binary.dim, 8);
238 assert_eq!(binary.num_bytes(), 1);
239
240 assert_eq!(binary.data[0], 0b10110101);
243 }
244
245 #[test]
246 fn test_hamming_distance() {
247 let a = [0b11110000u8, 0b10101010];
248 let b = [0b11110000u8, 0b10101010];
249 assert_eq!(hamming_distance(&a, &b), 0);
250
251 let c = [0b00001111u8, 0b01010101];
252 assert_eq!(hamming_distance(&a, &c), 16); }
254
255 #[test]
256 fn test_xnor_popcount() {
257 let a = [0b11111111u8];
258 let b = [0b11111111u8];
259 assert_eq!(xnor_popcount(&a, &b), 8);
261
262 let c = [0b00000000u8];
263 assert_eq!(xnor_popcount(&a, &c), -8);
265 }
266
267 #[test]
268 fn test_compression_ratio() {
269 let values = [0i8; 64];
270 let binary = BinaryVector::<8>::from_i8(&values, 0).unwrap();
271 assert_eq!(binary.compression_ratio(), 8.0);
272 }
273}