1use zeph_db::sql;
10
11use crate::error::MemoryError;
12use crate::store::SqliteStore;
13use crate::types::ConversationId;
14
15#[derive(Debug, Clone)]
17pub struct CompressionTrainingRecord {
18 pub id: i64,
19 pub conversation_id: ConversationId,
20 pub compression_ratio: f32,
21 pub message_count: i64,
22 pub avg_message_length: f32,
23 pub tool_output_fraction: f32,
24 pub probe_score: f32,
25 pub created_at: String,
26}
27
28impl SqliteStore {
29 pub async fn record_compression_training(
35 &self,
36 conversation_id: ConversationId,
37 compression_ratio: f32,
38 message_count: i64,
39 avg_message_length: f32,
40 tool_output_fraction: f32,
41 probe_score: f32,
42 ) -> Result<i64, MemoryError> {
43 let id = zeph_db::query_scalar(sql!(
44 "INSERT INTO compression_predictor_training \
45 (conversation_id, compression_ratio, message_count, \
46 avg_message_length, tool_output_fraction, probe_score) \
47 VALUES (?, ?, ?, ?, ?, ?) \
48 RETURNING id"
49 ))
50 .bind(conversation_id.0)
51 .bind(f64::from(compression_ratio))
52 .bind(message_count)
53 .bind(f64::from(avg_message_length))
54 .bind(f64::from(tool_output_fraction))
55 .bind(f64::from(probe_score))
56 .fetch_one(&self.pool)
57 .await?;
58 Ok(id)
59 }
60
61 pub async fn count_compression_training_records(&self) -> Result<i64, MemoryError> {
67 let count =
68 zeph_db::query_scalar(sql!("SELECT COUNT(*) FROM compression_predictor_training"))
69 .fetch_one(&self.pool)
70 .await?;
71 Ok(count)
72 }
73
74 pub async fn get_compression_training_batch(
80 &self,
81 limit: usize,
82 ) -> Result<Vec<CompressionTrainingRecord>, MemoryError> {
83 let limit = i64::try_from(limit).unwrap_or(i64::MAX);
84 let rows = zeph_db::query_as::<_, (i64, i64, f64, i64, f64, f64, f64, String)>(sql!(
85 "SELECT id, conversation_id, compression_ratio, message_count, \
86 avg_message_length, tool_output_fraction, probe_score, created_at \
87 FROM compression_predictor_training \
88 ORDER BY created_at DESC \
89 LIMIT ?"
90 ))
91 .bind(limit)
92 .fetch_all(&self.pool)
93 .await?;
94
95 Ok(rows
96 .into_iter()
97 .map(
98 |(id, cid, ratio, msg_count, avg_len, tool_frac, score, created_at)| {
99 CompressionTrainingRecord {
100 id,
101 conversation_id: ConversationId(cid),
102 #[expect(clippy::cast_possible_truncation)]
103 compression_ratio: ratio as f32,
104 message_count: msg_count,
105 #[expect(clippy::cast_possible_truncation)]
106 avg_message_length: avg_len as f32,
107 #[expect(clippy::cast_possible_truncation)]
108 tool_output_fraction: tool_frac as f32,
109 #[expect(clippy::cast_possible_truncation)]
110 probe_score: score as f32,
111 created_at,
112 }
113 },
114 )
115 .collect())
116 }
117
118 pub async fn trim_compression_training_data(
124 &self,
125 keep_recent: usize,
126 ) -> Result<(), MemoryError> {
127 let keep = i64::try_from(keep_recent).unwrap_or(i64::MAX);
128 zeph_db::query(sql!(
129 "DELETE FROM compression_predictor_training \
130 WHERE id NOT IN ( \
131 SELECT id FROM compression_predictor_training \
132 ORDER BY created_at DESC \
133 LIMIT ? \
134 )"
135 ))
136 .bind(keep)
137 .execute(&self.pool)
138 .await?;
139 Ok(())
140 }
141
142 pub async fn save_compression_predictor_weights(
148 &self,
149 weights_json: &str,
150 ) -> Result<(), MemoryError> {
151 zeph_db::query(sql!(
152 "INSERT OR REPLACE INTO compression_predictor_weights (id, weights_json, updated_at) \
153 VALUES (1, ?, strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))"
154 ))
155 .bind(weights_json)
156 .execute(&self.pool)
157 .await?;
158 Ok(())
159 }
160
161 pub async fn load_compression_predictor_weights(&self) -> Result<Option<String>, MemoryError> {
169 let row: Option<(String,)> = zeph_db::query_as(sql!(
170 "SELECT weights_json FROM compression_predictor_weights WHERE id = 1"
171 ))
172 .fetch_optional(&self.pool)
173 .await?;
174 Ok(row.map(|(json,)| json))
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 async fn make_store() -> (SqliteStore, ConversationId) {
183 let store = SqliteStore::new(":memory:")
184 .await
185 .expect("SqliteStore::new");
186 let cid = store
187 .create_conversation()
188 .await
189 .expect("create_conversation");
190 (store, cid)
191 }
192
193 #[tokio::test]
201 async fn record_and_count_training_data() {
202 let (store, cid) = make_store().await;
203 store
204 .record_compression_training(cid, 0.5, 20, 150.0, 0.3, 0.75)
205 .await
206 .expect("record");
207 let count = store
208 .count_compression_training_records()
209 .await
210 .expect("count");
211 assert_eq!(count, 1);
212 store.pool().close().await;
213 }
214
215 #[tokio::test]
216 async fn batch_returns_records() {
217 let (store, cid) = make_store().await;
218 store
219 .record_compression_training(cid, 0.5, 20, 150.0, 0.3, 0.75)
220 .await
221 .expect("record");
222 let batch = store
223 .get_compression_training_batch(10)
224 .await
225 .expect("batch");
226 assert_eq!(batch.len(), 1);
227 assert!((batch[0].compression_ratio - 0.5).abs() < 1e-4);
228 assert!((batch[0].probe_score - 0.75).abs() < 1e-4);
229 store.pool().close().await;
230 }
231
232 #[tokio::test]
233 async fn trim_keeps_most_recent() {
234 let (store, cid) = make_store().await;
235 for _ in 0..5 {
236 store
237 .record_compression_training(cid, 0.5, 20, 150.0, 0.3, 0.75)
238 .await
239 .expect("record");
240 }
241 store.trim_compression_training_data(2).await.expect("trim");
242 let count = store
243 .count_compression_training_records()
244 .await
245 .expect("count");
246 assert_eq!(count, 2);
247 store.pool().close().await;
248 }
249
250 #[tokio::test]
251 async fn save_and_load_weights() {
252 let (store, _) = make_store().await;
253 store
254 .save_compression_predictor_weights(r#"{"weights":[0.1,0.2,0.3,0.4],"bias":0.0}"#)
255 .await
256 .expect("save");
257 let loaded = store
258 .load_compression_predictor_weights()
259 .await
260 .expect("load");
261 assert!(loaded.is_some());
262 assert!(loaded.unwrap().contains("weights"));
263 store.pool().close().await;
264 }
265
266 #[tokio::test]
267 async fn load_weights_returns_none_when_empty() {
268 let (store, _) = make_store().await;
269 let loaded = store
270 .load_compression_predictor_weights()
271 .await
272 .expect("load");
273 assert!(loaded.is_none());
274 store.pool().close().await;
275 }
276}