Skip to main content

zeph_memory/store/
corrections.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use super::SqliteStore;
5use crate::error::MemoryError;
6#[allow(unused_imports)]
7use zeph_db::sql;
8
9#[derive(Debug, Clone)]
10pub struct UserCorrectionRow {
11    pub id: i64,
12    pub session_id: Option<i64>,
13    pub original_output: String,
14    pub correction_text: String,
15    pub skill_name: Option<String>,
16    pub correction_kind: String,
17    pub created_at: String,
18}
19
20type CorrectionTuple = (
21    i64,
22    Option<i64>,
23    String,
24    String,
25    Option<String>,
26    String,
27    String,
28);
29
30fn row_from_tuple(t: CorrectionTuple) -> UserCorrectionRow {
31    UserCorrectionRow {
32        id: t.0,
33        session_id: t.1,
34        original_output: t.2,
35        correction_text: t.3,
36        skill_name: t.4,
37        correction_kind: t.5,
38        created_at: t.6,
39    }
40}
41
42impl SqliteStore {
43    /// Store a user correction and return the new row ID.
44    ///
45    /// # Errors
46    ///
47    /// Returns an error if the insert fails.
48    pub async fn store_user_correction(
49        &self,
50        session_id: Option<i64>,
51        original_output: &str,
52        correction_text: &str,
53        skill_name: Option<&str>,
54        correction_kind: &str,
55    ) -> Result<i64, MemoryError> {
56        let row: (i64,) = zeph_db::query_as(sql!(
57            "INSERT INTO user_corrections \
58             (session_id, original_output, correction_text, skill_name, correction_kind) \
59             VALUES (?, ?, ?, ?, ?) RETURNING id"
60        ))
61        .bind(session_id)
62        .bind(original_output)
63        .bind(correction_text)
64        .bind(skill_name)
65        .bind(correction_kind)
66        .fetch_one(&self.pool)
67        .await?;
68        Ok(row.0)
69    }
70
71    /// Load corrections for a specific skill, newest first.
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if the query fails.
76    pub async fn load_corrections_for_skill(
77        &self,
78        skill_name: &str,
79        limit: u32,
80    ) -> Result<Vec<UserCorrectionRow>, MemoryError> {
81        let rows: Vec<CorrectionTuple> = zeph_db::query_as(sql!(
82            "SELECT id, session_id, original_output, correction_text, \
83             skill_name, correction_kind, created_at \
84             FROM user_corrections WHERE skill_name = ? \
85             ORDER BY id DESC LIMIT ?"
86        ))
87        .bind(skill_name)
88        .bind(limit)
89        .fetch_all(&self.pool)
90        .await?;
91        Ok(rows.into_iter().map(row_from_tuple).collect())
92    }
93
94    /// Load the most recent corrections across all skills.
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if the query fails.
99    pub async fn load_recent_corrections(
100        &self,
101        limit: u32,
102    ) -> Result<Vec<UserCorrectionRow>, MemoryError> {
103        let rows: Vec<CorrectionTuple> = zeph_db::query_as(sql!(
104            "SELECT id, session_id, original_output, correction_text, \
105             skill_name, correction_kind, created_at \
106             FROM user_corrections ORDER BY id DESC LIMIT ?"
107        ))
108        .bind(limit)
109        .fetch_all(&self.pool)
110        .await?;
111        Ok(rows.into_iter().map(row_from_tuple).collect())
112    }
113
114    /// Load a correction by ID (used by vector retrieval path).
115    ///
116    /// # Errors
117    ///
118    /// Returns an error if the query fails.
119    pub async fn load_corrections_for_id(
120        &self,
121        id: i64,
122    ) -> Result<Vec<UserCorrectionRow>, MemoryError> {
123        let rows: Vec<CorrectionTuple> = zeph_db::query_as(sql!(
124            "SELECT id, session_id, original_output, correction_text, \
125             skill_name, correction_kind, created_at \
126             FROM user_corrections WHERE id = ?"
127        ))
128        .bind(id)
129        .fetch_all(&self.pool)
130        .await?;
131        Ok(rows.into_iter().map(row_from_tuple).collect())
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    async fn test_store() -> SqliteStore {
140        SqliteStore::new(":memory:").await.unwrap()
141    }
142
143    #[tokio::test]
144    async fn store_and_load_correction() {
145        let store = test_store().await;
146
147        let id = store
148            .store_user_correction(
149                Some(1),
150                "original assistant output",
151                "that was wrong, try again",
152                Some("git"),
153                "explicit_rejection",
154            )
155            .await
156            .unwrap();
157        assert!(id > 0);
158
159        let rows = store.load_corrections_for_skill("git", 10).await.unwrap();
160        assert_eq!(rows.len(), 1);
161        assert_eq!(rows[0].correction_kind, "explicit_rejection");
162        assert_eq!(rows[0].skill_name.as_deref(), Some("git"));
163    }
164
165    #[tokio::test]
166    async fn load_recent_corrections_ordered() {
167        let store = test_store().await;
168
169        store
170            .store_user_correction(None, "out1", "fix1", None, "explicit_rejection")
171            .await
172            .unwrap();
173        store
174            .store_user_correction(None, "out2", "fix2", None, "alternative_request")
175            .await
176            .unwrap();
177
178        let rows = store.load_recent_corrections(10).await.unwrap();
179        assert_eq!(rows.len(), 2);
180        assert_eq!(rows[0].correction_text, "fix2");
181        assert_eq!(rows[1].correction_text, "fix1");
182    }
183
184    #[tokio::test]
185    async fn load_corrections_for_id_returns_single() {
186        let store = test_store().await;
187
188        let id = store
189            .store_user_correction(None, "out", "fix", Some("docker"), "repetition")
190            .await
191            .unwrap();
192
193        let rows = store.load_corrections_for_id(id).await.unwrap();
194        assert_eq!(rows.len(), 1);
195        assert_eq!(rows[0].id, id);
196    }
197
198    #[tokio::test]
199    async fn load_corrections_for_id_unknown_returns_empty() {
200        let store = test_store().await;
201        let rows = store.load_corrections_for_id(9999).await.unwrap();
202        assert!(rows.is_empty());
203    }
204
205    #[tokio::test]
206    async fn load_corrections_for_skill_unknown_returns_empty() {
207        let store = test_store().await;
208        let rows = store
209            .load_corrections_for_skill("nonexistent", 10)
210            .await
211            .unwrap();
212        assert!(rows.is_empty());
213    }
214
215    #[tokio::test]
216    async fn load_recent_corrections_empty_table() {
217        let store = test_store().await;
218        let rows = store.load_recent_corrections(10).await.unwrap();
219        assert!(rows.is_empty());
220    }
221
222    #[tokio::test]
223    async fn store_correction_without_skill_name() {
224        let store = test_store().await;
225
226        let id = store
227            .store_user_correction(
228                None,
229                "original output",
230                "correction text",
231                None,
232                "repetition",
233            )
234            .await
235            .unwrap();
236        assert!(id > 0);
237
238        let rows = store.load_recent_corrections(10).await.unwrap();
239        assert_eq!(rows.len(), 1);
240        assert!(rows[0].skill_name.is_none());
241        assert_eq!(rows[0].correction_kind, "repetition");
242    }
243
244    #[tokio::test]
245    async fn load_corrections_for_skill_respects_limit() {
246        let store = test_store().await;
247
248        for i in 0..5 {
249            store
250                .store_user_correction(
251                    None,
252                    &format!("out{i}"),
253                    &format!("fix{i}"),
254                    Some("git"),
255                    "explicit_rejection",
256                )
257                .await
258                .unwrap();
259        }
260
261        let rows = store.load_corrections_for_skill("git", 3).await.unwrap();
262        assert_eq!(rows.len(), 3);
263    }
264}