1use zeph_db::sql;
10
11use crate::error::MemoryError;
12use crate::store::SqliteStore;
13use crate::types::ConversationId;
14
15#[derive(Debug, Clone)]
17pub struct CompressionTrainingRecord {
18 pub id: i64,
19 pub conversation_id: ConversationId,
20 pub compression_ratio: f32,
21 pub message_count: i64,
22 pub avg_message_length: f32,
23 pub tool_output_fraction: f32,
24 pub probe_score: f32,
25 pub created_at: String,
26}
27
28impl SqliteStore {
29 pub async fn record_compression_training(
35 &self,
36 conversation_id: ConversationId,
37 compression_ratio: f32,
38 message_count: i64,
39 avg_message_length: f32,
40 tool_output_fraction: f32,
41 probe_score: f32,
42 ) -> Result<i64, MemoryError> {
43 let id = zeph_db::query_scalar(sql!(
44 "INSERT INTO compression_predictor_training \
45 (conversation_id, compression_ratio, message_count, \
46 avg_message_length, tool_output_fraction, probe_score) \
47 VALUES (?, ?, ?, ?, ?, ?) \
48 RETURNING id"
49 ))
50 .bind(conversation_id.0)
51 .bind(f64::from(compression_ratio))
52 .bind(message_count)
53 .bind(f64::from(avg_message_length))
54 .bind(f64::from(tool_output_fraction))
55 .bind(f64::from(probe_score))
56 .fetch_one(&self.pool)
57 .await?;
58 Ok(id)
59 }
60
61 pub async fn count_compression_training_records(&self) -> Result<i64, MemoryError> {
67 let count =
68 zeph_db::query_scalar(sql!("SELECT COUNT(*) FROM compression_predictor_training"))
69 .fetch_one(&self.pool)
70 .await?;
71 Ok(count)
72 }
73
74 pub async fn get_compression_training_batch(
80 &self,
81 limit: usize,
82 ) -> Result<Vec<CompressionTrainingRecord>, MemoryError> {
83 let limit = i64::try_from(limit).unwrap_or(i64::MAX);
84 let rows = zeph_db::query_as::<_, (i64, i64, f64, i64, f64, f64, f64, String)>(sql!(
85 "SELECT id, conversation_id, compression_ratio, message_count, \
86 avg_message_length, tool_output_fraction, probe_score, created_at \
87 FROM compression_predictor_training \
88 ORDER BY created_at DESC \
89 LIMIT ?"
90 ))
91 .bind(limit)
92 .fetch_all(&self.pool)
93 .await?;
94
95 Ok(rows
96 .into_iter()
97 .map(
98 |(id, cid, ratio, msg_count, avg_len, tool_frac, score, created_at)| {
99 CompressionTrainingRecord {
100 id,
101 conversation_id: ConversationId(cid),
102 #[expect(clippy::cast_possible_truncation)]
103 compression_ratio: ratio as f32,
104 message_count: msg_count,
105 #[expect(clippy::cast_possible_truncation)]
106 avg_message_length: avg_len as f32,
107 #[expect(clippy::cast_possible_truncation)]
108 tool_output_fraction: tool_frac as f32,
109 #[expect(clippy::cast_possible_truncation)]
110 probe_score: score as f32,
111 created_at,
112 }
113 },
114 )
115 .collect())
116 }
117
118 pub async fn trim_compression_training_data(
124 &self,
125 keep_recent: usize,
126 ) -> Result<(), MemoryError> {
127 let keep = i64::try_from(keep_recent).unwrap_or(i64::MAX);
128 zeph_db::query(sql!(
129 "DELETE FROM compression_predictor_training \
130 WHERE id NOT IN ( \
131 SELECT id FROM compression_predictor_training \
132 ORDER BY created_at DESC \
133 LIMIT ? \
134 )"
135 ))
136 .bind(keep)
137 .execute(&self.pool)
138 .await?;
139 Ok(())
140 }
141
142 pub async fn save_compression_predictor_weights(
148 &self,
149 weights_json: &str,
150 ) -> Result<(), MemoryError> {
151 zeph_db::query(sql!(
152 "INSERT OR REPLACE INTO compression_predictor_weights (id, weights_json, updated_at) \
153 VALUES (1, ?, strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))"
154 ))
155 .bind(weights_json)
156 .execute(&self.pool)
157 .await?;
158 Ok(())
159 }
160
161 pub async fn load_compression_predictor_weights(&self) -> Result<Option<String>, MemoryError> {
169 let row: Option<(String,)> = zeph_db::query_as(sql!(
170 "SELECT weights_json FROM compression_predictor_weights WHERE id = 1"
171 ))
172 .fetch_optional(&self.pool)
173 .await?;
174 Ok(row.map(|(json,)| json))
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 async fn make_store() -> (SqliteStore, ConversationId) {
183 let store = SqliteStore::new(":memory:")
184 .await
185 .expect("SqliteStore::new");
186 let cid = store
187 .create_conversation()
188 .await
189 .expect("create_conversation");
190 (store, cid)
191 }
192
193 #[tokio::test]
194 async fn record_and_count_training_data() {
195 let (store, cid) = make_store().await;
196 store
197 .record_compression_training(cid, 0.5, 20, 150.0, 0.3, 0.75)
198 .await
199 .expect("record");
200 let count = store
201 .count_compression_training_records()
202 .await
203 .expect("count");
204 assert_eq!(count, 1);
205 }
206
207 #[tokio::test]
208 async fn batch_returns_records() {
209 let (store, cid) = make_store().await;
210 store
211 .record_compression_training(cid, 0.5, 20, 150.0, 0.3, 0.75)
212 .await
213 .expect("record");
214 let batch = store
215 .get_compression_training_batch(10)
216 .await
217 .expect("batch");
218 assert_eq!(batch.len(), 1);
219 assert!((batch[0].compression_ratio - 0.5).abs() < 1e-4);
220 assert!((batch[0].probe_score - 0.75).abs() < 1e-4);
221 }
222
223 #[tokio::test]
224 async fn trim_keeps_most_recent() {
225 let (store, cid) = make_store().await;
226 for _ in 0..5 {
227 store
228 .record_compression_training(cid, 0.5, 20, 150.0, 0.3, 0.75)
229 .await
230 .expect("record");
231 }
232 store.trim_compression_training_data(2).await.expect("trim");
233 let count = store
234 .count_compression_training_records()
235 .await
236 .expect("count");
237 assert_eq!(count, 2);
238 }
239
240 #[tokio::test]
241 async fn save_and_load_weights() {
242 let (store, _) = make_store().await;
243 store
244 .save_compression_predictor_weights(r#"{"weights":[0.1,0.2,0.3,0.4],"bias":0.0}"#)
245 .await
246 .expect("save");
247 let loaded = store
248 .load_compression_predictor_weights()
249 .await
250 .expect("load");
251 assert!(loaded.is_some());
252 assert!(loaded.unwrap().contains("weights"));
253 }
254
255 #[tokio::test]
256 async fn load_weights_returns_none_when_empty() {
257 let (store, _) = make_store().await;
258 let loaded = store
259 .load_compression_predictor_weights()
260 .await
261 .expect("load");
262 assert!(loaded.is_none());
263 }
264}