1use std::path::Path;
10
11use rusqlite::{Connection, params};
12use serde::{Deserialize, Serialize};
13
14use crate::{Error, Result};
15
16const SCHEMA_V1: &str = r#"
17CREATE TABLE IF NOT EXISTS observations (
18 id INTEGER PRIMARY KEY AUTOINCREMENT,
19 template_hash TEXT NOT NULL,
20 plan_fingerprint TEXT NOT NULL,
21 est_rows INTEGER NOT NULL,
22 actual_rows INTEGER NOT NULL,
23 latency_ms REAL,
24 recorded_at TEXT NOT NULL DEFAULT (datetime('now'))
25);
26CREATE INDEX IF NOT EXISTS idx_obs_template ON observations(template_hash);
27CREATE INDEX IF NOT EXISTS idx_obs_plan ON observations(plan_fingerprint);
28"#;
29
30const SCHEMA_USER_VERSION: i32 = 1;
39
40#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
42pub struct Observation {
43 pub template_hash: String,
44 pub plan_fingerprint: String,
45 pub est_rows: u64,
46 pub actual_rows: u64,
47 pub latency_ms: Option<f64>,
48}
49
50impl Observation {
51 pub fn q_error(&self) -> f64 {
69 if self.est_rows == 0 || self.actual_rows == 0 {
70 return f64::INFINITY;
71 }
72 let r = self.actual_rows as f64 / self.est_rows as f64;
73 if r >= 1.0 { r } else { 1.0 / r }
74 }
75}
76
77pub struct FeedbackStore {
79 conn: Connection,
80}
81
82impl FeedbackStore {
83 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
85 let path_ref = path.as_ref();
86 let conn = Connection::open(path_ref).map_err(map_sqlite)?;
87 conn.execute_batch(SCHEMA_V1).map_err(map_sqlite)?;
88 check_or_stamp_schema_version(&conn)?;
89 #[cfg(unix)]
97 {
98 use std::os::unix::fs::PermissionsExt;
99 if let Err(err) =
100 std::fs::set_permissions(path_ref, std::fs::Permissions::from_mode(0o600))
101 {
102 log::debug!(
103 "feedback store: could not tighten perms on {}: {}",
104 path_ref.display(),
105 err
106 );
107 }
108 }
109 Ok(Self { conn })
110 }
111
112 pub fn open_in_memory() -> Result<Self> {
123 let conn = Connection::open_in_memory().map_err(map_sqlite)?;
124 conn.execute_batch(SCHEMA_V1).map_err(map_sqlite)?;
125 check_or_stamp_schema_version(&conn)?;
126 Ok(Self { conn })
127 }
128
129 pub fn record(&self, obs: &Observation) -> Result<i64> {
149 self.conn
150 .execute(
151 "INSERT INTO observations (template_hash, plan_fingerprint, est_rows, actual_rows, latency_ms)
152 VALUES (?1, ?2, ?3, ?4, ?5)",
153 params![
154 obs.template_hash,
155 obs.plan_fingerprint,
156 obs.est_rows as i64,
157 obs.actual_rows as i64,
158 obs.latency_ms,
159 ],
160 )
161 .map_err(map_sqlite)?;
162 Ok(self.conn.last_insert_rowid())
163 }
164
165 pub fn history(&self, template_hash: &str) -> Result<Vec<Observation>> {
167 let mut stmt = self
168 .conn
169 .prepare(
170 "SELECT template_hash, plan_fingerprint, est_rows, actual_rows, latency_ms
171 FROM observations WHERE template_hash = ?1 ORDER BY id ASC",
172 )
173 .map_err(map_sqlite)?;
174 let rows = stmt
175 .query_map(params![template_hash], |row| {
176 Ok(Observation {
177 template_hash: row.get(0)?,
178 plan_fingerprint: row.get(1)?,
179 est_rows: row.get::<_, i64>(2)? as u64,
180 actual_rows: row.get::<_, i64>(3)? as u64,
181 latency_ms: row.get(4)?,
182 })
183 })
184 .map_err(map_sqlite)?;
185 rows.collect::<std::result::Result<Vec<_>, _>>()
186 .map_err(map_sqlite)
187 }
188
189 pub fn count(&self) -> Result<u64> {
191 self.conn
192 .query_row("SELECT COUNT(*) FROM observations", [], |row| {
193 row.get::<_, i64>(0)
194 })
195 .map(|n| n as u64)
196 .map_err(map_sqlite)
197 }
198}
199
200fn map_sqlite(e: rusqlite::Error) -> Error {
201 Error::Feedback(e.to_string())
202}
203
204fn check_or_stamp_schema_version(conn: &Connection) -> Result<()> {
212 let on_disk: i32 = conn
213 .query_row("PRAGMA user_version", [], |row| row.get(0))
214 .map_err(map_sqlite)?;
215 if on_disk == 0 {
216 conn.execute_batch(&format!("PRAGMA user_version = {SCHEMA_USER_VERSION}"))
221 .map_err(map_sqlite)?;
222 return Ok(());
223 }
224 if on_disk > SCHEMA_USER_VERSION {
225 return Err(Error::Feedback(format!(
226 "feedback store schema version {on_disk} is newer than this build supports \
227 ({SCHEMA_USER_VERSION}); refuse to open to avoid data truncation"
228 )));
229 }
230 if on_disk < SCHEMA_USER_VERSION {
231 conn.execute_batch(&format!("PRAGMA user_version = {SCHEMA_USER_VERSION}"))
235 .map_err(map_sqlite)?;
236 }
237 Ok(())
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 fn sample(template: &str, est: u64, actual: u64) -> Observation {
245 Observation {
246 template_hash: template.into(),
247 plan_fingerprint: "plan-abc".into(),
248 est_rows: est,
249 actual_rows: actual,
250 latency_ms: Some(42.0),
251 }
252 }
253
254 #[test]
255 fn record_and_count() {
256 let store = FeedbackStore::open_in_memory().unwrap();
257 assert_eq!(store.count().unwrap(), 0);
258 store.record(&sample("t1", 100, 110)).unwrap();
259 store.record(&sample("t1", 100, 90)).unwrap();
260 store.record(&sample("t2", 50, 200)).unwrap();
261 assert_eq!(store.count().unwrap(), 3);
262 }
263
264 #[test]
265 fn history_filters_by_template() {
266 let store = FeedbackStore::open_in_memory().unwrap();
267 store.record(&sample("t1", 100, 110)).unwrap();
268 store.record(&sample("t2", 50, 200)).unwrap();
269 store.record(&sample("t1", 100, 90)).unwrap();
270 let t1 = store.history("t1").unwrap();
271 assert_eq!(t1.len(), 2);
272 assert!(t1.iter().all(|o| o.template_hash == "t1"));
273 }
274
275 #[test]
276 fn schema_version_stamped_on_fresh_store() {
277 let store = FeedbackStore::open_in_memory().unwrap();
278 let v: i32 = store
279 .conn
280 .query_row("PRAGMA user_version", [], |row| row.get(0))
281 .unwrap();
282 assert_eq!(v, SCHEMA_USER_VERSION);
283 }
284
285 #[test]
286 fn refuses_forward_versioned_store() {
287 let tmp = std::env::temp_dir().join(format!(
290 "samkhya-feedback-forward-{}.db",
291 std::process::id()
292 ));
293 let _ = std::fs::remove_file(&tmp);
294 {
295 let store = FeedbackStore::open(&tmp).unwrap();
296 store
297 .conn
298 .execute_batch(&format!(
299 "PRAGMA user_version = {}",
300 SCHEMA_USER_VERSION + 99
301 ))
302 .unwrap();
303 }
304 match FeedbackStore::open(&tmp) {
305 Ok(_) => panic!("expected forward-version rejection, got Ok"),
306 Err(Error::Feedback(msg)) => assert!(
307 msg.contains("newer than this build"),
308 "expected forward-version rejection, got: {msg}"
309 ),
310 Err(other) => panic!("expected Error::Feedback, got {other:?}"),
311 }
312 let _ = std::fs::remove_file(&tmp);
313 }
314
315 #[test]
316 fn q_error_computes_correctly() {
317 let obs_over = sample("t1", 10, 100); assert!((obs_over.q_error() - 10.0).abs() < 1e-9);
319 let obs_under = sample("t1", 100, 10); assert!((obs_under.q_error() - 10.0).abs() < 1e-9);
321 let obs_exact = sample("t1", 100, 100);
322 assert!((obs_exact.q_error() - 1.0).abs() < 1e-9);
323 let obs_zero = sample("t1", 0, 100);
324 assert!(obs_zero.q_error().is_infinite());
325 }
326
327 #[test]
328 fn persists_to_disk() {
329 let tmp = std::env::temp_dir().join(format!("samkhya-test-{}.db", std::process::id()));
330 let _ = std::fs::remove_file(&tmp);
332 {
333 let store = FeedbackStore::open(&tmp).unwrap();
334 store.record(&sample("t1", 1, 2)).unwrap();
335 }
336 let store2 = FeedbackStore::open(&tmp).unwrap();
337 assert_eq!(store2.count().unwrap(), 1);
338 std::fs::remove_file(&tmp).ok();
339 }
340}