1use super::SqliteStore;
5use crate::error::MemoryError;
6
7#[derive(Debug, Clone)]
8pub struct UserCorrectionRow {
9 pub id: i64,
10 pub session_id: Option<i64>,
11 pub original_output: String,
12 pub correction_text: String,
13 pub skill_name: Option<String>,
14 pub correction_kind: String,
15 pub created_at: String,
16}
17
18type CorrectionTuple = (
19 i64,
20 Option<i64>,
21 String,
22 String,
23 Option<String>,
24 String,
25 String,
26);
27
28fn row_from_tuple(t: CorrectionTuple) -> UserCorrectionRow {
29 UserCorrectionRow {
30 id: t.0,
31 session_id: t.1,
32 original_output: t.2,
33 correction_text: t.3,
34 skill_name: t.4,
35 correction_kind: t.5,
36 created_at: t.6,
37 }
38}
39
40impl SqliteStore {
41 pub async fn store_user_correction(
47 &self,
48 session_id: Option<i64>,
49 original_output: &str,
50 correction_text: &str,
51 skill_name: Option<&str>,
52 correction_kind: &str,
53 ) -> Result<i64, MemoryError> {
54 let row: (i64,) = sqlx::query_as(
55 "INSERT INTO user_corrections \
56 (session_id, original_output, correction_text, skill_name, correction_kind) \
57 VALUES (?, ?, ?, ?, ?) RETURNING id",
58 )
59 .bind(session_id)
60 .bind(original_output)
61 .bind(correction_text)
62 .bind(skill_name)
63 .bind(correction_kind)
64 .fetch_one(&self.pool)
65 .await?;
66 Ok(row.0)
67 }
68
69 pub async fn load_corrections_for_skill(
75 &self,
76 skill_name: &str,
77 limit: u32,
78 ) -> Result<Vec<UserCorrectionRow>, MemoryError> {
79 let rows: Vec<CorrectionTuple> = sqlx::query_as(
80 "SELECT id, session_id, original_output, correction_text, \
81 skill_name, correction_kind, created_at \
82 FROM user_corrections WHERE skill_name = ? \
83 ORDER BY id DESC LIMIT ?",
84 )
85 .bind(skill_name)
86 .bind(limit)
87 .fetch_all(&self.pool)
88 .await?;
89 Ok(rows.into_iter().map(row_from_tuple).collect())
90 }
91
92 pub async fn load_recent_corrections(
98 &self,
99 limit: u32,
100 ) -> Result<Vec<UserCorrectionRow>, MemoryError> {
101 let rows: Vec<CorrectionTuple> = sqlx::query_as(
102 "SELECT id, session_id, original_output, correction_text, \
103 skill_name, correction_kind, created_at \
104 FROM user_corrections ORDER BY id DESC LIMIT ?",
105 )
106 .bind(limit)
107 .fetch_all(&self.pool)
108 .await?;
109 Ok(rows.into_iter().map(row_from_tuple).collect())
110 }
111
112 pub async fn load_corrections_for_id(
118 &self,
119 id: i64,
120 ) -> Result<Vec<UserCorrectionRow>, MemoryError> {
121 let rows: Vec<CorrectionTuple> = sqlx::query_as(
122 "SELECT id, session_id, original_output, correction_text, \
123 skill_name, correction_kind, created_at \
124 FROM user_corrections WHERE id = ?",
125 )
126 .bind(id)
127 .fetch_all(&self.pool)
128 .await?;
129 Ok(rows.into_iter().map(row_from_tuple).collect())
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 async fn test_store() -> SqliteStore {
138 SqliteStore::new(":memory:").await.unwrap()
139 }
140
141 #[tokio::test]
142 async fn store_and_load_correction() {
143 let store = test_store().await;
144
145 let id = store
146 .store_user_correction(
147 Some(1),
148 "original assistant output",
149 "that was wrong, try again",
150 Some("git"),
151 "explicit_rejection",
152 )
153 .await
154 .unwrap();
155 assert!(id > 0);
156
157 let rows = store.load_corrections_for_skill("git", 10).await.unwrap();
158 assert_eq!(rows.len(), 1);
159 assert_eq!(rows[0].correction_kind, "explicit_rejection");
160 assert_eq!(rows[0].skill_name.as_deref(), Some("git"));
161 }
162
163 #[tokio::test]
164 async fn load_recent_corrections_ordered() {
165 let store = test_store().await;
166
167 store
168 .store_user_correction(None, "out1", "fix1", None, "explicit_rejection")
169 .await
170 .unwrap();
171 store
172 .store_user_correction(None, "out2", "fix2", None, "alternative_request")
173 .await
174 .unwrap();
175
176 let rows = store.load_recent_corrections(10).await.unwrap();
177 assert_eq!(rows.len(), 2);
178 assert_eq!(rows[0].correction_text, "fix2");
179 assert_eq!(rows[1].correction_text, "fix1");
180 }
181
182 #[tokio::test]
183 async fn load_corrections_for_id_returns_single() {
184 let store = test_store().await;
185
186 let id = store
187 .store_user_correction(None, "out", "fix", Some("docker"), "repetition")
188 .await
189 .unwrap();
190
191 let rows = store.load_corrections_for_id(id).await.unwrap();
192 assert_eq!(rows.len(), 1);
193 assert_eq!(rows[0].id, id);
194 }
195
196 #[tokio::test]
197 async fn load_corrections_for_id_unknown_returns_empty() {
198 let store = test_store().await;
199 let rows = store.load_corrections_for_id(9999).await.unwrap();
200 assert!(rows.is_empty());
201 }
202
203 #[tokio::test]
204 async fn load_corrections_for_skill_unknown_returns_empty() {
205 let store = test_store().await;
206 let rows = store
207 .load_corrections_for_skill("nonexistent", 10)
208 .await
209 .unwrap();
210 assert!(rows.is_empty());
211 }
212
213 #[tokio::test]
214 async fn load_recent_corrections_empty_table() {
215 let store = test_store().await;
216 let rows = store.load_recent_corrections(10).await.unwrap();
217 assert!(rows.is_empty());
218 }
219
220 #[tokio::test]
221 async fn store_correction_without_skill_name() {
222 let store = test_store().await;
223
224 let id = store
225 .store_user_correction(
226 None,
227 "original output",
228 "correction text",
229 None,
230 "repetition",
231 )
232 .await
233 .unwrap();
234 assert!(id > 0);
235
236 let rows = store.load_recent_corrections(10).await.unwrap();
237 assert_eq!(rows.len(), 1);
238 assert!(rows[0].skill_name.is_none());
239 assert_eq!(rows[0].correction_kind, "repetition");
240 }
241
242 #[tokio::test]
243 async fn load_corrections_for_skill_respects_limit() {
244 let store = test_store().await;
245
246 for i in 0..5 {
247 store
248 .store_user_correction(
249 None,
250 &format!("out{i}"),
251 &format!("fix{i}"),
252 Some("git"),
253 "explicit_rejection",
254 )
255 .await
256 .unwrap();
257 }
258
259 let rows = store.load_corrections_for_skill("git", 3).await.unwrap();
260 assert_eq!(rows.len(), 3);
261 }
262}