ruvector_sparse_inference/
memory.rs1use serde::{Deserialize, Serialize};
7use crate::config::CacheConfig;
8use crate::error::Result;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct QuantizedWeights {
16 data: Vec<u8>,
18 scales: Vec<f32>,
20 zero_points: Vec<f32>,
22 group_size: usize,
24 shape: (usize, usize),
26 bits: u8,
28}
29
30impl QuantizedWeights {
31 pub fn from_f32(
33 data: &[f32],
34 rows: usize,
35 cols: usize,
36 bits: u8,
37 group_size: usize,
38 ) -> Result<Self> {
39 assert!(bits == 4 || bits == 8, "Only 4-bit and 8-bit quantization supported");
40
41 let num_groups = (data.len() + group_size - 1) / group_size;
42 let mut scales = Vec::with_capacity(num_groups);
43 let mut zero_points = Vec::with_capacity(num_groups);
44
45 for group in data.chunks(group_size) {
47 let min = group.iter().fold(f32::INFINITY, |a, &b| a.min(b));
48 let max = group.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
49
50 let range = max - min;
51 let max_quant = ((1 << bits) - 1) as f32;
52
53 let scale = if range > 0.0 { range / max_quant } else { 1.0 };
54 scales.push(scale);
55 zero_points.push(min);
56 }
57
58 let quantized_data = if bits == 8 {
60 data.chunks(group_size)
61 .zip(scales.iter().zip(zero_points.iter()))
62 .flat_map(|(group, (&scale, &zp))| {
63 group.iter().map(move |&v| {
64 ((v - zp) / scale).round().clamp(0.0, 255.0) as u8
65 })
66 })
67 .collect()
68 } else {
69 let mut packed = Vec::with_capacity((data.len() + 1) / 2);
71 let quantized: Vec<u8> = data.chunks(group_size)
72 .zip(scales.iter().zip(zero_points.iter()))
73 .flat_map(|(group, (&scale, &zp))| {
74 group.iter().map(move |&v| {
75 ((v - zp) / scale).round().clamp(0.0, 15.0) as u8
76 })
77 })
78 .collect();
79
80 for pair in quantized.chunks(2) {
81 let byte = pair[0] | (pair.get(1).unwrap_or(&0) << 4);
82 packed.push(byte);
83 }
84 packed
85 };
86
87 Ok(Self {
88 data: quantized_data,
89 scales,
90 zero_points,
91 group_size,
92 shape: (rows, cols),
93 bits,
94 })
95 }
96
97 pub fn to_f32(&self) -> Vec<f32> {
99 let total = self.shape.0 * self.shape.1;
100 let mut result = Vec::with_capacity(total);
101
102 if self.bits == 8 {
103 for (i, &q) in self.data.iter().take(total).enumerate() {
104 let group_idx = i / self.group_size;
105 let scale = self.scales[group_idx];
106 let zp = self.zero_points[group_idx];
107 result.push(q as f32 * scale + zp);
108 }
109 } else {
110 for (i, &byte) in self.data.iter().enumerate() {
112 let idx = i * 2;
113 if idx < total {
114 let group_idx = idx / self.group_size;
115 let scale = self.scales[group_idx];
116 let zp = self.zero_points[group_idx];
117 result.push((byte & 0x0F) as f32 * scale + zp);
118 }
119 if idx + 1 < total {
120 let group_idx = (idx + 1) / self.group_size;
121 let scale = self.scales[group_idx];
122 let zp = self.zero_points[group_idx];
123 result.push((byte >> 4) as f32 * scale + zp);
124 }
125 }
126 }
127
128 result
129 }
130
131 pub fn shape(&self) -> (usize, usize) {
133 self.shape
134 }
135
136 pub fn memory_size(&self) -> usize {
138 self.data.len() + self.scales.len() * 4 + self.zero_points.len() * 4
139 }
140}
141
142#[derive(Debug, Clone)]
147pub struct NeuronCache {
148 activation_counts: Vec<u64>,
150 hot_neurons: Vec<usize>,
152 cold_neurons: Vec<usize>,
154 hot_threshold: f64,
156 total_activations: u64,
158 num_neurons: usize,
160}
161
162impl NeuronCache {
163 pub fn new(num_neurons: usize, config: CacheConfig) -> Self {
165 Self {
166 activation_counts: vec![0; num_neurons],
167 hot_neurons: Vec::new(),
168 cold_neurons: (0..num_neurons).collect(),
169 hot_threshold: config.hot_neuron_fraction as f64,
170 total_activations: 0,
171 num_neurons,
172 }
173 }
174
175 pub fn with_threshold(num_neurons: usize, hot_threshold: f64) -> Self {
177 Self {
178 activation_counts: vec![0; num_neurons],
179 hot_neurons: Vec::new(),
180 cold_neurons: (0..num_neurons).collect(),
181 hot_threshold,
182 total_activations: 0,
183 num_neurons,
184 }
185 }
186
187 pub fn clear(&mut self) {
189 self.activation_counts.fill(0);
190 self.hot_neurons.clear();
191 self.cold_neurons = (0..self.num_neurons).collect();
192 self.total_activations = 0;
193 }
194
195 pub fn record_activations(&mut self, active_neurons: &[usize]) {
197 for &neuron in active_neurons {
198 if neuron < self.activation_counts.len() {
199 self.activation_counts[neuron] += 1;
200 }
201 }
202 self.total_activations += 1;
203
204 if self.total_activations % 1000 == 0 {
206 self.reclassify();
207 }
208 }
209
210 pub fn reclassify(&mut self) {
212 if self.total_activations == 0 {
213 return;
214 }
215
216 let threshold = (self.total_activations as f64 * self.hot_threshold) as u64;
217
218 self.hot_neurons.clear();
219 self.cold_neurons.clear();
220
221 for (i, &count) in self.activation_counts.iter().enumerate() {
222 if count >= threshold {
223 self.hot_neurons.push(i);
224 } else {
225 self.cold_neurons.push(i);
226 }
227 }
228 }
229
230 pub fn hot_neurons(&self) -> &[usize] {
232 &self.hot_neurons
233 }
234
235 pub fn cold_neurons(&self) -> &[usize] {
237 &self.cold_neurons
238 }
239
240 pub fn activation_frequency(&self, neuron: usize) -> f64 {
242 if self.total_activations == 0 || neuron >= self.activation_counts.len() {
243 return 0.0;
244 }
245 self.activation_counts[neuron] as f64 / self.total_activations as f64
246 }
247
248 pub fn stats(&self) -> CacheStats {
250 CacheStats {
251 num_hot: self.hot_neurons.len(),
252 num_cold: self.cold_neurons.len(),
253 total_activations: self.total_activations,
254 hot_ratio: self.hot_neurons.len() as f64 / self.activation_counts.len() as f64,
255 }
256 }
257}
258
259#[derive(Debug, Clone)]
261pub struct CacheStats {
262 pub num_hot: usize,
264 pub num_cold: usize,
266 pub total_activations: u64,
268 pub hot_ratio: f64,
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_quantized_weights_8bit() {
278 let data: Vec<f32> = (0..256).map(|i| i as f32 / 256.0).collect();
279 let qw = QuantizedWeights::from_f32(&data, 16, 16, 8, 32).unwrap();
280
281 let restored = qw.to_f32();
282 assert_eq!(restored.len(), 256);
283
284 let max_error: f32 = data.iter()
286 .zip(restored.iter())
287 .map(|(a, b)| (a - b).abs())
288 .fold(0.0, f32::max);
289 assert!(max_error < 0.01, "Max error: {}", max_error);
290 }
291
292 #[test]
293 fn test_quantized_weights_4bit() {
294 let data: Vec<f32> = (0..256).map(|i| i as f32 / 256.0).collect();
295 let qw = QuantizedWeights::from_f32(&data, 16, 16, 4, 32).unwrap();
296
297 let restored = qw.to_f32();
298 assert_eq!(restored.len(), 256);
299
300 let max_error: f32 = data.iter()
302 .zip(restored.iter())
303 .map(|(a, b)| (a - b).abs())
304 .fold(0.0, f32::max);
305 assert!(max_error < 0.1, "Max error: {}", max_error);
306 }
307
308 #[test]
309 fn test_neuron_cache() {
310 let mut cache = NeuronCache::with_threshold(100, 0.1);
311
312 for _ in 0..1000 {
314 cache.record_activations(&[0, 1, 2, 3, 4]);
315 }
316
317 cache.reclassify();
318
319 assert!(cache.hot_neurons().contains(&0));
320 assert!(cache.hot_neurons().contains(&1));
321 assert!(!cache.hot_neurons().contains(&50));
322 }
323}