1use super::SqliteStore;
5use crate::error::MemoryError;
6#[allow(unused_imports)]
7use zeph_db::sql;
8
9#[derive(Debug, Clone)]
10pub struct UserCorrectionRow {
11 pub id: i64,
12 pub session_id: Option<i64>,
13 pub original_output: String,
14 pub correction_text: String,
15 pub skill_name: Option<String>,
16 pub correction_kind: String,
17 pub created_at: String,
18}
19
20type CorrectionTuple = (
21 i64,
22 Option<i64>,
23 String,
24 String,
25 Option<String>,
26 String,
27 String,
28);
29
30fn row_from_tuple(t: CorrectionTuple) -> UserCorrectionRow {
31 UserCorrectionRow {
32 id: t.0,
33 session_id: t.1,
34 original_output: t.2,
35 correction_text: t.3,
36 skill_name: t.4,
37 correction_kind: t.5,
38 created_at: t.6,
39 }
40}
41
42impl SqliteStore {
43 pub async fn store_user_correction(
49 &self,
50 session_id: Option<i64>,
51 original_output: &str,
52 correction_text: &str,
53 skill_name: Option<&str>,
54 correction_kind: &str,
55 ) -> Result<i64, MemoryError> {
56 let row: (i64,) = zeph_db::query_as(sql!(
57 "INSERT INTO user_corrections \
58 (session_id, original_output, correction_text, skill_name, correction_kind) \
59 VALUES (?, ?, ?, ?, ?) RETURNING id"
60 ))
61 .bind(session_id)
62 .bind(original_output)
63 .bind(correction_text)
64 .bind(skill_name)
65 .bind(correction_kind)
66 .fetch_one(&self.pool)
67 .await?;
68 Ok(row.0)
69 }
70
71 pub async fn load_corrections_for_skill(
77 &self,
78 skill_name: &str,
79 limit: u32,
80 ) -> Result<Vec<UserCorrectionRow>, MemoryError> {
81 let rows: Vec<CorrectionTuple> = zeph_db::query_as(sql!(
82 "SELECT id, session_id, original_output, correction_text, \
83 skill_name, correction_kind, created_at \
84 FROM user_corrections WHERE skill_name = ? \
85 ORDER BY id DESC LIMIT ?"
86 ))
87 .bind(skill_name)
88 .bind(limit)
89 .fetch_all(&self.pool)
90 .await?;
91 Ok(rows.into_iter().map(row_from_tuple).collect())
92 }
93
94 pub async fn load_recent_corrections(
100 &self,
101 limit: u32,
102 ) -> Result<Vec<UserCorrectionRow>, MemoryError> {
103 let rows: Vec<CorrectionTuple> = zeph_db::query_as(sql!(
104 "SELECT id, session_id, original_output, correction_text, \
105 skill_name, correction_kind, created_at \
106 FROM user_corrections ORDER BY id DESC LIMIT ?"
107 ))
108 .bind(limit)
109 .fetch_all(&self.pool)
110 .await?;
111 Ok(rows.into_iter().map(row_from_tuple).collect())
112 }
113
114 pub async fn load_corrections_for_id(
120 &self,
121 id: i64,
122 ) -> Result<Vec<UserCorrectionRow>, MemoryError> {
123 let rows: Vec<CorrectionTuple> = zeph_db::query_as(sql!(
124 "SELECT id, session_id, original_output, correction_text, \
125 skill_name, correction_kind, created_at \
126 FROM user_corrections WHERE id = ?"
127 ))
128 .bind(id)
129 .fetch_all(&self.pool)
130 .await?;
131 Ok(rows.into_iter().map(row_from_tuple).collect())
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 async fn test_store() -> SqliteStore {
140 SqliteStore::new(":memory:").await.unwrap()
141 }
142
143 #[tokio::test]
144 async fn store_and_load_correction() {
145 let store = test_store().await;
146
147 let id = store
148 .store_user_correction(
149 Some(1),
150 "original assistant output",
151 "that was wrong, try again",
152 Some("git"),
153 "explicit_rejection",
154 )
155 .await
156 .unwrap();
157 assert!(id > 0);
158
159 let rows = store.load_corrections_for_skill("git", 10).await.unwrap();
160 assert_eq!(rows.len(), 1);
161 assert_eq!(rows[0].correction_kind, "explicit_rejection");
162 assert_eq!(rows[0].skill_name.as_deref(), Some("git"));
163 }
164
165 #[tokio::test]
166 async fn load_recent_corrections_ordered() {
167 let store = test_store().await;
168
169 store
170 .store_user_correction(None, "out1", "fix1", None, "explicit_rejection")
171 .await
172 .unwrap();
173 store
174 .store_user_correction(None, "out2", "fix2", None, "alternative_request")
175 .await
176 .unwrap();
177
178 let rows = store.load_recent_corrections(10).await.unwrap();
179 assert_eq!(rows.len(), 2);
180 assert_eq!(rows[0].correction_text, "fix2");
181 assert_eq!(rows[1].correction_text, "fix1");
182 }
183
184 #[tokio::test]
185 async fn load_corrections_for_id_returns_single() {
186 let store = test_store().await;
187
188 let id = store
189 .store_user_correction(None, "out", "fix", Some("docker"), "repetition")
190 .await
191 .unwrap();
192
193 let rows = store.load_corrections_for_id(id).await.unwrap();
194 assert_eq!(rows.len(), 1);
195 assert_eq!(rows[0].id, id);
196 }
197
198 #[tokio::test]
199 async fn load_corrections_for_id_unknown_returns_empty() {
200 let store = test_store().await;
201 let rows = store.load_corrections_for_id(9999).await.unwrap();
202 assert!(rows.is_empty());
203 }
204
205 #[tokio::test]
206 async fn load_corrections_for_skill_unknown_returns_empty() {
207 let store = test_store().await;
208 let rows = store
209 .load_corrections_for_skill("nonexistent", 10)
210 .await
211 .unwrap();
212 assert!(rows.is_empty());
213 }
214
215 #[tokio::test]
216 async fn load_recent_corrections_empty_table() {
217 let store = test_store().await;
218 let rows = store.load_recent_corrections(10).await.unwrap();
219 assert!(rows.is_empty());
220 }
221
222 #[tokio::test]
223 async fn store_correction_without_skill_name() {
224 let store = test_store().await;
225
226 let id = store
227 .store_user_correction(
228 None,
229 "original output",
230 "correction text",
231 None,
232 "repetition",
233 )
234 .await
235 .unwrap();
236 assert!(id > 0);
237
238 let rows = store.load_recent_corrections(10).await.unwrap();
239 assert_eq!(rows.len(), 1);
240 assert!(rows[0].skill_name.is_none());
241 assert_eq!(rows[0].correction_kind, "repetition");
242 }
243
244 #[tokio::test]
245 async fn load_corrections_for_skill_respects_limit() {
246 let store = test_store().await;
247
248 for i in 0..5 {
249 store
250 .store_user_correction(
251 None,
252 &format!("out{i}"),
253 &format!("fix{i}"),
254 Some("git"),
255 "explicit_rejection",
256 )
257 .await
258 .unwrap();
259 }
260
261 let rows = store.load_corrections_for_skill("git", 3).await.unwrap();
262 assert_eq!(rows.len(), 3);
263 }
264}