Skip to main content

ruvector_temporal_tensor/
compressor.rs

1//! TemporalTensorCompressor: the main entry point.
2//!
3//! Manages temporal segments, drift detection, and tier transitions.
4//! Caches f32-converted scales to avoid repeated f16 conversion in hot paths.
5
6use crate::quantizer;
7use crate::segment;
8use crate::tier_policy::TierPolicy;
9
10pub struct TemporalTensorCompressor {
11    policy: TierPolicy,
12    len: u32,
13
14    access_count: u32,
15    last_access_ts: u32,
16
17    active_bits: u8,
18    active_group_len: usize,
19    active_scales_f16: Vec<u16>,
20    active_scales_f32: Vec<f32>, // Cached f32 conversion of scales
21    active_frames: u32,
22    active_data: Vec<u8>,
23}
24
25impl TemporalTensorCompressor {
26    /// Create a new compressor for tensors of the given length.
27    pub fn new(policy: TierPolicy, len: u32, now_ts: u32) -> Self {
28        let bits = policy.select_bits(0, now_ts, now_ts);
29        Self {
30            policy,
31            len,
32            access_count: 0,
33            last_access_ts: now_ts,
34            active_bits: bits,
35            active_group_len: policy.group_len.max(1) as usize,
36            active_scales_f16: Vec::new(),
37            active_scales_f32: Vec::new(),
38            active_frames: 0,
39            active_data: Vec::new(),
40        }
41    }
42
43    /// Record an access (increments count, updates timestamp).
44    pub fn touch(&mut self, now_ts: u32) {
45        self.access_count = self.access_count.wrapping_add(1);
46        self.last_access_ts = now_ts;
47    }
48
49    /// Set access stats directly (for restoring state).
50    pub fn set_access(&mut self, access_count: u32, last_access_ts: u32) {
51        self.access_count = access_count;
52        self.last_access_ts = last_access_ts;
53    }
54
55    /// Current tier bits.
56    pub fn active_bits(&self) -> u8 {
57        self.active_bits
58    }
59
60    /// Number of frames in the current segment.
61    pub fn active_frame_count(&self) -> u32 {
62        self.active_frames
63    }
64
65    /// Current policy.
66    pub fn policy(&self) -> &TierPolicy {
67        &self.policy
68    }
69
70    /// Tensor length.
71    pub fn len(&self) -> u32 {
72        self.len
73    }
74
75    /// Returns `true` if the tensor length is zero.
76    pub fn is_empty(&self) -> bool {
77        self.len == 0
78    }
79
80    /// Bytes currently buffered in the active segment data.
81    pub fn active_data_bytes(&self) -> usize {
82        self.active_data.len()
83    }
84
85    /// Push a frame. If a segment boundary is crossed, the completed segment
86    /// bytes are written to `out_segment`. Otherwise `out_segment` is cleared.
87    pub fn push_frame(&mut self, frame: &[f32], now_ts: u32, out_segment: &mut Vec<u8>) {
88        out_segment.clear();
89
90        if frame.len() != self.len as usize {
91            return;
92        }
93
94        let desired_bits = self.policy.select_bits(
95            self.access_count,
96            self.last_access_ts,
97            now_ts,
98        );
99        let drift_factor = self.policy.drift_factor();
100
101        // Use cached f32 scales for drift check (avoids f16 conversion per group)
102        let need_new_segment = self.active_frames == 0
103            || desired_bits != self.active_bits
104            || !quantizer::frame_fits_scales_f32(
105                frame,
106                &self.active_scales_f32,
107                self.active_group_len,
108                self.active_bits,
109                drift_factor,
110            );
111
112        if need_new_segment {
113            self.flush(out_segment);
114            self.active_bits = desired_bits;
115            self.active_group_len = self.policy.group_len.max(1) as usize;
116            self.active_scales_f16 = quantizer::compute_scales(
117                frame,
118                self.active_group_len,
119                self.active_bits,
120            );
121            self.active_scales_f32 = quantizer::scales_to_f32(&self.active_scales_f16);
122        }
123
124        // Use cached f32 scales for quantization (avoids f16 conversion per group)
125        quantizer::quantize_and_pack_f32(
126            frame,
127            &self.active_scales_f32,
128            self.active_group_len,
129            self.active_bits,
130            &mut self.active_data,
131        );
132        self.active_frames = self.active_frames.wrapping_add(1);
133    }
134
135    /// Flush the current segment. Writes segment bytes to `out_segment`.
136    /// Resets internal state for the next segment.
137    pub fn flush(&mut self, out_segment: &mut Vec<u8>) {
138        if self.active_frames == 0 {
139            return;
140        }
141
142        segment::encode(
143            self.active_bits,
144            self.active_group_len as u32,
145            self.len,
146            self.active_frames,
147            &self.active_scales_f16,
148            &self.active_data,
149            out_segment,
150        );
151
152        self.active_frames = 0;
153        self.active_scales_f16.clear();
154        self.active_scales_f32.clear();
155        self.active_data.clear();
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    fn default_policy() -> TierPolicy {
164        TierPolicy::default()
165    }
166
167    #[test]
168    fn test_create_and_push() {
169        let mut comp = TemporalTensorCompressor::new(default_policy(), 64, 0);
170        let frame = vec![1.0f32; 64];
171        let mut seg = Vec::new();
172
173        comp.push_frame(&frame, 0, &mut seg);
174        assert!(seg.is_empty()); // First frame, no completed segment
175        assert_eq!(comp.active_frame_count(), 1);
176    }
177
178    #[test]
179    fn test_flush_produces_segment() {
180        let mut comp = TemporalTensorCompressor::new(default_policy(), 64, 0);
181        let frame = vec![1.0f32; 64];
182        let mut seg = Vec::new();
183
184        comp.push_frame(&frame, 0, &mut seg);
185        comp.flush(&mut seg);
186
187        assert!(!seg.is_empty());
188        let mut decoded = Vec::new();
189        segment::decode(&seg, &mut decoded);
190        assert_eq!(decoded.len(), 64);
191    }
192
193    #[test]
194    fn test_tier_transition_flushes() {
195        let policy = TierPolicy {
196            hot_min_score: 512,
197            warm_min_score: 64,
198            warm_bits: 7,
199            drift_pct_q8: 26,
200            group_len: 64,
201        };
202
203        let mut comp = TemporalTensorCompressor::new(policy, 64, 0);
204        comp.set_access(100, 0); // Hot
205        let frame = vec![1.0f32; 64];
206        let mut seg = Vec::new();
207
208        comp.push_frame(&frame, 1, &mut seg);
209        assert_eq!(comp.active_bits(), 8);
210
211        // Make it cold
212        comp.set_access(1, 0);
213        comp.push_frame(&frame, 10000, &mut seg);
214        assert!(!seg.is_empty());
215        assert_eq!(comp.active_bits(), 3);
216    }
217
218    #[test]
219    fn test_drift_triggers_new_segment() {
220        let mut comp = TemporalTensorCompressor::new(default_policy(), 64, 0);
221        let mut seg = Vec::new();
222
223        let frame1 = vec![1.0f32; 64];
224        comp.push_frame(&frame1, 0, &mut seg);
225
226        let frame2 = vec![5.0f32; 64];
227        comp.push_frame(&frame2, 0, &mut seg);
228
229        assert!(!seg.is_empty());
230    }
231
232    #[test]
233    fn test_multi_frame_same_segment() {
234        let mut comp = TemporalTensorCompressor::new(default_policy(), 64, 0);
235        let mut seg = Vec::new();
236
237        let frame = vec![1.0f32; 64];
238        comp.push_frame(&frame, 0, &mut seg);
239        assert!(seg.is_empty());
240
241        let frame2 = vec![1.05f32; 64];
242        comp.push_frame(&frame2, 0, &mut seg);
243        assert!(seg.is_empty());
244        assert_eq!(comp.active_frame_count(), 2);
245    }
246
247    #[test]
248    fn test_full_roundtrip_hot() {
249        let mut comp = TemporalTensorCompressor::new(default_policy(), 128, 0);
250        comp.set_access(100, 0);
251        let frame: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) * 0.01).collect();
252        let mut seg = Vec::new();
253
254        for _ in 0..10 {
255            comp.push_frame(&frame, 1, &mut seg);
256        }
257        comp.flush(&mut seg);
258
259        let mut decoded = Vec::new();
260        segment::decode(&seg, &mut decoded);
261        assert_eq!(decoded.len(), 128 * 10);
262
263        let max_abs = frame.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
264        for i in 0..128 {
265            let err = (decoded[i] - frame[i]).abs();
266            assert!(err < max_abs * 0.02, "i={i} orig={} dec={} err={err}", frame[i], decoded[i]);
267        }
268    }
269
270    #[test]
271    fn test_full_roundtrip_cold() {
272        let mut comp = TemporalTensorCompressor::new(default_policy(), 64, 0);
273        // Default: access_count=0, cold -> 3-bit
274        let frame: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
275        let mut seg = Vec::new();
276
277        comp.push_frame(&frame, 0, &mut seg);
278        comp.flush(&mut seg);
279
280        let header = segment::parse_header(&seg).unwrap();
281        assert_eq!(header.bits, 3);
282
283        let mut decoded = Vec::new();
284        segment::decode(&seg, &mut decoded);
285        assert_eq!(decoded.len(), 64);
286
287        let max_abs = frame.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
288        for (i, (&orig, &dec)) in frame.iter().zip(decoded.iter()).enumerate() {
289            let err = (orig - dec).abs();
290            // 3-bit: qmax=3, max relative error ~33%
291            assert!(err < max_abs * 0.4, "i={i} orig={orig} dec={dec} err={err}");
292        }
293    }
294
295    #[test]
296    fn test_wrong_length_frame_rejected() {
297        let mut comp = TemporalTensorCompressor::new(default_policy(), 64, 0);
298        let frame = vec![1.0f32; 32];
299        let mut seg = Vec::new();
300        comp.push_frame(&frame, 0, &mut seg);
301        assert_eq!(comp.active_frame_count(), 0);
302    }
303
304    #[test]
305    fn test_accessor_methods() {
306        let policy = TierPolicy::default();
307        let comp = TemporalTensorCompressor::new(policy, 256, 42);
308        assert_eq!(comp.len(), 256);
309        assert_eq!(comp.active_frame_count(), 0);
310        assert_eq!(comp.active_data_bytes(), 0);
311        assert_eq!(comp.policy().group_len, 64);
312    }
313
314    #[test]
315    fn test_large_tensor_multi_group() {
316        let mut comp = TemporalTensorCompressor::new(default_policy(), 512, 0);
317        comp.set_access(100, 0); // hot -> 8-bit
318        let frame: Vec<f32> = (0..512).map(|i| ((i as f32) * 0.731).sin()).collect();
319        let mut seg = Vec::new();
320
321        for _ in 0..50 {
322            comp.push_frame(&frame, 1, &mut seg);
323        }
324        comp.flush(&mut seg);
325
326        let header = segment::parse_header(&seg).unwrap();
327        assert_eq!(header.bits, 8);
328        assert_eq!(header.tensor_len, 512);
329        assert_eq!(header.frame_count, 50);
330        assert_eq!(header.scale_count, 8); // 512/64 = 8 groups
331
332        let mut decoded = Vec::new();
333        segment::decode(&seg, &mut decoded);
334        assert_eq!(decoded.len(), 512 * 50);
335
336        // Verify compression ratio
337        let raw = 512 * 4 * 50;
338        let compressed = seg.len();
339        let ratio = raw as f32 / compressed as f32;
340        assert!(ratio > 3.5, "ratio={ratio:.2}x, expected >3.5x");
341    }
342}