Skip to main content

zeph_memory/store/
admission_training.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! SQLite-backed store for RL admission control training data (#2416).
5//!
6//! Records ALL messages seen by A-MAC (admitted and rejected) to avoid survivorship
7//! bias in the logistic regression model (critic fix C3). `was_recalled` is set to 1
8//! when `SemanticMemory::recall()` returns the message, providing positive training signal.
9
10#[allow(unused_imports)]
11use zeph_db::sql;
12
13use crate::error::MemoryError;
14use crate::store::SqliteStore;
15use crate::types::{ConversationId, MessageId};
16
17/// Input for recording a single RL admission training sample.
18pub struct AdmissionTrainingInput<'a> {
19    pub message_id: Option<MessageId>,
20    pub conversation_id: ConversationId,
21    pub content: &'a str,
22    pub role: &'a str,
23    pub composite_score: f32,
24    pub was_admitted: bool,
25    pub features_json: &'a str,
26}
27
28/// A single training record for the RL admission model.
29#[derive(Debug, Clone)]
30pub struct AdmissionTrainingRecord {
31    pub id: i64,
32    pub message_id: Option<i64>,
33    pub conversation_id: ConversationId,
34    pub content_hash: String,
35    pub role: String,
36    pub composite_score: f32,
37    pub was_admitted: bool,
38    pub was_recalled: bool,
39    pub features_json: String,
40    pub created_at: String,
41}
42
43/// Compute a stable 16-char hex hash of `content` for deduplication.
44///
45/// Uses the first 8 bytes of SHA-256 truncated to a 16-char hex string.
46/// SHA-256 output is stable across Rust toolchain versions, unlike `DefaultHasher`.
47#[must_use]
48pub fn content_hash(content: &str) -> String {
49    use sha2::{Digest, Sha256};
50    let digest = Sha256::digest(content.as_bytes());
51    let mut bytes = [0u8; 8];
52    bytes.copy_from_slice(&digest[..8]);
53    format!("{:016x}", u64::from_be_bytes(bytes))
54}
55
56impl SqliteStore {
57    /// Record a message in the RL admission training data.
58    ///
59    /// Called for BOTH admitted and rejected messages so the model sees both classes.
60    /// `message_id` is `None` for rejected messages (never persisted to `messages` table).
61    /// `features_json` is the JSON-serialized feature vector used for training.
62    ///
63    /// # Errors
64    ///
65    /// Returns an error if the database insert fails.
66    pub async fn record_admission_training(
67        &self,
68        input: AdmissionTrainingInput<'_>,
69    ) -> Result<i64, MemoryError> {
70        let hash = content_hash(input.content);
71        let admitted_i = i64::from(input.was_admitted);
72        let msg_id = input.message_id.map(|m| m.0);
73        let (conversation_id, role, composite_score, features_json) = (
74            input.conversation_id,
75            input.role,
76            input.composite_score,
77            input.features_json,
78        );
79        let id = zeph_db::query_scalar(sql!(
80            "INSERT INTO admission_training_data \
81             (message_id, conversation_id, content_hash, role, composite_score, \
82              was_admitted, was_recalled, features_json) \
83             VALUES (?, ?, ?, ?, ?, ?, 0, ?) \
84             RETURNING id"
85        ))
86        .bind(msg_id)
87        .bind(conversation_id.0)
88        .bind(hash)
89        .bind(role)
90        .bind(f64::from(composite_score))
91        .bind(admitted_i)
92        .bind(features_json)
93        .fetch_one(&self.pool)
94        .await?;
95        Ok(id)
96    }
97
98    /// Mark training records as recalled for the given message IDs.
99    ///
100    /// Called after `batch_increment_access_count()` in `SemanticMemory::recall()`.
101    /// Sets `was_recalled = 1` and updates `updated_at` for all matching records.
102    ///
103    /// # Errors
104    ///
105    /// Returns an error if the database update fails.
106    pub async fn mark_training_recalled(
107        &self,
108        message_ids: &[MessageId],
109    ) -> Result<(), MemoryError> {
110        if message_ids.is_empty() {
111            return Ok(());
112        }
113        let placeholders: String = message_ids
114            .iter()
115            .map(|_| "?")
116            .collect::<Vec<_>>()
117            .join(",");
118        let query = format!(
119            "UPDATE admission_training_data \
120             SET was_recalled = 1, updated_at = datetime('now') \
121             WHERE message_id IN ({placeholders})"
122        );
123        let mut q = zeph_db::query(&query);
124        for id in message_ids {
125            q = q.bind(id.0);
126        }
127        q.execute(&self.pool).await?;
128        Ok(())
129    }
130
131    /// Count total training records (admitted + rejected).
132    ///
133    /// # Errors
134    ///
135    /// Returns an error if the database query fails.
136    pub async fn count_training_records(&self) -> Result<i64, MemoryError> {
137        let count = zeph_db::query_scalar(sql!("SELECT COUNT(*) FROM admission_training_data"))
138            .fetch_one(&self.pool)
139            .await?;
140        Ok(count)
141    }
142
143    /// Get a batch of training records for model training.
144    ///
145    /// Returns up to `limit` records ordered by creation time (oldest first).
146    ///
147    /// # Errors
148    ///
149    /// Returns an error if the database query fails.
150    pub async fn get_training_batch(
151        &self,
152        limit: usize,
153    ) -> Result<Vec<AdmissionTrainingRecord>, MemoryError> {
154        let limit = i64::try_from(limit).unwrap_or(i64::MAX);
155        let rows = zeph_db::query_as::<
156            _,
157            (
158                i64,
159                Option<i64>,
160                i64,
161                String,
162                String,
163                f64,
164                i64,
165                i64,
166                String,
167                String,
168            ),
169        >(sql!(
170            "SELECT id, message_id, conversation_id, content_hash, role, \
171                    composite_score, was_admitted, was_recalled, features_json, created_at \
172             FROM admission_training_data \
173             ORDER BY created_at ASC \
174             LIMIT ?"
175        ))
176        .bind(limit)
177        .fetch_all(&self.pool)
178        .await?;
179
180        Ok(rows
181            .into_iter()
182            .map(
183                |(id, msg_id, cid, hash, role, score, admitted, recalled, features, created_at)| {
184                    AdmissionTrainingRecord {
185                        id,
186                        message_id: msg_id,
187                        conversation_id: ConversationId(cid),
188                        content_hash: hash,
189                        role,
190                        #[expect(clippy::cast_possible_truncation)]
191                        composite_score: score as f32,
192                        was_admitted: admitted != 0,
193                        was_recalled: recalled != 0,
194                        features_json: features,
195                        created_at,
196                    }
197                },
198            )
199            .collect())
200    }
201
202    /// Delete old training records, keeping the most recent `keep_recent`.
203    ///
204    /// Called after each retraining cycle to prevent unbounded table growth.
205    ///
206    /// # Errors
207    ///
208    /// Returns an error if the database delete fails.
209    // TODO(#2416): call cleanup_old_training_data() in the RL retrain loop scheduled in
210    // bootstrap/mod.rs once the retrain loop is wired.
211    pub async fn cleanup_old_training_data(&self, keep_recent: usize) -> Result<(), MemoryError> {
212        let keep = i64::try_from(keep_recent).unwrap_or(i64::MAX);
213        zeph_db::query(sql!(
214            "DELETE FROM admission_training_data \
215             WHERE id NOT IN ( \
216                 SELECT id FROM admission_training_data \
217                 ORDER BY created_at DESC \
218                 LIMIT ? \
219             )"
220        ))
221        .bind(keep)
222        .execute(&self.pool)
223        .await?;
224        Ok(())
225    }
226
227    /// Save trained RL model weights to `SQLite` for persistence across restarts.
228    ///
229    /// Uses a fixed `id = 1` row (INSERT OR REPLACE) so the table never grows beyond
230    /// one row — avoiding unbounded growth from repeated retrain cycles.
231    ///
232    /// # Errors
233    ///
234    /// Returns an error if the database upsert fails.
235    pub async fn save_rl_weights(
236        &self,
237        weights_json: &str,
238        sample_count: i64,
239    ) -> Result<(), MemoryError> {
240        zeph_db::query(sql!(
241            "INSERT OR REPLACE INTO admission_rl_weights (id, weights_json, sample_count) \
242             VALUES (1, ?, ?)"
243        ))
244        .bind(weights_json)
245        .bind(sample_count)
246        .execute(&self.pool)
247        .await?;
248        Ok(())
249    }
250
251    /// Load the latest RL model weights from `SQLite`.
252    ///
253    /// Returns `None` if no weights have been saved yet.
254    ///
255    /// # Errors
256    ///
257    /// Returns an error if the database query fails.
258    pub async fn load_rl_weights(&self) -> Result<Option<(String, i64)>, MemoryError> {
259        let row: Option<(String, i64)> = zeph_db::query_as(sql!(
260            "SELECT weights_json, sample_count FROM admission_rl_weights \
261             ORDER BY id DESC LIMIT 1"
262        ))
263        .fetch_optional(&self.pool)
264        .await?;
265        Ok(row)
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    async fn make_store() -> (SqliteStore, i64) {
274        let store = SqliteStore::new(":memory:")
275            .await
276            .expect("SqliteStore::new");
277        let cid = store
278            .create_conversation()
279            .await
280            .expect("create_conversation");
281        (store, cid.0)
282    }
283
284    #[tokio::test]
285    async fn record_and_count_training_data() {
286        let (store, cid) = make_store().await;
287        let cid = ConversationId(cid);
288        store
289            .record_admission_training(AdmissionTrainingInput {
290                message_id: None,
291                conversation_id: cid,
292                content: "content",
293                role: "user",
294                composite_score: 0.5,
295                was_admitted: false,
296                features_json: "[]",
297            })
298            .await
299            .expect("record rejected");
300        store
301            .record_admission_training(AdmissionTrainingInput {
302                message_id: Some(MessageId(1)),
303                conversation_id: cid,
304                content: "content2",
305                role: "assistant",
306                composite_score: 0.8,
307                was_admitted: true,
308                features_json: "[]",
309            })
310            .await
311            .expect("record admitted");
312        let count = store.count_training_records().await.expect("count");
313        assert_eq!(count, 2);
314    }
315
316    #[tokio::test]
317    async fn mark_recalled_sets_flag() {
318        let (store, cid) = make_store().await;
319        let cid = ConversationId(cid);
320        store
321            .record_admission_training(AdmissionTrainingInput {
322                message_id: Some(MessageId(42)),
323                conversation_id: cid,
324                content: "recalled content",
325                role: "user",
326                composite_score: 0.7,
327                was_admitted: true,
328                features_json: "[]",
329            })
330            .await
331            .expect("record");
332        store
333            .mark_training_recalled(&[MessageId(42)])
334            .await
335            .expect("mark recalled");
336        let batch = store.get_training_batch(10).await.expect("batch");
337        assert_eq!(batch.len(), 1);
338        assert!(
339            batch[0].was_recalled,
340            "was_recalled must be true after marking"
341        );
342    }
343
344    #[tokio::test]
345    async fn rejected_message_has_no_message_id() {
346        let (store, cid) = make_store().await;
347        let cid = ConversationId(cid);
348        store
349            .record_admission_training(AdmissionTrainingInput {
350                message_id: None,
351                conversation_id: cid,
352                content: "rejected",
353                role: "user",
354                composite_score: 0.2,
355                was_admitted: false,
356                features_json: "[]",
357            })
358            .await
359            .expect("record");
360        let batch = store.get_training_batch(10).await.expect("batch");
361        assert_eq!(batch.len(), 1);
362        assert!(!batch[0].was_admitted);
363        assert!(batch[0].message_id.is_none());
364    }
365
366    #[tokio::test]
367    async fn cleanup_trims_old_records() {
368        let (store, cid) = make_store().await;
369        let cid = ConversationId(cid);
370        for i in 0..5_i64 {
371            let content = format!("content {i}");
372            store
373                .record_admission_training(AdmissionTrainingInput {
374                    message_id: Some(MessageId(i)),
375                    conversation_id: cid,
376                    content: &content,
377                    role: "user",
378                    composite_score: 0.5,
379                    was_admitted: true,
380                    features_json: "[]",
381                })
382                .await
383                .expect("record");
384        }
385        // Keep only 2 most recent.
386        store.cleanup_old_training_data(2).await.expect("cleanup");
387        let count = store.count_training_records().await.expect("count");
388        assert_eq!(count, 2);
389    }
390
391    #[tokio::test]
392    async fn save_and_load_rl_weights() {
393        let (store, _) = make_store().await;
394        store
395            .save_rl_weights(r#"{"weights":[0.1,0.2],"bias":0.0}"#, 100)
396            .await
397            .expect("save");
398        let loaded = store.load_rl_weights().await.expect("load");
399        assert!(loaded.is_some());
400        let (json, count) = loaded.unwrap();
401        assert!(json.contains("weights"));
402        assert_eq!(count, 100);
403    }
404
405    #[tokio::test]
406    async fn load_rl_weights_returns_none_when_empty() {
407        let (store, _) = make_store().await;
408        let loaded = store.load_rl_weights().await.expect("load");
409        assert!(loaded.is_none());
410    }
411
412    #[test]
413    fn content_hash_is_deterministic() {
414        let h1 = content_hash("hello world");
415        let h2 = content_hash("hello world");
416        assert_eq!(h1, h2);
417    }
418
419    #[test]
420    fn content_hash_differs_for_different_content() {
421        let h1 = content_hash("hello");
422        let h2 = content_hash("world");
423        assert_ne!(h1, h2);
424    }
425}