1#[allow(unused_imports)]
11use zeph_db::sql;
12
13use crate::error::MemoryError;
14use crate::store::SqliteStore;
15use crate::types::{ConversationId, MessageId};
16
17pub struct AdmissionTrainingInput<'a> {
19 pub message_id: Option<MessageId>,
20 pub conversation_id: ConversationId,
21 pub content: &'a str,
22 pub role: &'a str,
23 pub composite_score: f32,
24 pub was_admitted: bool,
25 pub features_json: &'a str,
26}
27
28#[derive(Debug, Clone)]
30pub struct AdmissionTrainingRecord {
31 pub id: i64,
32 pub message_id: Option<i64>,
33 pub conversation_id: ConversationId,
34 pub content_hash: String,
35 pub role: String,
36 pub composite_score: f32,
37 pub was_admitted: bool,
38 pub was_recalled: bool,
39 pub features_json: String,
40 pub created_at: String,
41}
42
43#[must_use]
48pub fn content_hash(content: &str) -> String {
49 use sha2::{Digest, Sha256};
50 let digest = Sha256::digest(content.as_bytes());
51 let mut bytes = [0u8; 8];
52 bytes.copy_from_slice(&digest[..8]);
53 format!("{:016x}", u64::from_be_bytes(bytes))
54}
55
56impl SqliteStore {
57 pub async fn record_admission_training(
67 &self,
68 input: AdmissionTrainingInput<'_>,
69 ) -> Result<i64, MemoryError> {
70 let hash = content_hash(input.content);
71 let admitted_i = i64::from(input.was_admitted);
72 let msg_id = input.message_id.map(|m| m.0);
73 let (conversation_id, role, composite_score, features_json) = (
74 input.conversation_id,
75 input.role,
76 input.composite_score,
77 input.features_json,
78 );
79 let id = zeph_db::query_scalar(sql!(
80 "INSERT INTO admission_training_data \
81 (message_id, conversation_id, content_hash, role, composite_score, \
82 was_admitted, was_recalled, features_json) \
83 VALUES (?, ?, ?, ?, ?, ?, 0, ?) \
84 RETURNING id"
85 ))
86 .bind(msg_id)
87 .bind(conversation_id.0)
88 .bind(hash)
89 .bind(role)
90 .bind(f64::from(composite_score))
91 .bind(admitted_i)
92 .bind(features_json)
93 .fetch_one(&self.pool)
94 .await?;
95 Ok(id)
96 }
97
98 pub async fn mark_training_recalled(
107 &self,
108 message_ids: &[MessageId],
109 ) -> Result<(), MemoryError> {
110 if message_ids.is_empty() {
111 return Ok(());
112 }
113 let placeholders: String = message_ids
114 .iter()
115 .map(|_| "?")
116 .collect::<Vec<_>>()
117 .join(",");
118 let query = format!(
119 "UPDATE admission_training_data \
120 SET was_recalled = 1, updated_at = datetime('now') \
121 WHERE message_id IN ({placeholders})"
122 );
123 let mut q = zeph_db::query(&query);
124 for id in message_ids {
125 q = q.bind(id.0);
126 }
127 q.execute(&self.pool).await?;
128 Ok(())
129 }
130
131 pub async fn count_training_records(&self) -> Result<i64, MemoryError> {
137 let count = zeph_db::query_scalar(sql!("SELECT COUNT(*) FROM admission_training_data"))
138 .fetch_one(&self.pool)
139 .await?;
140 Ok(count)
141 }
142
143 pub async fn get_training_batch(
151 &self,
152 limit: usize,
153 ) -> Result<Vec<AdmissionTrainingRecord>, MemoryError> {
154 let limit = i64::try_from(limit).unwrap_or(i64::MAX);
155 let rows = zeph_db::query_as::<
156 _,
157 (
158 i64,
159 Option<i64>,
160 i64,
161 String,
162 String,
163 f64,
164 i64,
165 i64,
166 String,
167 String,
168 ),
169 >(sql!(
170 "SELECT id, message_id, conversation_id, content_hash, role, \
171 composite_score, was_admitted, was_recalled, features_json, created_at \
172 FROM admission_training_data \
173 ORDER BY created_at ASC \
174 LIMIT ?"
175 ))
176 .bind(limit)
177 .fetch_all(&self.pool)
178 .await?;
179
180 Ok(rows
181 .into_iter()
182 .map(
183 |(id, msg_id, cid, hash, role, score, admitted, recalled, features, created_at)| {
184 AdmissionTrainingRecord {
185 id,
186 message_id: msg_id,
187 conversation_id: ConversationId(cid),
188 content_hash: hash,
189 role,
190 #[expect(clippy::cast_possible_truncation)]
191 composite_score: score as f32,
192 was_admitted: admitted != 0,
193 was_recalled: recalled != 0,
194 features_json: features,
195 created_at,
196 }
197 },
198 )
199 .collect())
200 }
201
202 pub async fn cleanup_old_training_data(&self, keep_recent: usize) -> Result<(), MemoryError> {
212 let keep = i64::try_from(keep_recent).unwrap_or(i64::MAX);
213 zeph_db::query(sql!(
214 "DELETE FROM admission_training_data \
215 WHERE id NOT IN ( \
216 SELECT id FROM admission_training_data \
217 ORDER BY created_at DESC \
218 LIMIT ? \
219 )"
220 ))
221 .bind(keep)
222 .execute(&self.pool)
223 .await?;
224 Ok(())
225 }
226
227 pub async fn save_rl_weights(
236 &self,
237 weights_json: &str,
238 sample_count: i64,
239 ) -> Result<(), MemoryError> {
240 zeph_db::query(sql!(
241 "INSERT OR REPLACE INTO admission_rl_weights (id, weights_json, sample_count) \
242 VALUES (1, ?, ?)"
243 ))
244 .bind(weights_json)
245 .bind(sample_count)
246 .execute(&self.pool)
247 .await?;
248 Ok(())
249 }
250
251 pub async fn load_rl_weights(&self) -> Result<Option<(String, i64)>, MemoryError> {
259 let row: Option<(String, i64)> = zeph_db::query_as(sql!(
260 "SELECT weights_json, sample_count FROM admission_rl_weights \
261 ORDER BY id DESC LIMIT 1"
262 ))
263 .fetch_optional(&self.pool)
264 .await?;
265 Ok(row)
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 async fn make_store() -> (SqliteStore, i64) {
274 let store = SqliteStore::new(":memory:")
275 .await
276 .expect("SqliteStore::new");
277 let cid = store
278 .create_conversation()
279 .await
280 .expect("create_conversation");
281 (store, cid.0)
282 }
283
284 #[tokio::test]
285 async fn record_and_count_training_data() {
286 let (store, cid) = make_store().await;
287 let cid = ConversationId(cid);
288 store
289 .record_admission_training(AdmissionTrainingInput {
290 message_id: None,
291 conversation_id: cid,
292 content: "content",
293 role: "user",
294 composite_score: 0.5,
295 was_admitted: false,
296 features_json: "[]",
297 })
298 .await
299 .expect("record rejected");
300 store
301 .record_admission_training(AdmissionTrainingInput {
302 message_id: Some(MessageId(1)),
303 conversation_id: cid,
304 content: "content2",
305 role: "assistant",
306 composite_score: 0.8,
307 was_admitted: true,
308 features_json: "[]",
309 })
310 .await
311 .expect("record admitted");
312 let count = store.count_training_records().await.expect("count");
313 assert_eq!(count, 2);
314 }
315
316 #[tokio::test]
317 async fn mark_recalled_sets_flag() {
318 let (store, cid) = make_store().await;
319 let cid = ConversationId(cid);
320 store
321 .record_admission_training(AdmissionTrainingInput {
322 message_id: Some(MessageId(42)),
323 conversation_id: cid,
324 content: "recalled content",
325 role: "user",
326 composite_score: 0.7,
327 was_admitted: true,
328 features_json: "[]",
329 })
330 .await
331 .expect("record");
332 store
333 .mark_training_recalled(&[MessageId(42)])
334 .await
335 .expect("mark recalled");
336 let batch = store.get_training_batch(10).await.expect("batch");
337 assert_eq!(batch.len(), 1);
338 assert!(
339 batch[0].was_recalled,
340 "was_recalled must be true after marking"
341 );
342 }
343
344 #[tokio::test]
345 async fn rejected_message_has_no_message_id() {
346 let (store, cid) = make_store().await;
347 let cid = ConversationId(cid);
348 store
349 .record_admission_training(AdmissionTrainingInput {
350 message_id: None,
351 conversation_id: cid,
352 content: "rejected",
353 role: "user",
354 composite_score: 0.2,
355 was_admitted: false,
356 features_json: "[]",
357 })
358 .await
359 .expect("record");
360 let batch = store.get_training_batch(10).await.expect("batch");
361 assert_eq!(batch.len(), 1);
362 assert!(!batch[0].was_admitted);
363 assert!(batch[0].message_id.is_none());
364 }
365
366 #[tokio::test]
367 async fn cleanup_trims_old_records() {
368 let (store, cid) = make_store().await;
369 let cid = ConversationId(cid);
370 for i in 0..5_i64 {
371 let content = format!("content {i}");
372 store
373 .record_admission_training(AdmissionTrainingInput {
374 message_id: Some(MessageId(i)),
375 conversation_id: cid,
376 content: &content,
377 role: "user",
378 composite_score: 0.5,
379 was_admitted: true,
380 features_json: "[]",
381 })
382 .await
383 .expect("record");
384 }
385 store.cleanup_old_training_data(2).await.expect("cleanup");
387 let count = store.count_training_records().await.expect("count");
388 assert_eq!(count, 2);
389 }
390
391 #[tokio::test]
392 async fn save_and_load_rl_weights() {
393 let (store, _) = make_store().await;
394 store
395 .save_rl_weights(r#"{"weights":[0.1,0.2],"bias":0.0}"#, 100)
396 .await
397 .expect("save");
398 let loaded = store.load_rl_weights().await.expect("load");
399 assert!(loaded.is_some());
400 let (json, count) = loaded.unwrap();
401 assert!(json.contains("weights"));
402 assert_eq!(count, 100);
403 }
404
405 #[tokio::test]
406 async fn load_rl_weights_returns_none_when_empty() {
407 let (store, _) = make_store().await;
408 let loaded = store.load_rl_weights().await.expect("load");
409 assert!(loaded.is_none());
410 }
411
412 #[test]
413 fn content_hash_is_deterministic() {
414 let h1 = content_hash("hello world");
415 let h2 = content_hash("hello world");
416 assert_eq!(h1, h2);
417 }
418
419 #[test]
420 fn content_hash_differs_for_different_content() {
421 let h1 = content_hash("hello");
422 let h2 = content_hash("world");
423 assert_ne!(h1, h2);
424 }
425}