Skip to main content

zeph_memory/store/
compression_predictor.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! SQLite-backed store for the compression quality predictor (#2460).
5//!
6//! Provides persistence for training samples and model weights following the
7//! same pattern as `admission_training.rs`.
8
9use zeph_db::sql;
10
11use crate::error::MemoryError;
12use crate::store::SqliteStore;
13use crate::types::ConversationId;
14
15/// A single training record for the compression quality predictor.
16#[derive(Debug, Clone)]
17pub struct CompressionTrainingRecord {
18    pub id: i64,
19    pub conversation_id: ConversationId,
20    pub compression_ratio: f32,
21    pub message_count: i64,
22    pub avg_message_length: f32,
23    pub tool_output_fraction: f32,
24    pub probe_score: f32,
25    pub created_at: String,
26}
27
28impl SqliteStore {
29    /// Record a compression probe result for predictor training.
30    ///
31    /// # Errors
32    ///
33    /// Returns an error if the database insert fails.
34    pub async fn record_compression_training(
35        &self,
36        conversation_id: ConversationId,
37        compression_ratio: f32,
38        message_count: i64,
39        avg_message_length: f32,
40        tool_output_fraction: f32,
41        probe_score: f32,
42    ) -> Result<i64, MemoryError> {
43        let id = zeph_db::query_scalar(sql!(
44            "INSERT INTO compression_predictor_training \
45             (conversation_id, compression_ratio, message_count, \
46              avg_message_length, tool_output_fraction, probe_score) \
47             VALUES (?, ?, ?, ?, ?, ?) \
48             RETURNING id"
49        ))
50        .bind(conversation_id.0)
51        .bind(f64::from(compression_ratio))
52        .bind(message_count)
53        .bind(f64::from(avg_message_length))
54        .bind(f64::from(tool_output_fraction))
55        .bind(f64::from(probe_score))
56        .fetch_one(&self.pool)
57        .await?;
58        Ok(id)
59    }
60
61    /// Count total compression training records.
62    ///
63    /// # Errors
64    ///
65    /// Returns an error if the query fails.
66    pub async fn count_compression_training_records(&self) -> Result<i64, MemoryError> {
67        let count =
68            zeph_db::query_scalar(sql!("SELECT COUNT(*) FROM compression_predictor_training"))
69                .fetch_one(&self.pool)
70                .await?;
71        Ok(count)
72    }
73
74    /// Get the most recent `limit` training records for model training (sliding window).
75    ///
76    /// # Errors
77    ///
78    /// Returns an error if the query fails.
79    pub async fn get_compression_training_batch(
80        &self,
81        limit: usize,
82    ) -> Result<Vec<CompressionTrainingRecord>, MemoryError> {
83        let limit = i64::try_from(limit).unwrap_or(i64::MAX);
84        let rows = zeph_db::query_as::<_, (i64, i64, f64, i64, f64, f64, f64, String)>(sql!(
85            "SELECT id, conversation_id, compression_ratio, message_count, \
86                    avg_message_length, tool_output_fraction, probe_score, created_at \
87             FROM compression_predictor_training \
88             ORDER BY created_at DESC \
89             LIMIT ?"
90        ))
91        .bind(limit)
92        .fetch_all(&self.pool)
93        .await?;
94
95        Ok(rows
96            .into_iter()
97            .map(
98                |(id, cid, ratio, msg_count, avg_len, tool_frac, score, created_at)| {
99                    CompressionTrainingRecord {
100                        id,
101                        conversation_id: ConversationId(cid),
102                        #[expect(clippy::cast_possible_truncation)]
103                        compression_ratio: ratio as f32,
104                        message_count: msg_count,
105                        #[expect(clippy::cast_possible_truncation)]
106                        avg_message_length: avg_len as f32,
107                        #[expect(clippy::cast_possible_truncation)]
108                        tool_output_fraction: tool_frac as f32,
109                        #[expect(clippy::cast_possible_truncation)]
110                        probe_score: score as f32,
111                        created_at,
112                    }
113                },
114            )
115            .collect())
116    }
117
118    /// Trim compression training records, keeping the most recent `keep_recent`.
119    ///
120    /// # Errors
121    ///
122    /// Returns an error if the delete fails.
123    pub async fn trim_compression_training_data(
124        &self,
125        keep_recent: usize,
126    ) -> Result<(), MemoryError> {
127        let keep = i64::try_from(keep_recent).unwrap_or(i64::MAX);
128        zeph_db::query(sql!(
129            "DELETE FROM compression_predictor_training \
130             WHERE id NOT IN ( \
131                 SELECT id FROM compression_predictor_training \
132                 ORDER BY created_at DESC \
133                 LIMIT ? \
134             )"
135        ))
136        .bind(keep)
137        .execute(&self.pool)
138        .await?;
139        Ok(())
140    }
141
142    /// Save compression predictor weights (singleton row, id = 1).
143    ///
144    /// # Errors
145    ///
146    /// Returns an error if the upsert fails.
147    pub async fn save_compression_predictor_weights(
148        &self,
149        weights_json: &str,
150    ) -> Result<(), MemoryError> {
151        zeph_db::query(sql!(
152            "INSERT OR REPLACE INTO compression_predictor_weights (id, weights_json, updated_at) \
153             VALUES (1, ?, strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))"
154        ))
155        .bind(weights_json)
156        .execute(&self.pool)
157        .await?;
158        Ok(())
159    }
160
161    /// Load compression predictor weights.
162    ///
163    /// Returns `None` if no weights have been saved yet.
164    ///
165    /// # Errors
166    ///
167    /// Returns an error if the query fails.
168    pub async fn load_compression_predictor_weights(&self) -> Result<Option<String>, MemoryError> {
169        let row: Option<(String,)> = zeph_db::query_as(sql!(
170            "SELECT weights_json FROM compression_predictor_weights WHERE id = 1"
171        ))
172        .fetch_optional(&self.pool)
173        .await?;
174        Ok(row.map(|(json,)| json))
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    async fn make_store() -> (SqliteStore, ConversationId) {
183        let store = SqliteStore::new(":memory:")
184            .await
185            .expect("SqliteStore::new");
186        let cid = store
187            .create_conversation()
188            .await
189            .expect("create_conversation");
190        (store, cid)
191    }
192
193    // Each test calls `store.pool().close().await` before returning so that the
194    // sqlx-sqlite background connection threads fully exit before nextest measures
195    // the thread count. Without an explicit close, the pool's Drop only signals the
196    // threads to stop; they may still be alive when a concurrently-running plain
197    // `#[test]` (e.g. `compression_predictor::tests::*`) is executing, causing
198    // nextest to attribute those lingering threads as a LEAK for that innocent test.
199
200    #[tokio::test]
201    async fn record_and_count_training_data() {
202        let (store, cid) = make_store().await;
203        store
204            .record_compression_training(cid, 0.5, 20, 150.0, 0.3, 0.75)
205            .await
206            .expect("record");
207        let count = store
208            .count_compression_training_records()
209            .await
210            .expect("count");
211        assert_eq!(count, 1);
212        store.pool().close().await;
213    }
214
215    #[tokio::test]
216    async fn batch_returns_records() {
217        let (store, cid) = make_store().await;
218        store
219            .record_compression_training(cid, 0.5, 20, 150.0, 0.3, 0.75)
220            .await
221            .expect("record");
222        let batch = store
223            .get_compression_training_batch(10)
224            .await
225            .expect("batch");
226        assert_eq!(batch.len(), 1);
227        assert!((batch[0].compression_ratio - 0.5).abs() < 1e-4);
228        assert!((batch[0].probe_score - 0.75).abs() < 1e-4);
229        store.pool().close().await;
230    }
231
232    #[tokio::test]
233    async fn trim_keeps_most_recent() {
234        let (store, cid) = make_store().await;
235        for _ in 0..5 {
236            store
237                .record_compression_training(cid, 0.5, 20, 150.0, 0.3, 0.75)
238                .await
239                .expect("record");
240        }
241        store.trim_compression_training_data(2).await.expect("trim");
242        let count = store
243            .count_compression_training_records()
244            .await
245            .expect("count");
246        assert_eq!(count, 2);
247        store.pool().close().await;
248    }
249
250    #[tokio::test]
251    async fn save_and_load_weights() {
252        let (store, _) = make_store().await;
253        store
254            .save_compression_predictor_weights(r#"{"weights":[0.1,0.2,0.3,0.4],"bias":0.0}"#)
255            .await
256            .expect("save");
257        let loaded = store
258            .load_compression_predictor_weights()
259            .await
260            .expect("load");
261        assert!(loaded.is_some());
262        assert!(loaded.unwrap().contains("weights"));
263        store.pool().close().await;
264    }
265
266    #[tokio::test]
267    async fn load_weights_returns_none_when_empty() {
268        let (store, _) = make_store().await;
269        let loaded = store
270            .load_compression_predictor_weights()
271            .await
272            .expect("load");
273        assert!(loaded.is_none());
274        store.pool().close().await;
275    }
276}