Skip to main content

zeph_memory/sqlite/
corrections.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use super::SqliteStore;
5use crate::error::MemoryError;
6
7#[derive(Debug, Clone)]
8pub struct UserCorrectionRow {
9    pub id: i64,
10    pub session_id: Option<i64>,
11    pub original_output: String,
12    pub correction_text: String,
13    pub skill_name: Option<String>,
14    pub correction_kind: String,
15    pub created_at: String,
16}
17
18type CorrectionTuple = (
19    i64,
20    Option<i64>,
21    String,
22    String,
23    Option<String>,
24    String,
25    String,
26);
27
28fn row_from_tuple(t: CorrectionTuple) -> UserCorrectionRow {
29    UserCorrectionRow {
30        id: t.0,
31        session_id: t.1,
32        original_output: t.2,
33        correction_text: t.3,
34        skill_name: t.4,
35        correction_kind: t.5,
36        created_at: t.6,
37    }
38}
39
40impl SqliteStore {
41    /// Store a user correction and return the new row ID.
42    ///
43    /// # Errors
44    ///
45    /// Returns an error if the insert fails.
46    pub async fn store_user_correction(
47        &self,
48        session_id: Option<i64>,
49        original_output: &str,
50        correction_text: &str,
51        skill_name: Option<&str>,
52        correction_kind: &str,
53    ) -> Result<i64, MemoryError> {
54        let row: (i64,) = sqlx::query_as(
55            "INSERT INTO user_corrections \
56             (session_id, original_output, correction_text, skill_name, correction_kind) \
57             VALUES (?, ?, ?, ?, ?) RETURNING id",
58        )
59        .bind(session_id)
60        .bind(original_output)
61        .bind(correction_text)
62        .bind(skill_name)
63        .bind(correction_kind)
64        .fetch_one(&self.pool)
65        .await?;
66        Ok(row.0)
67    }
68
69    /// Load corrections for a specific skill, newest first.
70    ///
71    /// # Errors
72    ///
73    /// Returns an error if the query fails.
74    pub async fn load_corrections_for_skill(
75        &self,
76        skill_name: &str,
77        limit: u32,
78    ) -> Result<Vec<UserCorrectionRow>, MemoryError> {
79        let rows: Vec<CorrectionTuple> = sqlx::query_as(
80            "SELECT id, session_id, original_output, correction_text, \
81             skill_name, correction_kind, created_at \
82             FROM user_corrections WHERE skill_name = ? \
83             ORDER BY id DESC LIMIT ?",
84        )
85        .bind(skill_name)
86        .bind(limit)
87        .fetch_all(&self.pool)
88        .await?;
89        Ok(rows.into_iter().map(row_from_tuple).collect())
90    }
91
92    /// Load the most recent corrections across all skills.
93    ///
94    /// # Errors
95    ///
96    /// Returns an error if the query fails.
97    pub async fn load_recent_corrections(
98        &self,
99        limit: u32,
100    ) -> Result<Vec<UserCorrectionRow>, MemoryError> {
101        let rows: Vec<CorrectionTuple> = sqlx::query_as(
102            "SELECT id, session_id, original_output, correction_text, \
103             skill_name, correction_kind, created_at \
104             FROM user_corrections ORDER BY id DESC LIMIT ?",
105        )
106        .bind(limit)
107        .fetch_all(&self.pool)
108        .await?;
109        Ok(rows.into_iter().map(row_from_tuple).collect())
110    }
111
112    /// Load a correction by ID (used by vector retrieval path).
113    ///
114    /// # Errors
115    ///
116    /// Returns an error if the query fails.
117    pub async fn load_corrections_for_id(
118        &self,
119        id: i64,
120    ) -> Result<Vec<UserCorrectionRow>, MemoryError> {
121        let rows: Vec<CorrectionTuple> = sqlx::query_as(
122            "SELECT id, session_id, original_output, correction_text, \
123             skill_name, correction_kind, created_at \
124             FROM user_corrections WHERE id = ?",
125        )
126        .bind(id)
127        .fetch_all(&self.pool)
128        .await?;
129        Ok(rows.into_iter().map(row_from_tuple).collect())
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    async fn test_store() -> SqliteStore {
138        SqliteStore::new(":memory:").await.unwrap()
139    }
140
141    #[tokio::test]
142    async fn store_and_load_correction() {
143        let store = test_store().await;
144
145        let id = store
146            .store_user_correction(
147                Some(1),
148                "original assistant output",
149                "that was wrong, try again",
150                Some("git"),
151                "explicit_rejection",
152            )
153            .await
154            .unwrap();
155        assert!(id > 0);
156
157        let rows = store.load_corrections_for_skill("git", 10).await.unwrap();
158        assert_eq!(rows.len(), 1);
159        assert_eq!(rows[0].correction_kind, "explicit_rejection");
160        assert_eq!(rows[0].skill_name.as_deref(), Some("git"));
161    }
162
163    #[tokio::test]
164    async fn load_recent_corrections_ordered() {
165        let store = test_store().await;
166
167        store
168            .store_user_correction(None, "out1", "fix1", None, "explicit_rejection")
169            .await
170            .unwrap();
171        store
172            .store_user_correction(None, "out2", "fix2", None, "alternative_request")
173            .await
174            .unwrap();
175
176        let rows = store.load_recent_corrections(10).await.unwrap();
177        assert_eq!(rows.len(), 2);
178        assert_eq!(rows[0].correction_text, "fix2");
179        assert_eq!(rows[1].correction_text, "fix1");
180    }
181
182    #[tokio::test]
183    async fn load_corrections_for_id_returns_single() {
184        let store = test_store().await;
185
186        let id = store
187            .store_user_correction(None, "out", "fix", Some("docker"), "repetition")
188            .await
189            .unwrap();
190
191        let rows = store.load_corrections_for_id(id).await.unwrap();
192        assert_eq!(rows.len(), 1);
193        assert_eq!(rows[0].id, id);
194    }
195
196    #[tokio::test]
197    async fn load_corrections_for_id_unknown_returns_empty() {
198        let store = test_store().await;
199        let rows = store.load_corrections_for_id(9999).await.unwrap();
200        assert!(rows.is_empty());
201    }
202
203    #[tokio::test]
204    async fn load_corrections_for_skill_unknown_returns_empty() {
205        let store = test_store().await;
206        let rows = store
207            .load_corrections_for_skill("nonexistent", 10)
208            .await
209            .unwrap();
210        assert!(rows.is_empty());
211    }
212
213    #[tokio::test]
214    async fn load_recent_corrections_empty_table() {
215        let store = test_store().await;
216        let rows = store.load_recent_corrections(10).await.unwrap();
217        assert!(rows.is_empty());
218    }
219
220    #[tokio::test]
221    async fn store_correction_without_skill_name() {
222        let store = test_store().await;
223
224        let id = store
225            .store_user_correction(
226                None,
227                "original output",
228                "correction text",
229                None,
230                "repetition",
231            )
232            .await
233            .unwrap();
234        assert!(id > 0);
235
236        let rows = store.load_recent_corrections(10).await.unwrap();
237        assert_eq!(rows.len(), 1);
238        assert!(rows[0].skill_name.is_none());
239        assert_eq!(rows[0].correction_kind, "repetition");
240    }
241
242    #[tokio::test]
243    async fn load_corrections_for_skill_respects_limit() {
244        let store = test_store().await;
245
246        for i in 0..5 {
247            store
248                .store_user_correction(
249                    None,
250                    &format!("out{i}"),
251                    &format!("fix{i}"),
252                    Some("git"),
253                    "explicit_rejection",
254                )
255                .await
256                .unwrap();
257        }
258
259        let rows = store.load_corrections_for_skill("git", 3).await.unwrap();
260        assert_eq!(rows.len(), 3);
261    }
262}