Skip to main content

zeph_memory/
compression_predictor.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Performance-floor compression ratio predictor (#2460).
5//!
6//! A lightweight linear regression model that predicts compaction probe quality
7//! at a given compression ratio. Used to select the most aggressive compression
8//! ratio that keeps the predicted probe score above `hard_fail_threshold`.
9//!
10//! # Design
11//!
12//! - No external ML crate dependencies — pure f32 arithmetic
13//! - 4 input features: `compression_ratio`, `message_count`, `avg_message_length`, `tool_output_fraction`
14//! - MSE loss with mini-batch gradient descent (continuous score target, not binary)
15//! - Sigmoid output activation to bound predictions in [0.0, 1.0]
16//! - Persisted as JSON via the `compression_predictor_weights` `SQLite` table
17//! - Falls back to `None` (use default behavior) during cold start
18//! - Training data sliding window: only the most recent N samples are retained
19
20use serde::{Deserialize, Serialize};
21
22const LEARNING_RATE: f32 = 0.01;
23const EPOCHS: usize = 50;
24
25// ── Features ──────────────────────────────────────────────────────────────────
26
27/// Input features for the compression quality predictor.
28#[derive(Debug, Clone, Copy)]
29pub struct CompressionFeatures {
30    /// Fraction of tokens retained after compression. Range: [0.0, 1.0].
31    pub compression_ratio: f32,
32    /// Normalized message count (divide by a reference scale, e.g. 100).
33    pub message_count: f32,
34    /// Normalized average token count per message.
35    pub avg_message_length: f32,
36    /// Fraction of messages that are tool outputs. Range: [0.0, 1.0].
37    pub tool_output_fraction: f32,
38}
39
40impl CompressionFeatures {
41    /// Convert to a fixed-length array for model arithmetic.
42    #[must_use]
43    pub fn to_vec(&self) -> [f32; 4] {
44        [
45            self.compression_ratio,
46            self.message_count,
47            self.avg_message_length,
48            self.tool_output_fraction,
49        ]
50    }
51}
52
53// ── Weights ───────────────────────────────────────────────────────────────────
54
55/// Persisted model state (JSON-serializable).
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct CompressionModelWeights {
58    /// Weight vector (one per feature).
59    pub weights: Vec<f32>,
60    /// Bias term.
61    pub bias: f32,
62    /// Number of training samples used.
63    pub sample_count: u64,
64}
65
66impl Default for CompressionModelWeights {
67    fn default() -> Self {
68        // Small positive initial weights. Compression ratio is positively correlated
69        // with quality (higher ratio = less compression = better score), so initializing
70        // weights positive helps the model converge faster from cold start.
71        Self {
72            weights: vec![0.3, 0.05, 0.05, -0.1],
73            bias: 0.0,
74            sample_count: 0,
75        }
76    }
77}
78
79// ── Model ─────────────────────────────────────────────────────────────────────
80
81/// Compression quality predictor using linear regression with sigmoid output.
82pub struct CompressionPredictor {
83    weights: CompressionModelWeights,
84}
85
86impl CompressionPredictor {
87    /// Create a new model with default (untrained) weights.
88    #[must_use]
89    pub fn new() -> Self {
90        Self {
91            weights: CompressionModelWeights::default(),
92        }
93    }
94
95    /// Create a model from saved weights.
96    #[must_use]
97    pub fn from_weights(weights: CompressionModelWeights) -> Self {
98        Self { weights }
99    }
100
101    /// Serialize current weights to JSON.
102    ///
103    /// # Errors
104    ///
105    /// Returns an error if serialization fails.
106    pub fn serialize(&self) -> Result<String, serde_json::Error> {
107        serde_json::to_string(&self.weights)
108    }
109
110    /// Return the number of training samples used.
111    #[must_use]
112    pub fn sample_count(&self) -> u64 {
113        self.weights.sample_count
114    }
115
116    /// Return `true` if the model has fewer than `min_samples` training samples.
117    ///
118    /// During cold start, `select_ratio` returns `None` and the caller falls back
119    /// to default compression behavior.
120    #[must_use]
121    pub fn is_cold_start(&self, min_samples: u64) -> bool {
122        self.weights.sample_count < min_samples
123    }
124
125    /// Predict probe score for the given features. Range: `[0.0, 1.0]`.
126    ///
127    /// Uses `sigmoid(w^T x + b)` to bound predictions.
128    #[must_use]
129    pub fn predict(&self, features: &CompressionFeatures) -> f32 {
130        let x = features.to_vec();
131        let dot: f32 = x
132            .iter()
133            .zip(self.weights.weights.iter())
134            .map(|(xi, wi)| xi * wi)
135            .sum();
136        sigmoid(dot + self.weights.bias)
137    }
138
139    /// Train the model on a batch of `(features, probe_score)` pairs using MSE loss.
140    ///
141    /// Runs `EPOCHS` gradient steps. `sample_count` is incremented by the batch size
142    /// once per `train()` call (not per epoch) to avoid over-counting.
143    pub fn train(&mut self, samples: &[(CompressionFeatures, f32)]) {
144        if samples.is_empty() {
145            return;
146        }
147
148        let n_features = self.weights.weights.len();
149        #[allow(clippy::cast_precision_loss)]
150        let n = samples.len() as f32;
151
152        for _ in 0..EPOCHS {
153            let mut grad_w = vec![0.0f32; n_features];
154            let mut grad_b = 0.0f32;
155
156            for (features, target) in samples {
157                let pred = self.predict(features);
158                // MSE gradient: 2 * (pred - target) * sigmoid_derivative(pred)
159                // sigmoid_derivative(y) = y * (1 - y)
160                let error = (pred - target) * pred * (1.0 - pred);
161                let x = features.to_vec();
162                for (i, xi) in x.iter().enumerate() {
163                    grad_w[i] += error * xi;
164                }
165                grad_b += error;
166            }
167
168            for (wi, gi) in self.weights.weights.iter_mut().zip(grad_w.iter()) {
169                *wi -= LEARNING_RATE * gi / n;
170            }
171            self.weights.bias -= LEARNING_RATE * grad_b / n;
172        }
173
174        self.weights.sample_count += samples.len() as u64;
175    }
176
177    /// Find the most aggressive compression ratio that keeps predicted score >= `floor`.
178    ///
179    /// Iterates `candidate_ratios` from lowest (most aggressive) to highest (least aggressive).
180    /// Returns the first ratio whose predicted quality clears `floor`, or `None` if no
181    /// candidate passes (caller should fall back to default behavior).
182    ///
183    /// # Panics
184    ///
185    /// Does not panic; returns `None` on empty candidate list.
186    #[must_use]
187    pub fn select_ratio(
188        &self,
189        floor: f32,
190        candidate_ratios: &[f32],
191        message_count: f32,
192        avg_message_length: f32,
193        tool_output_fraction: f32,
194    ) -> Option<f32> {
195        // Iterate from most aggressive to least aggressive (ascending ratio order).
196        let mut sorted = candidate_ratios.to_vec();
197        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
198
199        for &ratio in &sorted {
200            let features = CompressionFeatures {
201                compression_ratio: ratio,
202                message_count,
203                avg_message_length,
204                tool_output_fraction,
205            };
206            if self.predict(&features) >= floor {
207                return Some(ratio);
208            }
209        }
210        None
211    }
212}
213
214impl Default for CompressionPredictor {
215    fn default() -> Self {
216        Self::new()
217    }
218}
219
220/// Logistic (sigmoid) activation function.
221#[inline]
222fn sigmoid(x: f32) -> f32 {
223    1.0 / (1.0 + (-x).exp())
224}
225
226// ── Tests ─────────────────────────────────────────────────────────────────────
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn sigmoid_at_zero_is_half() {
234        assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
235    }
236
237    #[test]
238    fn predict_with_default_weights_returns_valid_range() {
239        let model = CompressionPredictor::new();
240        let features = CompressionFeatures {
241            compression_ratio: 0.5,
242            message_count: 0.5,
243            avg_message_length: 0.5,
244            tool_output_fraction: 0.2,
245        };
246        let p = model.predict(&features);
247        assert!(p > 0.0 && p < 1.0, "prediction must be in (0, 1): {p}");
248    }
249
250    #[test]
251    fn is_cold_start_true_when_no_training() {
252        let model = CompressionPredictor::new();
253        assert!(model.is_cold_start(10));
254    }
255
256    #[test]
257    fn is_cold_start_false_after_training() {
258        let mut model = CompressionPredictor::new();
259        let features = CompressionFeatures {
260            compression_ratio: 0.5,
261            message_count: 0.5,
262            avg_message_length: 0.5,
263            tool_output_fraction: 0.2,
264        };
265        let samples: Vec<_> = (0..10).map(|_| (features, 0.7f32)).collect();
266        model.train(&samples);
267        assert!(!model.is_cold_start(10));
268    }
269
270    #[test]
271    fn training_on_high_ratio_improves_high_ratio_prediction() {
272        let mut model = CompressionPredictor::new();
273        let high_ratio = CompressionFeatures {
274            compression_ratio: 0.8,
275            message_count: 0.5,
276            avg_message_length: 0.5,
277            tool_output_fraction: 0.2,
278        };
279        let initial_pred = model.predict(&high_ratio);
280        let samples: Vec<_> = (0..50).map(|_| (high_ratio, 0.9f32)).collect();
281        model.train(&samples);
282        let trained_pred = model.predict(&high_ratio);
283        assert!(
284            trained_pred > initial_pred,
285            "training on high target must increase prediction: {initial_pred} -> {trained_pred}"
286        );
287    }
288
289    #[test]
290    fn select_ratio_returns_most_aggressive_passing_ratio() {
291        let mut model = CompressionPredictor::new();
292        // Train so that 0.8 predicts well but 0.2 predicts poorly.
293        let good = CompressionFeatures {
294            compression_ratio: 0.8,
295            message_count: 0.5,
296            avg_message_length: 0.5,
297            tool_output_fraction: 0.2,
298        };
299        let bad = CompressionFeatures {
300            compression_ratio: 0.2,
301            message_count: 0.5,
302            avg_message_length: 0.5,
303            tool_output_fraction: 0.2,
304        };
305        let samples: Vec<(CompressionFeatures, f32)> = (0..100)
306            .map(|i| if i % 2 == 0 { (good, 0.95) } else { (bad, 0.1) })
307            .collect();
308        model.train(&samples);
309
310        let ratios = vec![0.2, 0.4, 0.6, 0.8, 0.9];
311        let selected = model.select_ratio(0.5, &ratios, 0.5, 0.5, 0.2);
312        // Must pick a ratio that the model predicts >= 0.5
313        if let Some(r) = selected {
314            let features = CompressionFeatures {
315                compression_ratio: r,
316                message_count: 0.5,
317                avg_message_length: 0.5,
318                tool_output_fraction: 0.2,
319            };
320            assert!(
321                model.predict(&features) >= 0.5,
322                "selected ratio must clear floor"
323            );
324        }
325    }
326
327    #[test]
328    fn select_ratio_returns_none_when_nothing_passes() {
329        let model = CompressionPredictor::new();
330        // Default weights are all zero: sigmoid(0) = 0.5, which is below the 0.99 floor.
331        // All candidates must fail the floor check, so None is returned.
332        let ratios = vec![0.1, 0.2];
333        let selected = model.select_ratio(0.99, &ratios, 0.5, 0.5, 0.2);
334        assert!(
335            selected.is_none(),
336            "expected None when no ratio clears the floor"
337        );
338    }
339
340    #[test]
341    fn sample_count_increments_after_training() {
342        let mut model = CompressionPredictor::new();
343        assert_eq!(model.sample_count(), 0);
344        let features = CompressionFeatures {
345            compression_ratio: 0.5,
346            message_count: 0.5,
347            avg_message_length: 0.5,
348            tool_output_fraction: 0.2,
349        };
350        model.train(&[(features, 0.7)]);
351        assert_eq!(model.sample_count(), 1);
352        model.train(&[(features, 0.6), (features, 0.8)]);
353        assert_eq!(model.sample_count(), 3);
354    }
355
356    #[test]
357    fn serialize_roundtrip() {
358        let model = CompressionPredictor::new();
359        let json = model.serialize().expect("serialize");
360        let weights: CompressionModelWeights = serde_json::from_str(&json).expect("deserialize");
361        assert_eq!(weights.weights.len(), model.weights.weights.len());
362        assert!((weights.bias - model.weights.bias).abs() < 1e-6);
363    }
364}