ruvector_sparse_inference/
memory.rs

1//! Memory management for sparse inference.
2//!
3//! This module provides weight quantization and neuron caching for efficient
4//! memory usage during inference.
5
6use serde::{Deserialize, Serialize};
7use crate::config::CacheConfig;
8use crate::error::Result;
9
10/// Quantized weight storage for reduced memory usage.
11///
12/// Stores neural network weights in a compressed format to reduce
13/// memory footprint while maintaining accuracy.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct QuantizedWeights {
16    /// Quantized weight data (packed bits)
17    data: Vec<u8>,
18    /// Scale factors per group
19    scales: Vec<f32>,
20    /// Zero points per group
21    zero_points: Vec<f32>,
22    /// Group size for quantization
23    group_size: usize,
24    /// Original dimensions
25    shape: (usize, usize),
26    /// Quantization bit width
27    bits: u8,
28}
29
30impl QuantizedWeights {
31    /// Create new quantized weights from f32 data.
32    pub fn from_f32(
33        data: &[f32],
34        rows: usize,
35        cols: usize,
36        bits: u8,
37        group_size: usize,
38    ) -> Result<Self> {
39        assert!(bits == 4 || bits == 8, "Only 4-bit and 8-bit quantization supported");
40
41        let num_groups = (data.len() + group_size - 1) / group_size;
42        let mut scales = Vec::with_capacity(num_groups);
43        let mut zero_points = Vec::with_capacity(num_groups);
44
45        // Calculate per-group scales and zero points
46        for group in data.chunks(group_size) {
47            let min = group.iter().fold(f32::INFINITY, |a, &b| a.min(b));
48            let max = group.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
49
50            let range = max - min;
51            let max_quant = ((1 << bits) - 1) as f32;
52
53            let scale = if range > 0.0 { range / max_quant } else { 1.0 };
54            scales.push(scale);
55            zero_points.push(min);
56        }
57
58        // Quantize the data
59        let quantized_data = if bits == 8 {
60            data.chunks(group_size)
61                .zip(scales.iter().zip(zero_points.iter()))
62                .flat_map(|(group, (&scale, &zp))| {
63                    group.iter().map(move |&v| {
64                        ((v - zp) / scale).round().clamp(0.0, 255.0) as u8
65                    })
66                })
67                .collect()
68        } else {
69            // 4-bit: pack two values per byte
70            let mut packed = Vec::with_capacity((data.len() + 1) / 2);
71            let quantized: Vec<u8> = data.chunks(group_size)
72                .zip(scales.iter().zip(zero_points.iter()))
73                .flat_map(|(group, (&scale, &zp))| {
74                    group.iter().map(move |&v| {
75                        ((v - zp) / scale).round().clamp(0.0, 15.0) as u8
76                    })
77                })
78                .collect();
79
80            for pair in quantized.chunks(2) {
81                let byte = pair[0] | (pair.get(1).unwrap_or(&0) << 4);
82                packed.push(byte);
83            }
84            packed
85        };
86
87        Ok(Self {
88            data: quantized_data,
89            scales,
90            zero_points,
91            group_size,
92            shape: (rows, cols),
93            bits,
94        })
95    }
96
97    /// Dequantize to f32.
98    pub fn to_f32(&self) -> Vec<f32> {
99        let total = self.shape.0 * self.shape.1;
100        let mut result = Vec::with_capacity(total);
101
102        if self.bits == 8 {
103            for (i, &q) in self.data.iter().take(total).enumerate() {
104                let group_idx = i / self.group_size;
105                let scale = self.scales[group_idx];
106                let zp = self.zero_points[group_idx];
107                result.push(q as f32 * scale + zp);
108            }
109        } else {
110            // 4-bit unpacking
111            for (i, &byte) in self.data.iter().enumerate() {
112                let idx = i * 2;
113                if idx < total {
114                    let group_idx = idx / self.group_size;
115                    let scale = self.scales[group_idx];
116                    let zp = self.zero_points[group_idx];
117                    result.push((byte & 0x0F) as f32 * scale + zp);
118                }
119                if idx + 1 < total {
120                    let group_idx = (idx + 1) / self.group_size;
121                    let scale = self.scales[group_idx];
122                    let zp = self.zero_points[group_idx];
123                    result.push((byte >> 4) as f32 * scale + zp);
124                }
125            }
126        }
127
128        result
129    }
130
131    /// Get shape.
132    pub fn shape(&self) -> (usize, usize) {
133        self.shape
134    }
135
136    /// Get memory size in bytes.
137    pub fn memory_size(&self) -> usize {
138        self.data.len() + self.scales.len() * 4 + self.zero_points.len() * 4
139    }
140}
141
142/// Neuron activation cache for hot/cold management.
143///
144/// Tracks neuron activation frequencies and maintains a cache of
145/// frequently accessed ("hot") neuron weights.
146#[derive(Debug, Clone)]
147pub struct NeuronCache {
148    /// Activation counts per neuron
149    activation_counts: Vec<u64>,
150    /// Hot neuron indices (frequently activated)
151    hot_neurons: Vec<usize>,
152    /// Cold neuron indices (rarely activated)
153    cold_neurons: Vec<usize>,
154    /// Threshold for hot classification
155    hot_threshold: f64,
156    /// Total activations tracked
157    total_activations: u64,
158    /// Number of neurons
159    num_neurons: usize,
160}
161
162impl NeuronCache {
163    /// Create a new neuron cache from config.
164    pub fn new(num_neurons: usize, config: CacheConfig) -> Self {
165        Self {
166            activation_counts: vec![0; num_neurons],
167            hot_neurons: Vec::new(),
168            cold_neurons: (0..num_neurons).collect(),
169            hot_threshold: config.hot_neuron_fraction as f64,
170            total_activations: 0,
171            num_neurons,
172        }
173    }
174
175    /// Create a new neuron cache with explicit threshold.
176    pub fn with_threshold(num_neurons: usize, hot_threshold: f64) -> Self {
177        Self {
178            activation_counts: vec![0; num_neurons],
179            hot_neurons: Vec::new(),
180            cold_neurons: (0..num_neurons).collect(),
181            hot_threshold,
182            total_activations: 0,
183            num_neurons,
184        }
185    }
186
187    /// Clear all cache state and reset counters.
188    pub fn clear(&mut self) {
189        self.activation_counts.fill(0);
190        self.hot_neurons.clear();
191        self.cold_neurons = (0..self.num_neurons).collect();
192        self.total_activations = 0;
193    }
194
195    /// Record neuron activations.
196    pub fn record_activations(&mut self, active_neurons: &[usize]) {
197        for &neuron in active_neurons {
198            if neuron < self.activation_counts.len() {
199                self.activation_counts[neuron] += 1;
200            }
201        }
202        self.total_activations += 1;
203
204        // Periodically reclassify
205        if self.total_activations % 1000 == 0 {
206            self.reclassify();
207        }
208    }
209
210    /// Reclassify neurons as hot or cold.
211    pub fn reclassify(&mut self) {
212        if self.total_activations == 0 {
213            return;
214        }
215
216        let threshold = (self.total_activations as f64 * self.hot_threshold) as u64;
217
218        self.hot_neurons.clear();
219        self.cold_neurons.clear();
220
221        for (i, &count) in self.activation_counts.iter().enumerate() {
222            if count >= threshold {
223                self.hot_neurons.push(i);
224            } else {
225                self.cold_neurons.push(i);
226            }
227        }
228    }
229
230    /// Get hot neurons.
231    pub fn hot_neurons(&self) -> &[usize] {
232        &self.hot_neurons
233    }
234
235    /// Get cold neurons.
236    pub fn cold_neurons(&self) -> &[usize] {
237        &self.cold_neurons
238    }
239
240    /// Get activation frequency for a neuron.
241    pub fn activation_frequency(&self, neuron: usize) -> f64 {
242        if self.total_activations == 0 || neuron >= self.activation_counts.len() {
243            return 0.0;
244        }
245        self.activation_counts[neuron] as f64 / self.total_activations as f64
246    }
247
248    /// Get cache statistics.
249    pub fn stats(&self) -> CacheStats {
250        CacheStats {
251            num_hot: self.hot_neurons.len(),
252            num_cold: self.cold_neurons.len(),
253            total_activations: self.total_activations,
254            hot_ratio: self.hot_neurons.len() as f64 / self.activation_counts.len() as f64,
255        }
256    }
257}
258
259/// Cache statistics.
260#[derive(Debug, Clone)]
261pub struct CacheStats {
262    /// Number of hot neurons.
263    pub num_hot: usize,
264    /// Number of cold neurons.
265    pub num_cold: usize,
266    /// Total activations tracked.
267    pub total_activations: u64,
268    /// Ratio of hot neurons.
269    pub hot_ratio: f64,
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_quantized_weights_8bit() {
278        let data: Vec<f32> = (0..256).map(|i| i as f32 / 256.0).collect();
279        let qw = QuantizedWeights::from_f32(&data, 16, 16, 8, 32).unwrap();
280
281        let restored = qw.to_f32();
282        assert_eq!(restored.len(), 256);
283
284        // Check reconstruction error
285        let max_error: f32 = data.iter()
286            .zip(restored.iter())
287            .map(|(a, b)| (a - b).abs())
288            .fold(0.0, f32::max);
289        assert!(max_error < 0.01, "Max error: {}", max_error);
290    }
291
292    #[test]
293    fn test_quantized_weights_4bit() {
294        let data: Vec<f32> = (0..256).map(|i| i as f32 / 256.0).collect();
295        let qw = QuantizedWeights::from_f32(&data, 16, 16, 4, 32).unwrap();
296
297        let restored = qw.to_f32();
298        assert_eq!(restored.len(), 256);
299
300        // 4-bit has more error
301        let max_error: f32 = data.iter()
302            .zip(restored.iter())
303            .map(|(a, b)| (a - b).abs())
304            .fold(0.0, f32::max);
305        assert!(max_error < 0.1, "Max error: {}", max_error);
306    }
307
308    #[test]
309    fn test_neuron_cache() {
310        let mut cache = NeuronCache::with_threshold(100, 0.1);
311
312        // Activate some neurons frequently
313        for _ in 0..1000 {
314            cache.record_activations(&[0, 1, 2, 3, 4]);
315        }
316
317        cache.reclassify();
318
319        assert!(cache.hot_neurons().contains(&0));
320        assert!(cache.hot_neurons().contains(&1));
321        assert!(!cache.hot_neurons().contains(&50));
322    }
323}