1use serde::{Deserialize, Serialize};
21
22const LEARNING_RATE: f32 = 0.01;
23const EPOCHS: usize = 50;
24
25#[derive(Debug, Clone, Copy)]
29pub struct CompressionFeatures {
30 pub compression_ratio: f32,
32 pub message_count: f32,
34 pub avg_message_length: f32,
36 pub tool_output_fraction: f32,
38}
39
40impl CompressionFeatures {
41 #[must_use]
43 pub fn to_vec(&self) -> [f32; 4] {
44 [
45 self.compression_ratio,
46 self.message_count,
47 self.avg_message_length,
48 self.tool_output_fraction,
49 ]
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct CompressionModelWeights {
58 pub weights: Vec<f32>,
60 pub bias: f32,
62 pub sample_count: u64,
64}
65
66impl Default for CompressionModelWeights {
67 fn default() -> Self {
68 Self {
72 weights: vec![0.3, 0.05, 0.05, -0.1],
73 bias: 0.0,
74 sample_count: 0,
75 }
76 }
77}
78
79pub struct CompressionPredictor {
83 weights: CompressionModelWeights,
84}
85
86impl CompressionPredictor {
87 #[must_use]
89 pub fn new() -> Self {
90 Self {
91 weights: CompressionModelWeights::default(),
92 }
93 }
94
95 #[must_use]
97 pub fn from_weights(weights: CompressionModelWeights) -> Self {
98 Self { weights }
99 }
100
101 pub fn serialize(&self) -> Result<String, serde_json::Error> {
107 serde_json::to_string(&self.weights)
108 }
109
110 #[must_use]
112 pub fn sample_count(&self) -> u64 {
113 self.weights.sample_count
114 }
115
116 #[must_use]
121 pub fn is_cold_start(&self, min_samples: u64) -> bool {
122 self.weights.sample_count < min_samples
123 }
124
125 #[must_use]
129 pub fn predict(&self, features: &CompressionFeatures) -> f32 {
130 let x = features.to_vec();
131 let dot: f32 = x
132 .iter()
133 .zip(self.weights.weights.iter())
134 .map(|(xi, wi)| xi * wi)
135 .sum();
136 sigmoid(dot + self.weights.bias)
137 }
138
139 pub fn train(&mut self, samples: &[(CompressionFeatures, f32)]) {
144 if samples.is_empty() {
145 return;
146 }
147
148 let n_features = self.weights.weights.len();
149 #[allow(clippy::cast_precision_loss)]
150 let n = samples.len() as f32;
151
152 for _ in 0..EPOCHS {
153 let mut grad_w = vec![0.0f32; n_features];
154 let mut grad_b = 0.0f32;
155
156 for (features, target) in samples {
157 let pred = self.predict(features);
158 let error = (pred - target) * pred * (1.0 - pred);
161 let x = features.to_vec();
162 for (i, xi) in x.iter().enumerate() {
163 grad_w[i] += error * xi;
164 }
165 grad_b += error;
166 }
167
168 for (wi, gi) in self.weights.weights.iter_mut().zip(grad_w.iter()) {
169 *wi -= LEARNING_RATE * gi / n;
170 }
171 self.weights.bias -= LEARNING_RATE * grad_b / n;
172 }
173
174 self.weights.sample_count += samples.len() as u64;
175 }
176
177 #[must_use]
187 pub fn select_ratio(
188 &self,
189 floor: f32,
190 candidate_ratios: &[f32],
191 message_count: f32,
192 avg_message_length: f32,
193 tool_output_fraction: f32,
194 ) -> Option<f32> {
195 let mut sorted = candidate_ratios.to_vec();
197 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
198
199 for &ratio in &sorted {
200 let features = CompressionFeatures {
201 compression_ratio: ratio,
202 message_count,
203 avg_message_length,
204 tool_output_fraction,
205 };
206 if self.predict(&features) >= floor {
207 return Some(ratio);
208 }
209 }
210 None
211 }
212}
213
214impl Default for CompressionPredictor {
215 fn default() -> Self {
216 Self::new()
217 }
218}
219
220#[inline]
222fn sigmoid(x: f32) -> f32 {
223 1.0 / (1.0 + (-x).exp())
224}
225
226#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn sigmoid_at_zero_is_half() {
234 assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
235 }
236
237 #[test]
238 fn predict_with_default_weights_returns_valid_range() {
239 let model = CompressionPredictor::new();
240 let features = CompressionFeatures {
241 compression_ratio: 0.5,
242 message_count: 0.5,
243 avg_message_length: 0.5,
244 tool_output_fraction: 0.2,
245 };
246 let p = model.predict(&features);
247 assert!(p > 0.0 && p < 1.0, "prediction must be in (0, 1): {p}");
248 }
249
250 #[test]
251 fn is_cold_start_true_when_no_training() {
252 let model = CompressionPredictor::new();
253 assert!(model.is_cold_start(10));
254 }
255
256 #[test]
257 fn is_cold_start_false_after_training() {
258 let mut model = CompressionPredictor::new();
259 let features = CompressionFeatures {
260 compression_ratio: 0.5,
261 message_count: 0.5,
262 avg_message_length: 0.5,
263 tool_output_fraction: 0.2,
264 };
265 let samples: Vec<_> = (0..10).map(|_| (features, 0.7f32)).collect();
266 model.train(&samples);
267 assert!(!model.is_cold_start(10));
268 }
269
270 #[test]
271 fn training_on_high_ratio_improves_high_ratio_prediction() {
272 let mut model = CompressionPredictor::new();
273 let high_ratio = CompressionFeatures {
274 compression_ratio: 0.8,
275 message_count: 0.5,
276 avg_message_length: 0.5,
277 tool_output_fraction: 0.2,
278 };
279 let initial_pred = model.predict(&high_ratio);
280 let samples: Vec<_> = (0..50).map(|_| (high_ratio, 0.9f32)).collect();
281 model.train(&samples);
282 let trained_pred = model.predict(&high_ratio);
283 assert!(
284 trained_pred > initial_pred,
285 "training on high target must increase prediction: {initial_pred} -> {trained_pred}"
286 );
287 }
288
289 #[test]
290 fn select_ratio_returns_most_aggressive_passing_ratio() {
291 let mut model = CompressionPredictor::new();
292 let good = CompressionFeatures {
294 compression_ratio: 0.8,
295 message_count: 0.5,
296 avg_message_length: 0.5,
297 tool_output_fraction: 0.2,
298 };
299 let bad = CompressionFeatures {
300 compression_ratio: 0.2,
301 message_count: 0.5,
302 avg_message_length: 0.5,
303 tool_output_fraction: 0.2,
304 };
305 let samples: Vec<(CompressionFeatures, f32)> = (0..100)
306 .map(|i| if i % 2 == 0 { (good, 0.95) } else { (bad, 0.1) })
307 .collect();
308 model.train(&samples);
309
310 let ratios = vec![0.2, 0.4, 0.6, 0.8, 0.9];
311 let selected = model.select_ratio(0.5, &ratios, 0.5, 0.5, 0.2);
312 if let Some(r) = selected {
314 let features = CompressionFeatures {
315 compression_ratio: r,
316 message_count: 0.5,
317 avg_message_length: 0.5,
318 tool_output_fraction: 0.2,
319 };
320 assert!(
321 model.predict(&features) >= 0.5,
322 "selected ratio must clear floor"
323 );
324 }
325 }
326
327 #[test]
328 fn select_ratio_returns_none_when_nothing_passes() {
329 let model = CompressionPredictor::new();
330 let ratios = vec![0.1, 0.2];
333 let selected = model.select_ratio(0.99, &ratios, 0.5, 0.5, 0.2);
334 assert!(
335 selected.is_none(),
336 "expected None when no ratio clears the floor"
337 );
338 }
339
340 #[test]
341 fn sample_count_increments_after_training() {
342 let mut model = CompressionPredictor::new();
343 assert_eq!(model.sample_count(), 0);
344 let features = CompressionFeatures {
345 compression_ratio: 0.5,
346 message_count: 0.5,
347 avg_message_length: 0.5,
348 tool_output_fraction: 0.2,
349 };
350 model.train(&[(features, 0.7)]);
351 assert_eq!(model.sample_count(), 1);
352 model.train(&[(features, 0.6), (features, 0.8)]);
353 assert_eq!(model.sample_count(), 3);
354 }
355
356 #[test]
357 fn serialize_roundtrip() {
358 let model = CompressionPredictor::new();
359 let json = model.serialize().expect("serialize");
360 let weights: CompressionModelWeights = serde_json::from_str(&json).expect("deserialize");
361 assert_eq!(weights.weights.len(), model.weights.weights.len());
362 assert!((weights.bias - model.weights.bias).abs() < 1e-6);
363 }
364}