1use super::SqliteStore;
5use crate::error::MemoryError;
6
7fn truncate_to_bytes(s: &str, max_bytes: usize) -> &str {
9 if s.len() <= max_bytes {
10 return s;
11 }
12 let mut end = max_bytes;
14 while end > 0 && !s.is_char_boundary(end) {
15 end -= 1;
16 }
17 &s[..end]
18}
19
20#[derive(Debug, Clone)]
21pub struct LearnedPreferenceRow {
22 pub id: i64,
23 pub preference_key: String,
24 pub preference_value: String,
25 pub confidence: f64,
26 pub evidence_count: i64,
27 pub updated_at: String,
28}
29
30type PreferenceTuple = (i64, String, String, f64, i64, String);
31
32fn row_from_tuple(t: PreferenceTuple) -> LearnedPreferenceRow {
33 LearnedPreferenceRow {
34 id: t.0,
35 preference_key: t.1,
36 preference_value: t.2,
37 confidence: t.3,
38 evidence_count: t.4,
39 updated_at: t.5,
40 }
41}
42
43impl SqliteStore {
44 pub async fn upsert_learned_preference(
57 &self,
58 key: &str,
59 value: &str,
60 confidence: f64,
61 evidence_count: i64,
62 ) -> Result<(), MemoryError> {
63 const MAX_KEY_BYTES: usize = 128;
64 const MAX_VALUE_BYTES: usize = 256;
65 let key_trunc = truncate_to_bytes(key, MAX_KEY_BYTES);
66 let value_trunc = truncate_to_bytes(value, MAX_VALUE_BYTES);
67 if key_trunc.len() < key.len() {
68 tracing::warn!(
69 original_len = key.len(),
70 "learned_preferences: key truncated to 128 bytes"
71 );
72 }
73 if value_trunc.len() < value.len() {
74 tracing::warn!(
75 original_len = value.len(),
76 "learned_preferences: value truncated to 256 bytes"
77 );
78 }
79 sqlx::query(
80 "INSERT INTO learned_preferences \
81 (preference_key, preference_value, confidence, evidence_count, updated_at) \
82 VALUES (?, ?, ?, ?, datetime('now')) \
83 ON CONFLICT(preference_key) DO UPDATE SET \
84 preference_value = excluded.preference_value, \
85 confidence = excluded.confidence, \
86 evidence_count = excluded.evidence_count, \
87 updated_at = datetime('now')",
88 )
89 .bind(key_trunc)
90 .bind(value_trunc)
91 .bind(confidence)
92 .bind(evidence_count)
93 .execute(&self.pool)
94 .await?;
95 Ok(())
96 }
97
98 pub async fn load_learned_preferences(&self) -> Result<Vec<LearnedPreferenceRow>, MemoryError> {
104 let rows: Vec<PreferenceTuple> = sqlx::query_as(
105 "SELECT id, preference_key, preference_value, confidence, evidence_count, updated_at \
106 FROM learned_preferences \
107 ORDER BY confidence DESC",
108 )
109 .fetch_all(&self.pool)
110 .await?;
111 Ok(rows.into_iter().map(row_from_tuple).collect())
112 }
113
114 pub async fn load_corrections_after(
123 &self,
124 after_id: i64,
125 limit: u32,
126 ) -> Result<Vec<super::corrections::UserCorrectionRow>, MemoryError> {
127 use super::corrections::UserCorrectionRow;
128
129 type Tuple = (
130 i64,
131 Option<i64>,
132 String,
133 String,
134 Option<String>,
135 String,
136 String,
137 );
138
139 let rows: Vec<Tuple> = sqlx::query_as(
140 "SELECT id, session_id, original_output, correction_text, \
141 skill_name, correction_kind, created_at \
142 FROM user_corrections \
143 WHERE id > ? \
144 ORDER BY id ASC LIMIT ?",
145 )
146 .bind(after_id)
147 .bind(limit)
148 .fetch_all(&self.pool)
149 .await?;
150
151 Ok(rows
152 .into_iter()
153 .map(|t| UserCorrectionRow {
154 id: t.0,
155 session_id: t.1,
156 original_output: t.2,
157 correction_text: t.3,
158 skill_name: t.4,
159 correction_kind: t.5,
160 created_at: t.6,
161 })
162 .collect())
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 async fn store() -> SqliteStore {
171 SqliteStore::new(":memory:").await.unwrap()
172 }
173
174 #[tokio::test]
175 async fn upsert_and_load() {
176 let s = store().await;
177 s.upsert_learned_preference("verbosity", "concise", 0.9, 5)
178 .await
179 .unwrap();
180 let rows = s.load_learned_preferences().await.unwrap();
181 assert_eq!(rows.len(), 1);
182 assert_eq!(rows[0].preference_key, "verbosity");
183 assert_eq!(rows[0].preference_value, "concise");
184 assert!((rows[0].confidence - 0.9).abs() < 1e-9);
185 assert_eq!(rows[0].evidence_count, 5);
186 }
187
188 #[tokio::test]
189 async fn upsert_updates_existing() {
190 let s = store().await;
191 s.upsert_learned_preference("verbosity", "concise", 0.8, 3)
192 .await
193 .unwrap();
194 s.upsert_learned_preference("verbosity", "verbose", 0.95, 8)
195 .await
196 .unwrap();
197 let rows = s.load_learned_preferences().await.unwrap();
198 assert_eq!(rows.len(), 1);
199 assert_eq!(rows[0].preference_value, "verbose");
200 assert!((rows[0].confidence - 0.95).abs() < 1e-9);
201 assert_eq!(rows[0].evidence_count, 8);
202 }
203
204 #[tokio::test]
205 async fn load_ordered_by_confidence() {
206 let s = store().await;
207 s.upsert_learned_preference("format_preference", "bullet points", 0.75, 3)
208 .await
209 .unwrap();
210 s.upsert_learned_preference("verbosity", "concise", 0.9, 5)
211 .await
212 .unwrap();
213 let rows = s.load_learned_preferences().await.unwrap();
214 assert_eq!(rows[0].preference_key, "verbosity");
215 assert_eq!(rows[1].preference_key, "format_preference");
216 }
217
218 #[tokio::test]
219 async fn load_corrections_after_watermark() {
220 let s = store().await;
221 s.store_user_correction(None, "output", "be brief", None, "explicit_rejection")
223 .await
224 .unwrap();
225 let id2 = s
226 .store_user_correction(None, "output2", "use bullets", None, "alternative_request")
227 .await
228 .unwrap();
229 let rows = s.load_corrections_after(id2 - 1, 10).await.unwrap();
231 assert_eq!(rows.len(), 1);
232 assert_eq!(rows[0].correction_text, "use bullets");
233 }
234}