Skip to main content

zeph_memory/
admission_rl.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Lightweight logistic regression model for RL-based admission control (#2416).
5//!
6//! This module provides a pure-Rust binary classifier trained on `(features, was_recalled)`
7//! pairs from `admission_training_data`. It replaces the LLM-based `future_utility` factor
8//! when enough training data is available.
9//!
10//! # Design
11//!
12//! - No external ML crate dependencies — pure f32 arithmetic
13//! - 5 input features matching the A-MAC factor vector (see [`AdmissionFeatures`])
14//! - Mini-batch gradient descent on log-loss
15//! - Persisted as JSON via the `admission_rl_weights` `SQLite` table
16//! - Falls back to heuristic when `sample_count < rl_min_samples`
17
18use serde::{Deserialize, Serialize};
19
20const LEARNING_RATE: f32 = 0.01;
21
22/// Input feature vector for the RL admission model.
23///
24/// Matches the factor order in [`crate::admission::AdmissionFactors`].
25#[derive(Debug, Clone, Copy)]
26pub struct AdmissionFeatures {
27    pub factual_confidence: f32,
28    pub semantic_novelty: f32,
29    pub content_type_prior: f32,
30    /// Encoded content length bucket: `0.0` = short (<100 chars), `0.5` = medium, `1.0` = long.
31    pub content_length_bucket: f32,
32    /// Encoded role: user=0.7, assistant=0.6, tool=0.8, system=0.3, other=0.5.
33    pub role_encoding: f32,
34}
35
36impl AdmissionFeatures {
37    /// Encode the role as a float matching `compute_content_type_prior` values.
38    #[must_use]
39    pub fn encode_role(role: &str) -> f32 {
40        match role {
41            "user" => 0.7,
42            "assistant" => 0.6,
43            "tool" | "tool_result" => 0.8,
44            "system" => 0.3,
45            _ => 0.5,
46        }
47    }
48
49    /// Encode content length into a 3-bucket float.
50    #[must_use]
51    pub fn encode_length(content_len: usize) -> f32 {
52        if content_len < 100 {
53            0.0
54        } else if content_len < 1000 {
55            0.5
56        } else {
57            1.0
58        }
59    }
60
61    /// Convert to a fixed-length slice for model arithmetic.
62    #[must_use]
63    pub fn to_vec(&self) -> [f32; 5] {
64        [
65            self.factual_confidence,
66            self.semantic_novelty,
67            self.content_type_prior,
68            self.content_length_bucket,
69            self.role_encoding,
70        ]
71    }
72}
73
74/// Persisted model state (JSON-serializable).
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct RlModelWeights {
77    /// Weight vector (one per feature).
78    pub weights: Vec<f32>,
79    /// Bias term.
80    pub bias: f32,
81    /// Number of training samples used.
82    pub sample_count: u64,
83}
84
85impl Default for RlModelWeights {
86    fn default() -> Self {
87        // Initialize to small random-ish values to break symmetry.
88        // Deterministic for reproducibility.
89        Self {
90            weights: vec![0.1, 0.1, 0.1, 0.05, 0.1],
91            bias: 0.0,
92            sample_count: 0,
93        }
94    }
95}
96
97/// Logistic regression admission model.
98pub struct RlAdmissionModel {
99    weights: RlModelWeights,
100}
101
102impl RlAdmissionModel {
103    /// Create a new model with default (untrained) weights.
104    #[must_use]
105    pub fn new() -> Self {
106        Self {
107            weights: RlModelWeights::default(),
108        }
109    }
110
111    /// Create a model from saved weights.
112    #[must_use]
113    pub fn from_weights(weights: RlModelWeights) -> Self {
114        Self { weights }
115    }
116
117    /// Serialize current weights to JSON.
118    ///
119    /// # Errors
120    ///
121    /// Returns an error if serialization fails.
122    pub fn serialize(&self) -> Result<String, serde_json::Error> {
123        serde_json::to_string(&self.weights)
124    }
125
126    /// Return the current sample count.
127    #[must_use]
128    pub fn sample_count(&self) -> u64 {
129        self.weights.sample_count
130    }
131
132    /// Predict recall probability for the given features. Range: `[0.0, 1.0]`.
133    ///
134    /// Uses the logistic (sigmoid) function: `sigma(w^T x + b)`.
135    #[must_use]
136    pub fn predict(&self, features: &AdmissionFeatures) -> f32 {
137        let x = features.to_vec();
138        let dot: f32 = x
139            .iter()
140            .zip(self.weights.weights.iter())
141            .map(|(xi, wi)| xi * wi)
142            .sum();
143        sigmoid(dot + self.weights.bias)
144    }
145
146    /// Train the model on a batch of (features, label) pairs using mini-batch gradient descent.
147    ///
148    /// `label = 1.0` when the message was recalled (positive), `0.0` when not recalled (negative).
149    /// Runs `EPOCHS` gradient steps over the full batch so the model converges meaningfully.
150    /// `sample_count` is incremented once per call (not per epoch) to avoid over-counting.
151    /// Learning rate is fixed at `0.01` — suitable for the expected data scale.
152    pub fn train(&mut self, samples: &[(AdmissionFeatures, f32)]) {
153        const EPOCHS: usize = 50;
154
155        if samples.is_empty() {
156            return;
157        }
158
159        let n_features = self.weights.weights.len();
160        #[allow(clippy::cast_precision_loss)]
161        let n = samples.len() as f32;
162
163        for _ in 0..EPOCHS {
164            let mut grad_w = vec![0.0f32; n_features];
165            let mut grad_b = 0.0f32;
166
167            for (features, label) in samples {
168                let pred = self.predict(features);
169                let error = pred - label;
170                let x = features.to_vec();
171                for (i, xi) in x.iter().enumerate() {
172                    grad_w[i] += error * xi;
173                }
174                grad_b += error;
175            }
176
177            for (wi, gi) in self.weights.weights.iter_mut().zip(grad_w.iter()) {
178                *wi -= LEARNING_RATE * gi / n;
179            }
180            self.weights.bias -= LEARNING_RATE * grad_b / n;
181        }
182
183        // Increment once per train() call, not per epoch.
184        self.weights.sample_count += samples.len() as u64;
185    }
186}
187
188impl Default for RlAdmissionModel {
189    fn default() -> Self {
190        Self::new()
191    }
192}
193
194/// Logistic (sigmoid) activation function.
195#[inline]
196fn sigmoid(x: f32) -> f32 {
197    1.0 / (1.0 + (-x).exp())
198}
199
200/// Parse training samples from `SQLite` records for model training.
201///
202/// Returns `(features, label)` pairs where `label = 1.0` if `was_recalled`.
203#[must_use]
204pub fn parse_training_samples(
205    records: &[crate::store::admission_training::AdmissionTrainingRecord],
206) -> Vec<(AdmissionFeatures, f32)> {
207    records
208        .iter()
209        .filter_map(|r| {
210            // Parse features_json: ["factual_confidence", "semantic_novelty",
211            //                        "content_type_prior", "length_bucket", "role_encoding"]
212            let arr: Vec<f32> = serde_json::from_str(&r.features_json).ok()?;
213            if arr.len() < 5 {
214                return None;
215            }
216            let features = AdmissionFeatures {
217                factual_confidence: arr[0],
218                semantic_novelty: arr[1],
219                content_type_prior: arr[2],
220                content_length_bucket: arr[3],
221                role_encoding: arr[4],
222            };
223            let label = if r.was_recalled { 1.0f32 } else { 0.0f32 };
224            Some((features, label))
225        })
226        .collect()
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn sigmoid_at_zero_is_half() {
235        assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
236    }
237
238    #[test]
239    fn sigmoid_large_positive_approaches_one() {
240        assert!(sigmoid(10.0) > 0.99);
241    }
242
243    #[test]
244    fn sigmoid_large_negative_approaches_zero() {
245        assert!(sigmoid(-10.0) < 0.01);
246    }
247
248    #[test]
249    fn predict_with_default_weights_is_near_half() {
250        let model = RlAdmissionModel::new();
251        let features = AdmissionFeatures {
252            factual_confidence: 0.5,
253            semantic_novelty: 0.5,
254            content_type_prior: 0.5,
255            content_length_bucket: 0.5,
256            role_encoding: 0.5,
257        };
258        let p = model.predict(&features);
259        assert!(p > 0.0 && p < 1.0, "prediction must be in (0, 1): {p}");
260    }
261
262    #[test]
263    fn train_updates_weights() {
264        let mut model = RlAdmissionModel::new();
265        let features = AdmissionFeatures {
266            factual_confidence: 1.0,
267            semantic_novelty: 1.0,
268            content_type_prior: 1.0,
269            content_length_bucket: 1.0,
270            role_encoding: 1.0,
271        };
272        let initial_pred = model.predict(&features);
273        // Train with all positive labels — model should increase weights over time.
274        let samples: Vec<_> = (0..100).map(|_| (features, 1.0f32)).collect();
275        model.train(&samples);
276        let trained_pred = model.predict(&features);
277        assert!(
278            trained_pred > initial_pred,
279            "training on positive labels must increase prediction: {initial_pred} -> {trained_pred}"
280        );
281    }
282
283    #[test]
284    fn sample_count_increments_after_training() {
285        let mut model = RlAdmissionModel::new();
286        assert_eq!(model.sample_count(), 0);
287        let features = AdmissionFeatures {
288            factual_confidence: 0.5,
289            semantic_novelty: 0.5,
290            content_type_prior: 0.5,
291            content_length_bucket: 0.5,
292            role_encoding: 0.5,
293        };
294        model.train(&[(features, 1.0)]);
295        assert_eq!(model.sample_count(), 1);
296        model.train(&[(features, 0.0), (features, 1.0)]);
297        assert_eq!(model.sample_count(), 3);
298    }
299
300    #[test]
301    fn serialize_and_deserialize_roundtrip() {
302        let model = RlAdmissionModel::new();
303        let json = model.serialize().expect("serialize");
304        let weights: RlModelWeights = serde_json::from_str(&json).expect("deserialize");
305        assert_eq!(weights.weights.len(), model.weights.weights.len());
306        assert!((weights.bias - model.weights.bias).abs() < 1e-6);
307    }
308
309    #[test]
310    fn encode_role_matches_content_type_prior() {
311        assert!((AdmissionFeatures::encode_role("user") - 0.7).abs() < 0.01);
312        assert!((AdmissionFeatures::encode_role("tool") - 0.8).abs() < 0.01);
313        assert!((AdmissionFeatures::encode_role("assistant") - 0.6).abs() < 0.01);
314        assert!((AdmissionFeatures::encode_role("system") - 0.3).abs() < 0.01);
315    }
316
317    #[test]
318    fn encode_length_buckets() {
319        assert!((AdmissionFeatures::encode_length(50) - 0.0).abs() < 0.01);
320        assert!((AdmissionFeatures::encode_length(500) - 0.5).abs() < 0.01);
321        assert!((AdmissionFeatures::encode_length(5000) - 1.0).abs() < 0.01);
322    }
323
324    #[test]
325    fn parse_training_samples_skips_short_arrays() {
326        use crate::store::admission_training::AdmissionTrainingRecord;
327        use crate::types::ConversationId;
328        let records = vec![AdmissionTrainingRecord {
329            id: 1,
330            message_id: None,
331            conversation_id: ConversationId(1),
332            content_hash: "abc".into(),
333            role: "user".into(),
334            composite_score: 0.5,
335            was_admitted: false,
336            was_recalled: false,
337            features_json: "[0.5, 0.5]".into(), // too short
338            created_at: "2026-01-01".into(),
339        }];
340        let samples = parse_training_samples(&records);
341        assert!(samples.is_empty(), "short array must be skipped");
342    }
343
344    #[test]
345    fn parse_training_samples_valid_record() {
346        use crate::store::admission_training::AdmissionTrainingRecord;
347        use crate::types::ConversationId;
348        let records = vec![AdmissionTrainingRecord {
349            id: 1,
350            message_id: Some(42),
351            conversation_id: ConversationId(1),
352            content_hash: "abc".into(),
353            role: "user".into(),
354            composite_score: 0.7,
355            was_admitted: true,
356            was_recalled: true,
357            features_json: "[0.9, 0.8, 0.7, 0.5, 0.7]".into(),
358            created_at: "2026-01-01".into(),
359        }];
360        let samples = parse_training_samples(&records);
361        assert_eq!(samples.len(), 1);
362        let (_, label) = samples[0];
363        assert!((label - 1.0).abs() < 1e-6, "recalled → label 1.0");
364    }
365}