1use zeph_common::text::truncate_to_bytes_ref;
5
6use super::SqliteStore;
7use crate::error::MemoryError;
8
9#[derive(Debug, Clone)]
10pub struct LearnedPreferenceRow {
11 pub id: i64,
12 pub preference_key: String,
13 pub preference_value: String,
14 pub confidence: f64,
15 pub evidence_count: i64,
16 pub updated_at: String,
17}
18
19type PreferenceTuple = (i64, String, String, f64, i64, String);
20
21fn row_from_tuple(t: PreferenceTuple) -> LearnedPreferenceRow {
22 LearnedPreferenceRow {
23 id: t.0,
24 preference_key: t.1,
25 preference_value: t.2,
26 confidence: t.3,
27 evidence_count: t.4,
28 updated_at: t.5,
29 }
30}
31
32impl SqliteStore {
33 pub async fn upsert_learned_preference(
46 &self,
47 key: &str,
48 value: &str,
49 confidence: f64,
50 evidence_count: i64,
51 ) -> Result<(), MemoryError> {
52 const MAX_KEY_BYTES: usize = 128;
53 const MAX_VALUE_BYTES: usize = 256;
54 let key_trunc = truncate_to_bytes_ref(key, MAX_KEY_BYTES);
55 let value_trunc = truncate_to_bytes_ref(value, MAX_VALUE_BYTES);
56 if key_trunc.len() < key.len() {
57 tracing::warn!(
58 original_len = key.len(),
59 "learned_preferences: key truncated to 128 bytes"
60 );
61 }
62 if value_trunc.len() < value.len() {
63 tracing::warn!(
64 original_len = value.len(),
65 "learned_preferences: value truncated to 256 bytes"
66 );
67 }
68 sqlx::query(
69 "INSERT INTO learned_preferences \
70 (preference_key, preference_value, confidence, evidence_count, updated_at) \
71 VALUES (?, ?, ?, ?, datetime('now')) \
72 ON CONFLICT(preference_key) DO UPDATE SET \
73 preference_value = excluded.preference_value, \
74 confidence = excluded.confidence, \
75 evidence_count = excluded.evidence_count, \
76 updated_at = datetime('now')",
77 )
78 .bind(key_trunc)
79 .bind(value_trunc)
80 .bind(confidence)
81 .bind(evidence_count)
82 .execute(&self.pool)
83 .await?;
84 Ok(())
85 }
86
87 pub async fn load_learned_preferences(&self) -> Result<Vec<LearnedPreferenceRow>, MemoryError> {
93 let rows: Vec<PreferenceTuple> = sqlx::query_as(
94 "SELECT id, preference_key, preference_value, confidence, evidence_count, updated_at \
95 FROM learned_preferences \
96 ORDER BY confidence DESC",
97 )
98 .fetch_all(&self.pool)
99 .await?;
100 Ok(rows.into_iter().map(row_from_tuple).collect())
101 }
102
103 pub async fn load_corrections_after(
112 &self,
113 after_id: i64,
114 limit: u32,
115 ) -> Result<Vec<super::corrections::UserCorrectionRow>, MemoryError> {
116 use super::corrections::UserCorrectionRow;
117
118 type Tuple = (
119 i64,
120 Option<i64>,
121 String,
122 String,
123 Option<String>,
124 String,
125 String,
126 );
127
128 let rows: Vec<Tuple> = sqlx::query_as(
129 "SELECT id, session_id, original_output, correction_text, \
130 skill_name, correction_kind, created_at \
131 FROM user_corrections \
132 WHERE id > ? \
133 ORDER BY id ASC LIMIT ?",
134 )
135 .bind(after_id)
136 .bind(limit)
137 .fetch_all(&self.pool)
138 .await?;
139
140 Ok(rows
141 .into_iter()
142 .map(|t| UserCorrectionRow {
143 id: t.0,
144 session_id: t.1,
145 original_output: t.2,
146 correction_text: t.3,
147 skill_name: t.4,
148 correction_kind: t.5,
149 created_at: t.6,
150 })
151 .collect())
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 async fn store() -> SqliteStore {
160 SqliteStore::new(":memory:").await.unwrap()
161 }
162
163 #[tokio::test]
164 async fn upsert_and_load() {
165 let s = store().await;
166 s.upsert_learned_preference("verbosity", "concise", 0.9, 5)
167 .await
168 .unwrap();
169 let rows = s.load_learned_preferences().await.unwrap();
170 assert_eq!(rows.len(), 1);
171 assert_eq!(rows[0].preference_key, "verbosity");
172 assert_eq!(rows[0].preference_value, "concise");
173 assert!((rows[0].confidence - 0.9).abs() < 1e-9);
174 assert_eq!(rows[0].evidence_count, 5);
175 }
176
177 #[tokio::test]
178 async fn upsert_updates_existing() {
179 let s = store().await;
180 s.upsert_learned_preference("verbosity", "concise", 0.8, 3)
181 .await
182 .unwrap();
183 s.upsert_learned_preference("verbosity", "verbose", 0.95, 8)
184 .await
185 .unwrap();
186 let rows = s.load_learned_preferences().await.unwrap();
187 assert_eq!(rows.len(), 1);
188 assert_eq!(rows[0].preference_value, "verbose");
189 assert!((rows[0].confidence - 0.95).abs() < 1e-9);
190 assert_eq!(rows[0].evidence_count, 8);
191 }
192
193 #[tokio::test]
194 async fn load_ordered_by_confidence() {
195 let s = store().await;
196 s.upsert_learned_preference("format_preference", "bullet points", 0.75, 3)
197 .await
198 .unwrap();
199 s.upsert_learned_preference("verbosity", "concise", 0.9, 5)
200 .await
201 .unwrap();
202 let rows = s.load_learned_preferences().await.unwrap();
203 assert_eq!(rows[0].preference_key, "verbosity");
204 assert_eq!(rows[1].preference_key, "format_preference");
205 }
206
207 #[tokio::test]
208 async fn load_corrections_after_watermark() {
209 let s = store().await;
210 s.store_user_correction(None, "output", "be brief", None, "explicit_rejection")
212 .await
213 .unwrap();
214 let id2 = s
215 .store_user_correction(None, "output2", "use bullets", None, "alternative_request")
216 .await
217 .unwrap();
218 let rows = s.load_corrections_after(id2 - 1, 10).await.unwrap();
220 assert_eq!(rows.len(), 1);
221 assert_eq!(rows[0].correction_text, "use bullets");
222 }
223}