Skip to main content

zeph_memory/store/
admission_training.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! SQLite-backed store for RL admission control training data (#2416).
5//!
6//! Records ALL messages seen by A-MAC (admitted and rejected) to avoid survivorship
7//! bias in the logistic regression model (critic fix C3). `was_recalled` is set to 1
8//! when `SemanticMemory::recall()` returns the message, providing positive training signal.
9
10#[allow(unused_imports)]
11use zeph_db::sql;
12
13use crate::error::MemoryError;
14use crate::store::SqliteStore;
15use crate::types::{ConversationId, MessageId};
16
17/// Input for recording a single RL admission training sample.
18pub struct AdmissionTrainingInput<'a> {
19    pub message_id: Option<MessageId>,
20    pub conversation_id: ConversationId,
21    pub content: &'a str,
22    pub role: &'a str,
23    pub composite_score: f32,
24    pub was_admitted: bool,
25    pub features_json: &'a str,
26}
27
28/// A single training record for the RL admission model.
29#[derive(Debug, Clone)]
30pub struct AdmissionTrainingRecord {
31    pub id: i64,
32    pub message_id: Option<i64>,
33    pub conversation_id: ConversationId,
34    pub content_hash: String,
35    pub role: String,
36    pub composite_score: f32,
37    pub was_admitted: bool,
38    pub was_recalled: bool,
39    pub features_json: String,
40    pub created_at: String,
41}
42
43/// Compute a stable 16-char hex hash of `content` for deduplication.
44///
45/// Uses the first 8 bytes of SHA-256 truncated to a 16-char hex string.
46/// SHA-256 output is stable across Rust toolchain versions, unlike `DefaultHasher`.
47#[must_use]
48pub fn content_hash(content: &str) -> String {
49    use sha2::{Digest, Sha256};
50    let digest = Sha256::digest(content.as_bytes());
51    let mut bytes = [0u8; 8];
52    bytes.copy_from_slice(&digest[..8]);
53    format!("{:016x}", u64::from_be_bytes(bytes))
54}
55
56impl SqliteStore {
57    /// Record a message in the RL admission training data.
58    ///
59    /// Called for BOTH admitted and rejected messages so the model sees both classes.
60    /// `message_id` is `None` for rejected messages (never persisted to `messages` table).
61    /// `features_json` is the JSON-serialized feature vector used for training.
62    ///
63    /// # Errors
64    ///
65    /// Returns an error if the database insert fails.
66    pub async fn record_admission_training(
67        &self,
68        input: AdmissionTrainingInput<'_>,
69    ) -> Result<i64, MemoryError> {
70        let hash = content_hash(input.content);
71        let admitted_i = i64::from(input.was_admitted);
72        let msg_id = input.message_id.map(|m| m.0);
73        let (conversation_id, role, composite_score, features_json) = (
74            input.conversation_id,
75            input.role,
76            input.composite_score,
77            input.features_json,
78        );
79        let id = zeph_db::query_scalar(sql!(
80            "INSERT INTO admission_training_data \
81             (message_id, conversation_id, content_hash, role, composite_score, \
82              was_admitted, was_recalled, features_json) \
83             VALUES (?, ?, ?, ?, ?, ?, 0, ?) \
84             RETURNING id"
85        ))
86        .bind(msg_id)
87        .bind(conversation_id.0)
88        .bind(hash)
89        .bind(role)
90        .bind(f64::from(composite_score))
91        .bind(admitted_i)
92        .bind(features_json)
93        .fetch_one(&self.pool)
94        .await?;
95        Ok(id)
96    }
97
98    /// Mark training records as recalled for the given message IDs.
99    ///
100    /// Called after `batch_increment_access_count()` in `SemanticMemory::recall()`.
101    /// Sets `was_recalled = 1` and updates `updated_at` for all matching records.
102    ///
103    /// # Errors
104    ///
105    /// Returns an error if the database update fails.
106    pub async fn mark_training_recalled(
107        &self,
108        message_ids: &[MessageId],
109    ) -> Result<(), MemoryError> {
110        if message_ids.is_empty() {
111            return Ok(());
112        }
113        let placeholders: String = message_ids
114            .iter()
115            .map(|_| "?")
116            .collect::<Vec<_>>()
117            .join(",");
118        let query = format!(
119            "UPDATE admission_training_data \
120             SET was_recalled = 1, updated_at = datetime('now') \
121             WHERE message_id IN ({placeholders})"
122        );
123        let mut q = zeph_db::query(&query);
124        for id in message_ids {
125            q = q.bind(id.0);
126        }
127        q.execute(&self.pool).await?;
128        Ok(())
129    }
130
131    /// Count total training records (admitted + rejected).
132    ///
133    /// # Errors
134    ///
135    /// Returns an error if the database query fails.
136    pub async fn count_training_records(&self) -> Result<i64, MemoryError> {
137        let count = zeph_db::query_scalar(sql!("SELECT COUNT(*) FROM admission_training_data"))
138            .fetch_one(&self.pool)
139            .await?;
140        Ok(count)
141    }
142
143    /// Get a batch of training records for model training.
144    ///
145    /// Returns up to `limit` records ordered by creation time (oldest first).
146    ///
147    /// # Errors
148    ///
149    /// Returns an error if the database query fails.
150    pub async fn get_training_batch(
151        &self,
152        limit: usize,
153    ) -> Result<Vec<AdmissionTrainingRecord>, MemoryError> {
154        let limit = i64::try_from(limit).unwrap_or(i64::MAX);
155        let rows = zeph_db::query_as::<
156            _,
157            (
158                i64,
159                Option<i64>,
160                i64,
161                String,
162                String,
163                f64,
164                i64,
165                i64,
166                String,
167                String,
168            ),
169        >(sql!(
170            "SELECT id, message_id, conversation_id, content_hash, role, \
171                    composite_score, was_admitted, was_recalled, features_json, created_at \
172             FROM admission_training_data \
173             ORDER BY created_at ASC \
174             LIMIT ?"
175        ))
176        .bind(limit)
177        .fetch_all(&self.pool)
178        .await?;
179
180        Ok(rows
181            .into_iter()
182            .map(
183                |(id, msg_id, cid, hash, role, score, admitted, recalled, features, created_at)| {
184                    AdmissionTrainingRecord {
185                        id,
186                        message_id: msg_id,
187                        conversation_id: ConversationId(cid),
188                        content_hash: hash,
189                        role,
190                        #[expect(clippy::cast_possible_truncation)]
191                        composite_score: score as f32,
192                        was_admitted: admitted != 0,
193                        was_recalled: recalled != 0,
194                        features_json: features,
195                        created_at,
196                    }
197                },
198            )
199            .collect())
200    }
201
202    /// Delete old training records, keeping the most recent `keep_recent`.
203    ///
204    /// Called after each retraining cycle to prevent unbounded table growth.
205    ///
206    /// # Errors
207    ///
208    /// Returns an error if the database delete fails.
209    // TODO(#2416): call cleanup_old_training_data() in the RL retrain loop scheduled in
210    // bootstrap/mod.rs once the retrain loop is wired.
211    pub async fn cleanup_old_training_data(&self, keep_recent: usize) -> Result<(), MemoryError> {
212        let keep = i64::try_from(keep_recent).unwrap_or(i64::MAX);
213        zeph_db::query(sql!(
214            "DELETE FROM admission_training_data \
215             WHERE id NOT IN ( \
216                 SELECT id FROM admission_training_data \
217                 ORDER BY created_at DESC \
218                 LIMIT ? \
219             )"
220        ))
221        .bind(keep)
222        .execute(&self.pool)
223        .await?;
224        Ok(())
225    }
226
227    /// Save trained RL model weights to `SQLite` for persistence across restarts.
228    ///
229    /// Uses a fixed `id = 1` row (INSERT OR REPLACE) so the table never grows beyond
230    /// one row — avoiding unbounded growth from repeated retrain cycles.
231    ///
232    /// # Errors
233    ///
234    /// Returns an error if the database upsert fails.
235    pub async fn save_rl_weights(
236        &self,
237        weights_json: &str,
238        sample_count: i64,
239    ) -> Result<(), MemoryError> {
240        zeph_db::query(sql!(
241            "INSERT INTO admission_rl_weights (id, weights_json, sample_count) \
242             VALUES (1, ?, ?) \
243             ON CONFLICT (id) DO UPDATE SET \
244               weights_json = EXCLUDED.weights_json, \
245               sample_count = EXCLUDED.sample_count"
246        ))
247        .bind(weights_json)
248        .bind(sample_count)
249        .execute(&self.pool)
250        .await?;
251        Ok(())
252    }
253
254    /// Load the latest RL model weights from `SQLite`.
255    ///
256    /// Returns `None` if no weights have been saved yet.
257    ///
258    /// # Errors
259    ///
260    /// Returns an error if the database query fails.
261    pub async fn load_rl_weights(&self) -> Result<Option<(String, i64)>, MemoryError> {
262        let row: Option<(String, i64)> = zeph_db::query_as(sql!(
263            "SELECT weights_json, sample_count FROM admission_rl_weights \
264             ORDER BY id DESC LIMIT 1"
265        ))
266        .fetch_optional(&self.pool)
267        .await?;
268        Ok(row)
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    async fn make_store() -> (SqliteStore, i64) {
277        let store = SqliteStore::new(":memory:")
278            .await
279            .expect("SqliteStore::new");
280        let cid = store
281            .create_conversation()
282            .await
283            .expect("create_conversation");
284        (store, cid.0)
285    }
286
287    #[tokio::test]
288    async fn record_and_count_training_data() {
289        let (store, cid) = make_store().await;
290        let cid = ConversationId(cid);
291        store
292            .record_admission_training(AdmissionTrainingInput {
293                message_id: None,
294                conversation_id: cid,
295                content: "content",
296                role: "user",
297                composite_score: 0.5,
298                was_admitted: false,
299                features_json: "[]",
300            })
301            .await
302            .expect("record rejected");
303        store
304            .record_admission_training(AdmissionTrainingInput {
305                message_id: Some(MessageId(1)),
306                conversation_id: cid,
307                content: "content2",
308                role: "assistant",
309                composite_score: 0.8,
310                was_admitted: true,
311                features_json: "[]",
312            })
313            .await
314            .expect("record admitted");
315        let count = store.count_training_records().await.expect("count");
316        assert_eq!(count, 2);
317    }
318
319    #[tokio::test]
320    async fn mark_recalled_sets_flag() {
321        let (store, cid) = make_store().await;
322        let cid = ConversationId(cid);
323        store
324            .record_admission_training(AdmissionTrainingInput {
325                message_id: Some(MessageId(42)),
326                conversation_id: cid,
327                content: "recalled content",
328                role: "user",
329                composite_score: 0.7,
330                was_admitted: true,
331                features_json: "[]",
332            })
333            .await
334            .expect("record");
335        store
336            .mark_training_recalled(&[MessageId(42)])
337            .await
338            .expect("mark recalled");
339        let batch = store.get_training_batch(10).await.expect("batch");
340        assert_eq!(batch.len(), 1);
341        assert!(
342            batch[0].was_recalled,
343            "was_recalled must be true after marking"
344        );
345    }
346
347    #[tokio::test]
348    async fn rejected_message_has_no_message_id() {
349        let (store, cid) = make_store().await;
350        let cid = ConversationId(cid);
351        store
352            .record_admission_training(AdmissionTrainingInput {
353                message_id: None,
354                conversation_id: cid,
355                content: "rejected",
356                role: "user",
357                composite_score: 0.2,
358                was_admitted: false,
359                features_json: "[]",
360            })
361            .await
362            .expect("record");
363        let batch = store.get_training_batch(10).await.expect("batch");
364        assert_eq!(batch.len(), 1);
365        assert!(!batch[0].was_admitted);
366        assert!(batch[0].message_id.is_none());
367    }
368
369    #[tokio::test]
370    async fn cleanup_trims_old_records() {
371        let (store, cid) = make_store().await;
372        let cid = ConversationId(cid);
373        for i in 0..5_i64 {
374            let content = format!("content {i}");
375            store
376                .record_admission_training(AdmissionTrainingInput {
377                    message_id: Some(MessageId(i)),
378                    conversation_id: cid,
379                    content: &content,
380                    role: "user",
381                    composite_score: 0.5,
382                    was_admitted: true,
383                    features_json: "[]",
384                })
385                .await
386                .expect("record");
387        }
388        // Keep only 2 most recent.
389        store.cleanup_old_training_data(2).await.expect("cleanup");
390        let count = store.count_training_records().await.expect("count");
391        assert_eq!(count, 2);
392    }
393
394    #[tokio::test]
395    async fn save_and_load_rl_weights() {
396        let (store, _) = make_store().await;
397        store
398            .save_rl_weights(r#"{"weights":[0.1,0.2],"bias":0.0}"#, 100)
399            .await
400            .expect("save");
401        let loaded = store.load_rl_weights().await.expect("load");
402        assert!(loaded.is_some());
403        let (json, count) = loaded.unwrap();
404        assert!(json.contains("weights"));
405        assert_eq!(count, 100);
406    }
407
408    #[tokio::test]
409    async fn load_rl_weights_returns_none_when_empty() {
410        let (store, _) = make_store().await;
411        let loaded = store.load_rl_weights().await.expect("load");
412        assert!(loaded.is_none());
413    }
414
415    #[test]
416    fn content_hash_is_deterministic() {
417        let h1 = content_hash("hello world");
418        let h2 = content_hash("hello world");
419        assert_eq!(h1, h2);
420    }
421
422    #[test]
423    fn content_hash_differs_for_different_content() {
424        let h1 = content_hash("hello");
425        let h2 = content_hash("world");
426        assert_ne!(h1, h2);
427    }
428}