1use serde::{Deserialize, Serialize};
19
20const LEARNING_RATE: f32 = 0.01;
21
22#[derive(Debug, Clone, Copy)]
26pub struct AdmissionFeatures {
27 pub factual_confidence: f32,
28 pub semantic_novelty: f32,
29 pub content_type_prior: f32,
30 pub content_length_bucket: f32,
32 pub role_encoding: f32,
34}
35
36impl AdmissionFeatures {
37 #[must_use]
39 pub fn encode_role(role: &str) -> f32 {
40 match role {
41 "user" => 0.7,
42 "assistant" => 0.6,
43 "tool" | "tool_result" => 0.8,
44 "system" => 0.3,
45 _ => 0.5,
46 }
47 }
48
49 #[must_use]
51 pub fn encode_length(content_len: usize) -> f32 {
52 if content_len < 100 {
53 0.0
54 } else if content_len < 1000 {
55 0.5
56 } else {
57 1.0
58 }
59 }
60
61 #[must_use]
63 pub fn to_vec(&self) -> [f32; 5] {
64 [
65 self.factual_confidence,
66 self.semantic_novelty,
67 self.content_type_prior,
68 self.content_length_bucket,
69 self.role_encoding,
70 ]
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct RlModelWeights {
77 pub weights: Vec<f32>,
79 pub bias: f32,
81 pub sample_count: u64,
83}
84
85impl Default for RlModelWeights {
86 fn default() -> Self {
87 Self {
90 weights: vec![0.1, 0.1, 0.1, 0.05, 0.1],
91 bias: 0.0,
92 sample_count: 0,
93 }
94 }
95}
96
97pub struct RlAdmissionModel {
99 weights: RlModelWeights,
100}
101
102impl RlAdmissionModel {
103 #[must_use]
105 pub fn new() -> Self {
106 Self {
107 weights: RlModelWeights::default(),
108 }
109 }
110
111 #[must_use]
113 pub fn from_weights(weights: RlModelWeights) -> Self {
114 Self { weights }
115 }
116
117 pub fn serialize(&self) -> Result<String, serde_json::Error> {
123 serde_json::to_string(&self.weights)
124 }
125
126 #[must_use]
128 pub fn sample_count(&self) -> u64 {
129 self.weights.sample_count
130 }
131
132 #[must_use]
136 pub fn predict(&self, features: &AdmissionFeatures) -> f32 {
137 let x = features.to_vec();
138 let dot: f32 = x
139 .iter()
140 .zip(self.weights.weights.iter())
141 .map(|(xi, wi)| xi * wi)
142 .sum();
143 sigmoid(dot + self.weights.bias)
144 }
145
146 pub fn train(&mut self, samples: &[(AdmissionFeatures, f32)]) {
153 const EPOCHS: usize = 50;
154
155 if samples.is_empty() {
156 return;
157 }
158
159 let n_features = self.weights.weights.len();
160 #[allow(clippy::cast_precision_loss)]
161 let n = samples.len() as f32;
162
163 for _ in 0..EPOCHS {
164 let mut grad_w = vec![0.0f32; n_features];
165 let mut grad_b = 0.0f32;
166
167 for (features, label) in samples {
168 let pred = self.predict(features);
169 let error = pred - label;
170 let x = features.to_vec();
171 for (i, xi) in x.iter().enumerate() {
172 grad_w[i] += error * xi;
173 }
174 grad_b += error;
175 }
176
177 for (wi, gi) in self.weights.weights.iter_mut().zip(grad_w.iter()) {
178 *wi -= LEARNING_RATE * gi / n;
179 }
180 self.weights.bias -= LEARNING_RATE * grad_b / n;
181 }
182
183 self.weights.sample_count += samples.len() as u64;
185 }
186}
187
188impl Default for RlAdmissionModel {
189 fn default() -> Self {
190 Self::new()
191 }
192}
193
194#[inline]
196fn sigmoid(x: f32) -> f32 {
197 1.0 / (1.0 + (-x).exp())
198}
199
200#[must_use]
204pub fn parse_training_samples(
205 records: &[crate::store::admission_training::AdmissionTrainingRecord],
206) -> Vec<(AdmissionFeatures, f32)> {
207 records
208 .iter()
209 .filter_map(|r| {
210 let arr: Vec<f32> = serde_json::from_str(&r.features_json).ok()?;
213 if arr.len() < 5 {
214 return None;
215 }
216 let features = AdmissionFeatures {
217 factual_confidence: arr[0],
218 semantic_novelty: arr[1],
219 content_type_prior: arr[2],
220 content_length_bucket: arr[3],
221 role_encoding: arr[4],
222 };
223 let label = if r.was_recalled { 1.0f32 } else { 0.0f32 };
224 Some((features, label))
225 })
226 .collect()
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn sigmoid_at_zero_is_half() {
235 assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
236 }
237
238 #[test]
239 fn sigmoid_large_positive_approaches_one() {
240 assert!(sigmoid(10.0) > 0.99);
241 }
242
243 #[test]
244 fn sigmoid_large_negative_approaches_zero() {
245 assert!(sigmoid(-10.0) < 0.01);
246 }
247
248 #[test]
249 fn predict_with_default_weights_is_near_half() {
250 let model = RlAdmissionModel::new();
251 let features = AdmissionFeatures {
252 factual_confidence: 0.5,
253 semantic_novelty: 0.5,
254 content_type_prior: 0.5,
255 content_length_bucket: 0.5,
256 role_encoding: 0.5,
257 };
258 let p = model.predict(&features);
259 assert!(p > 0.0 && p < 1.0, "prediction must be in (0, 1): {p}");
260 }
261
262 #[test]
263 fn train_updates_weights() {
264 let mut model = RlAdmissionModel::new();
265 let features = AdmissionFeatures {
266 factual_confidence: 1.0,
267 semantic_novelty: 1.0,
268 content_type_prior: 1.0,
269 content_length_bucket: 1.0,
270 role_encoding: 1.0,
271 };
272 let initial_pred = model.predict(&features);
273 let samples: Vec<_> = (0..100).map(|_| (features, 1.0f32)).collect();
275 model.train(&samples);
276 let trained_pred = model.predict(&features);
277 assert!(
278 trained_pred > initial_pred,
279 "training on positive labels must increase prediction: {initial_pred} -> {trained_pred}"
280 );
281 }
282
283 #[test]
284 fn sample_count_increments_after_training() {
285 let mut model = RlAdmissionModel::new();
286 assert_eq!(model.sample_count(), 0);
287 let features = AdmissionFeatures {
288 factual_confidence: 0.5,
289 semantic_novelty: 0.5,
290 content_type_prior: 0.5,
291 content_length_bucket: 0.5,
292 role_encoding: 0.5,
293 };
294 model.train(&[(features, 1.0)]);
295 assert_eq!(model.sample_count(), 1);
296 model.train(&[(features, 0.0), (features, 1.0)]);
297 assert_eq!(model.sample_count(), 3);
298 }
299
300 #[test]
301 fn serialize_and_deserialize_roundtrip() {
302 let model = RlAdmissionModel::new();
303 let json = model.serialize().expect("serialize");
304 let weights: RlModelWeights = serde_json::from_str(&json).expect("deserialize");
305 assert_eq!(weights.weights.len(), model.weights.weights.len());
306 assert!((weights.bias - model.weights.bias).abs() < 1e-6);
307 }
308
309 #[test]
310 fn encode_role_matches_content_type_prior() {
311 assert!((AdmissionFeatures::encode_role("user") - 0.7).abs() < 0.01);
312 assert!((AdmissionFeatures::encode_role("tool") - 0.8).abs() < 0.01);
313 assert!((AdmissionFeatures::encode_role("assistant") - 0.6).abs() < 0.01);
314 assert!((AdmissionFeatures::encode_role("system") - 0.3).abs() < 0.01);
315 }
316
317 #[test]
318 fn encode_length_buckets() {
319 assert!((AdmissionFeatures::encode_length(50) - 0.0).abs() < 0.01);
320 assert!((AdmissionFeatures::encode_length(500) - 0.5).abs() < 0.01);
321 assert!((AdmissionFeatures::encode_length(5000) - 1.0).abs() < 0.01);
322 }
323
324 #[test]
325 fn parse_training_samples_skips_short_arrays() {
326 use crate::store::admission_training::AdmissionTrainingRecord;
327 use crate::types::ConversationId;
328 let records = vec![AdmissionTrainingRecord {
329 id: 1,
330 message_id: None,
331 conversation_id: ConversationId(1),
332 content_hash: "abc".into(),
333 role: "user".into(),
334 composite_score: 0.5,
335 was_admitted: false,
336 was_recalled: false,
337 features_json: "[0.5, 0.5]".into(), created_at: "2026-01-01".into(),
339 }];
340 let samples = parse_training_samples(&records);
341 assert!(samples.is_empty(), "short array must be skipped");
342 }
343
344 #[test]
345 fn parse_training_samples_valid_record() {
346 use crate::store::admission_training::AdmissionTrainingRecord;
347 use crate::types::ConversationId;
348 let records = vec![AdmissionTrainingRecord {
349 id: 1,
350 message_id: Some(42),
351 conversation_id: ConversationId(1),
352 content_hash: "abc".into(),
353 role: "user".into(),
354 composite_score: 0.7,
355 was_admitted: true,
356 was_recalled: true,
357 features_json: "[0.9, 0.8, 0.7, 0.5, 0.7]".into(),
358 created_at: "2026-01-01".into(),
359 }];
360 let samples = parse_training_samples(&records);
361 assert_eq!(samples.len(), 1);
362 let (_, label) = samples[0];
363 assert!((label - 1.0).abs() < 1e-6, "recalled → label 1.0");
364 }
365}