1#[allow(unused_imports)]
11use zeph_db::sql;
12
13use crate::error::MemoryError;
14use crate::store::SqliteStore;
15use crate::types::{ConversationId, MessageId};
16
17pub struct AdmissionTrainingInput<'a> {
19 pub message_id: Option<MessageId>,
20 pub conversation_id: ConversationId,
21 pub content: &'a str,
22 pub role: &'a str,
23 pub composite_score: f32,
24 pub was_admitted: bool,
25 pub features_json: &'a str,
26}
27
28#[derive(Debug, Clone)]
30pub struct AdmissionTrainingRecord {
31 pub id: i64,
32 pub message_id: Option<i64>,
33 pub conversation_id: ConversationId,
34 pub content_hash: String,
35 pub role: String,
36 pub composite_score: f32,
37 pub was_admitted: bool,
38 pub was_recalled: bool,
39 pub features_json: String,
40 pub created_at: String,
41}
42
43#[must_use]
48pub fn content_hash(content: &str) -> String {
49 use sha2::{Digest, Sha256};
50 let digest = Sha256::digest(content.as_bytes());
51 let mut bytes = [0u8; 8];
52 bytes.copy_from_slice(&digest[..8]);
53 format!("{:016x}", u64::from_be_bytes(bytes))
54}
55
56impl SqliteStore {
57 pub async fn record_admission_training(
67 &self,
68 input: AdmissionTrainingInput<'_>,
69 ) -> Result<i64, MemoryError> {
70 let hash = content_hash(input.content);
71 let admitted_i = i64::from(input.was_admitted);
72 let msg_id = input.message_id.map(|m| m.0);
73 let (conversation_id, role, composite_score, features_json) = (
74 input.conversation_id,
75 input.role,
76 input.composite_score,
77 input.features_json,
78 );
79 let id = zeph_db::query_scalar(sql!(
80 "INSERT INTO admission_training_data \
81 (message_id, conversation_id, content_hash, role, composite_score, \
82 was_admitted, was_recalled, features_json) \
83 VALUES (?, ?, ?, ?, ?, ?, 0, ?) \
84 RETURNING id"
85 ))
86 .bind(msg_id)
87 .bind(conversation_id.0)
88 .bind(hash)
89 .bind(role)
90 .bind(f64::from(composite_score))
91 .bind(admitted_i)
92 .bind(features_json)
93 .fetch_one(&self.pool)
94 .await?;
95 Ok(id)
96 }
97
98 pub async fn mark_training_recalled(
107 &self,
108 message_ids: &[MessageId],
109 ) -> Result<(), MemoryError> {
110 if message_ids.is_empty() {
111 return Ok(());
112 }
113 let placeholders: String = message_ids
114 .iter()
115 .map(|_| "?")
116 .collect::<Vec<_>>()
117 .join(",");
118 let query = format!(
119 "UPDATE admission_training_data \
120 SET was_recalled = 1, updated_at = datetime('now') \
121 WHERE message_id IN ({placeholders})"
122 );
123 let mut q = zeph_db::query(&query);
124 for id in message_ids {
125 q = q.bind(id.0);
126 }
127 q.execute(&self.pool).await?;
128 Ok(())
129 }
130
131 pub async fn count_training_records(&self) -> Result<i64, MemoryError> {
137 let count = zeph_db::query_scalar(sql!("SELECT COUNT(*) FROM admission_training_data"))
138 .fetch_one(&self.pool)
139 .await?;
140 Ok(count)
141 }
142
143 pub async fn get_training_batch(
151 &self,
152 limit: usize,
153 ) -> Result<Vec<AdmissionTrainingRecord>, MemoryError> {
154 let limit = i64::try_from(limit).unwrap_or(i64::MAX);
155 let rows = zeph_db::query_as::<
156 _,
157 (
158 i64,
159 Option<i64>,
160 i64,
161 String,
162 String,
163 f64,
164 i64,
165 i64,
166 String,
167 String,
168 ),
169 >(sql!(
170 "SELECT id, message_id, conversation_id, content_hash, role, \
171 composite_score, was_admitted, was_recalled, features_json, created_at \
172 FROM admission_training_data \
173 ORDER BY created_at ASC \
174 LIMIT ?"
175 ))
176 .bind(limit)
177 .fetch_all(&self.pool)
178 .await?;
179
180 Ok(rows
181 .into_iter()
182 .map(
183 |(id, msg_id, cid, hash, role, score, admitted, recalled, features, created_at)| {
184 AdmissionTrainingRecord {
185 id,
186 message_id: msg_id,
187 conversation_id: ConversationId(cid),
188 content_hash: hash,
189 role,
190 #[expect(clippy::cast_possible_truncation)]
191 composite_score: score as f32,
192 was_admitted: admitted != 0,
193 was_recalled: recalled != 0,
194 features_json: features,
195 created_at,
196 }
197 },
198 )
199 .collect())
200 }
201
202 pub async fn cleanup_old_training_data(&self, keep_recent: usize) -> Result<(), MemoryError> {
212 let keep = i64::try_from(keep_recent).unwrap_or(i64::MAX);
213 zeph_db::query(sql!(
214 "DELETE FROM admission_training_data \
215 WHERE id NOT IN ( \
216 SELECT id FROM admission_training_data \
217 ORDER BY created_at DESC \
218 LIMIT ? \
219 )"
220 ))
221 .bind(keep)
222 .execute(&self.pool)
223 .await?;
224 Ok(())
225 }
226
227 pub async fn save_rl_weights(
236 &self,
237 weights_json: &str,
238 sample_count: i64,
239 ) -> Result<(), MemoryError> {
240 zeph_db::query(sql!(
241 "INSERT INTO admission_rl_weights (id, weights_json, sample_count) \
242 VALUES (1, ?, ?) \
243 ON CONFLICT (id) DO UPDATE SET \
244 weights_json = EXCLUDED.weights_json, \
245 sample_count = EXCLUDED.sample_count"
246 ))
247 .bind(weights_json)
248 .bind(sample_count)
249 .execute(&self.pool)
250 .await?;
251 Ok(())
252 }
253
254 pub async fn load_rl_weights(&self) -> Result<Option<(String, i64)>, MemoryError> {
262 let row: Option<(String, i64)> = zeph_db::query_as(sql!(
263 "SELECT weights_json, sample_count FROM admission_rl_weights \
264 ORDER BY id DESC LIMIT 1"
265 ))
266 .fetch_optional(&self.pool)
267 .await?;
268 Ok(row)
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 async fn make_store() -> (SqliteStore, i64) {
277 let store = SqliteStore::new(":memory:")
278 .await
279 .expect("SqliteStore::new");
280 let cid = store
281 .create_conversation()
282 .await
283 .expect("create_conversation");
284 (store, cid.0)
285 }
286
287 #[tokio::test]
288 async fn record_and_count_training_data() {
289 let (store, cid) = make_store().await;
290 let cid = ConversationId(cid);
291 store
292 .record_admission_training(AdmissionTrainingInput {
293 message_id: None,
294 conversation_id: cid,
295 content: "content",
296 role: "user",
297 composite_score: 0.5,
298 was_admitted: false,
299 features_json: "[]",
300 })
301 .await
302 .expect("record rejected");
303 store
304 .record_admission_training(AdmissionTrainingInput {
305 message_id: Some(MessageId(1)),
306 conversation_id: cid,
307 content: "content2",
308 role: "assistant",
309 composite_score: 0.8,
310 was_admitted: true,
311 features_json: "[]",
312 })
313 .await
314 .expect("record admitted");
315 let count = store.count_training_records().await.expect("count");
316 assert_eq!(count, 2);
317 }
318
319 #[tokio::test]
320 async fn mark_recalled_sets_flag() {
321 let (store, cid) = make_store().await;
322 let cid = ConversationId(cid);
323 store
324 .record_admission_training(AdmissionTrainingInput {
325 message_id: Some(MessageId(42)),
326 conversation_id: cid,
327 content: "recalled content",
328 role: "user",
329 composite_score: 0.7,
330 was_admitted: true,
331 features_json: "[]",
332 })
333 .await
334 .expect("record");
335 store
336 .mark_training_recalled(&[MessageId(42)])
337 .await
338 .expect("mark recalled");
339 let batch = store.get_training_batch(10).await.expect("batch");
340 assert_eq!(batch.len(), 1);
341 assert!(
342 batch[0].was_recalled,
343 "was_recalled must be true after marking"
344 );
345 }
346
347 #[tokio::test]
348 async fn rejected_message_has_no_message_id() {
349 let (store, cid) = make_store().await;
350 let cid = ConversationId(cid);
351 store
352 .record_admission_training(AdmissionTrainingInput {
353 message_id: None,
354 conversation_id: cid,
355 content: "rejected",
356 role: "user",
357 composite_score: 0.2,
358 was_admitted: false,
359 features_json: "[]",
360 })
361 .await
362 .expect("record");
363 let batch = store.get_training_batch(10).await.expect("batch");
364 assert_eq!(batch.len(), 1);
365 assert!(!batch[0].was_admitted);
366 assert!(batch[0].message_id.is_none());
367 }
368
369 #[tokio::test]
370 async fn cleanup_trims_old_records() {
371 let (store, cid) = make_store().await;
372 let cid = ConversationId(cid);
373 for i in 0..5_i64 {
374 let content = format!("content {i}");
375 store
376 .record_admission_training(AdmissionTrainingInput {
377 message_id: Some(MessageId(i)),
378 conversation_id: cid,
379 content: &content,
380 role: "user",
381 composite_score: 0.5,
382 was_admitted: true,
383 features_json: "[]",
384 })
385 .await
386 .expect("record");
387 }
388 store.cleanup_old_training_data(2).await.expect("cleanup");
390 let count = store.count_training_records().await.expect("count");
391 assert_eq!(count, 2);
392 }
393
394 #[tokio::test]
395 async fn save_and_load_rl_weights() {
396 let (store, _) = make_store().await;
397 store
398 .save_rl_weights(r#"{"weights":[0.1,0.2],"bias":0.0}"#, 100)
399 .await
400 .expect("save");
401 let loaded = store.load_rl_weights().await.expect("load");
402 assert!(loaded.is_some());
403 let (json, count) = loaded.unwrap();
404 assert!(json.contains("weights"));
405 assert_eq!(count, 100);
406 }
407
408 #[tokio::test]
409 async fn load_rl_weights_returns_none_when_empty() {
410 let (store, _) = make_store().await;
411 let loaded = store.load_rl_weights().await.expect("load");
412 assert!(loaded.is_none());
413 }
414
415 #[test]
416 fn content_hash_is_deterministic() {
417 let h1 = content_hash("hello world");
418 let h2 = content_hash("hello world");
419 assert_eq!(h1, h2);
420 }
421
422 #[test]
423 fn content_hash_differs_for_different_content() {
424 let h1 = content_hash("hello");
425 let h2 = content_hash("world");
426 assert_ne!(h1, h2);
427 }
428}