Skip to main content

samkhya_core/
feedback.rs

1//! Feedback recorder — captures `(plan, estimate, actual)` triples to a
2//! SQLite sidecar so the residual correction model can learn from real
3//! engine behavior. Inspired by Bao / AutoSteer.
4//!
5//! The store is deliberately minimal: one process, one connection, one
6//! table per concern. The schema is forward-compatible — new optional
7//! columns can be added with `ALTER TABLE` migrations later.
8
9use std::path::Path;
10
11use rusqlite::{Connection, params};
12use serde::{Deserialize, Serialize};
13
14use crate::{Error, Result};
15
16const SCHEMA_V1: &str = r#"
17CREATE TABLE IF NOT EXISTS observations (
18    id              INTEGER PRIMARY KEY AUTOINCREMENT,
19    template_hash   TEXT NOT NULL,
20    plan_fingerprint TEXT NOT NULL,
21    est_rows        INTEGER NOT NULL,
22    actual_rows     INTEGER NOT NULL,
23    latency_ms      REAL,
24    recorded_at     TEXT NOT NULL DEFAULT (datetime('now'))
25);
26CREATE INDEX IF NOT EXISTS idx_obs_template ON observations(template_hash);
27CREATE INDEX IF NOT EXISTS idx_obs_plan ON observations(plan_fingerprint);
28"#;
29
30/// Schema version stamped into SQLite's `PRAGMA user_version`.
31///
32/// Bumped only when the on-disk schema changes in a backwards-incompatible
33/// way. Stores written by an older binary (with `user_version = 0`,
34/// i.e. unset) are silently upgraded by writing the current value;
35/// stores written by a newer binary (with a strictly larger version)
36/// are rejected so we never silently truncate forward-versioned data.
37/// See `documents/SECURITY-REVIEW-2026-05-17.md` item L3.
38const SCHEMA_USER_VERSION: i32 = 1;
39
40/// A single observation captured at query end.
41#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
42pub struct Observation {
43    pub template_hash: String,
44    pub plan_fingerprint: String,
45    pub est_rows: u64,
46    pub actual_rows: u64,
47    pub latency_ms: Option<f64>,
48}
49
50impl Observation {
51    /// Multiplicative q-error: `max(actual/est, est/actual)`. Returns `f64::INFINITY` if either is 0.
52    ///
53    /// # Examples
54    ///
55    /// ```
56    /// use samkhya_core::feedback::Observation;
57    ///
58    /// // 10× underestimate: est=10, actual=100 → q-error = 10.
59    /// let obs = Observation {
60    ///     template_hash: "t".into(),
61    ///     plan_fingerprint: "p".into(),
62    ///     est_rows: 10,
63    ///     actual_rows: 100,
64    ///     latency_ms: None,
65    /// };
66    /// assert!((obs.q_error() - 10.0).abs() < 1e-9);
67    /// ```
68    pub fn q_error(&self) -> f64 {
69        if self.est_rows == 0 || self.actual_rows == 0 {
70            return f64::INFINITY;
71        }
72        let r = self.actual_rows as f64 / self.est_rows as f64;
73        if r >= 1.0 { r } else { 1.0 / r }
74    }
75}
76
77/// SQLite-backed feedback store.
78pub struct FeedbackStore {
79    conn: Connection,
80}
81
82impl FeedbackStore {
83    /// Open or create a store at `path`.
84    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
85        let path_ref = path.as_ref();
86        let conn = Connection::open(path_ref).map_err(map_sqlite)?;
87        conn.execute_batch(SCHEMA_V1).map_err(map_sqlite)?;
88        check_or_stamp_schema_version(&conn)?;
89        // SECURITY-REVIEW-2026-05-17.md (M2): the feedback store records
90        // plan fingerprints which may carry schema details or filter
91        // values. Tighten the file mode to 0o600 (owner-only) so a
92        // shared-system reader cannot snapshot the store. Best-effort:
93        // a failure here (e.g., the file does not exist because we are
94        // running against an in-memory or VFS-special path) is logged
95        // but not promoted to an error — the store is still usable.
96        #[cfg(unix)]
97        {
98            use std::os::unix::fs::PermissionsExt;
99            if let Err(err) =
100                std::fs::set_permissions(path_ref, std::fs::Permissions::from_mode(0o600))
101            {
102                log::debug!(
103                    "feedback store: could not tighten perms on {}: {}",
104                    path_ref.display(),
105                    err
106                );
107            }
108        }
109        Ok(Self { conn })
110    }
111
112    /// Open an in-memory store (test / ephemeral).
113    ///
114    /// # Examples
115    ///
116    /// ```
117    /// use samkhya_core::feedback::FeedbackStore;
118    ///
119    /// let store = FeedbackStore::open_in_memory().unwrap();
120    /// assert_eq!(store.count().unwrap(), 0);
121    /// ```
122    pub fn open_in_memory() -> Result<Self> {
123        let conn = Connection::open_in_memory().map_err(map_sqlite)?;
124        conn.execute_batch(SCHEMA_V1).map_err(map_sqlite)?;
125        check_or_stamp_schema_version(&conn)?;
126        Ok(Self { conn })
127    }
128
129    /// Record an observation.
130    ///
131    /// # Examples
132    ///
133    /// ```
134    /// use samkhya_core::feedback::{FeedbackStore, Observation};
135    ///
136    /// let store = FeedbackStore::open_in_memory().unwrap();
137    /// let obs = Observation {
138    ///     template_hash: "tpch-q1".into(),
139    ///     plan_fingerprint: "hash-join#42".into(),
140    ///     est_rows: 1000,
141    ///     actual_rows: 950,
142    ///     latency_ms: Some(12.5),
143    /// };
144    /// let id = store.record(&obs).unwrap();
145    /// assert!(id > 0);
146    /// assert_eq!(store.count().unwrap(), 1);
147    /// ```
148    pub fn record(&self, obs: &Observation) -> Result<i64> {
149        self.conn
150            .execute(
151                "INSERT INTO observations (template_hash, plan_fingerprint, est_rows, actual_rows, latency_ms)
152                 VALUES (?1, ?2, ?3, ?4, ?5)",
153                params![
154                    obs.template_hash,
155                    obs.plan_fingerprint,
156                    obs.est_rows as i64,
157                    obs.actual_rows as i64,
158                    obs.latency_ms,
159                ],
160            )
161            .map_err(map_sqlite)?;
162        Ok(self.conn.last_insert_rowid())
163    }
164
165    /// Return all observations for a given query template, oldest first.
166    pub fn history(&self, template_hash: &str) -> Result<Vec<Observation>> {
167        let mut stmt = self
168            .conn
169            .prepare(
170                "SELECT template_hash, plan_fingerprint, est_rows, actual_rows, latency_ms
171                 FROM observations WHERE template_hash = ?1 ORDER BY id ASC",
172            )
173            .map_err(map_sqlite)?;
174        let rows = stmt
175            .query_map(params![template_hash], |row| {
176                Ok(Observation {
177                    template_hash: row.get(0)?,
178                    plan_fingerprint: row.get(1)?,
179                    est_rows: row.get::<_, i64>(2)? as u64,
180                    actual_rows: row.get::<_, i64>(3)? as u64,
181                    latency_ms: row.get(4)?,
182                })
183            })
184            .map_err(map_sqlite)?;
185        rows.collect::<std::result::Result<Vec<_>, _>>()
186            .map_err(map_sqlite)
187    }
188
189    /// Number of observations stored.
190    pub fn count(&self) -> Result<u64> {
191        self.conn
192            .query_row("SELECT COUNT(*) FROM observations", [], |row| {
193                row.get::<_, i64>(0)
194            })
195            .map(|n| n as u64)
196            .map_err(map_sqlite)
197    }
198}
199
200fn map_sqlite(e: rusqlite::Error) -> Error {
201    Error::Feedback(e.to_string())
202}
203
204/// Read the SQLite `user_version` PRAGMA and either stamp it (if unset)
205/// or reject the store (if it carries a strictly larger version).
206///
207/// See `documents/SECURITY-REVIEW-2026-05-17.md` item L3: a previously
208/// malicious or simply newer-schema `.db` opened by an older samkhya
209/// would silently mismatch row shape on read; the new PRAGMA check
210/// makes that visible.
211fn check_or_stamp_schema_version(conn: &Connection) -> Result<()> {
212    let on_disk: i32 = conn
213        .query_row("PRAGMA user_version", [], |row| row.get(0))
214        .map_err(map_sqlite)?;
215    if on_disk == 0 {
216        // Fresh / pre-versioning store. Stamp the current version so
217        // future opens see a match. Using `execute_batch` because
218        // `PRAGMA user_version = N` is not a parameterised statement
219        // (SQLite refuses bind params on PRAGMA writes).
220        conn.execute_batch(&format!("PRAGMA user_version = {SCHEMA_USER_VERSION}"))
221            .map_err(map_sqlite)?;
222        return Ok(());
223    }
224    if on_disk > SCHEMA_USER_VERSION {
225        return Err(Error::Feedback(format!(
226            "feedback store schema version {on_disk} is newer than this build supports \
227             ({SCHEMA_USER_VERSION}); refuse to open to avoid data truncation"
228        )));
229    }
230    if on_disk < SCHEMA_USER_VERSION {
231        // Older but compatible. No migrations yet (we are on v1), so
232        // just bump the marker. Future versions will run migration
233        // SQL here before bumping.
234        conn.execute_batch(&format!("PRAGMA user_version = {SCHEMA_USER_VERSION}"))
235            .map_err(map_sqlite)?;
236    }
237    Ok(())
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    fn sample(template: &str, est: u64, actual: u64) -> Observation {
245        Observation {
246            template_hash: template.into(),
247            plan_fingerprint: "plan-abc".into(),
248            est_rows: est,
249            actual_rows: actual,
250            latency_ms: Some(42.0),
251        }
252    }
253
254    #[test]
255    fn record_and_count() {
256        let store = FeedbackStore::open_in_memory().unwrap();
257        assert_eq!(store.count().unwrap(), 0);
258        store.record(&sample("t1", 100, 110)).unwrap();
259        store.record(&sample("t1", 100, 90)).unwrap();
260        store.record(&sample("t2", 50, 200)).unwrap();
261        assert_eq!(store.count().unwrap(), 3);
262    }
263
264    #[test]
265    fn history_filters_by_template() {
266        let store = FeedbackStore::open_in_memory().unwrap();
267        store.record(&sample("t1", 100, 110)).unwrap();
268        store.record(&sample("t2", 50, 200)).unwrap();
269        store.record(&sample("t1", 100, 90)).unwrap();
270        let t1 = store.history("t1").unwrap();
271        assert_eq!(t1.len(), 2);
272        assert!(t1.iter().all(|o| o.template_hash == "t1"));
273    }
274
275    #[test]
276    fn schema_version_stamped_on_fresh_store() {
277        let store = FeedbackStore::open_in_memory().unwrap();
278        let v: i32 = store
279            .conn
280            .query_row("PRAGMA user_version", [], |row| row.get(0))
281            .unwrap();
282        assert_eq!(v, SCHEMA_USER_VERSION);
283    }
284
285    #[test]
286    fn refuses_forward_versioned_store() {
287        // Open once to stamp the schema, then manually bump the
288        // user_version past what this binary supports and re-open.
289        let tmp = std::env::temp_dir().join(format!(
290            "samkhya-feedback-forward-{}.db",
291            std::process::id()
292        ));
293        let _ = std::fs::remove_file(&tmp);
294        {
295            let store = FeedbackStore::open(&tmp).unwrap();
296            store
297                .conn
298                .execute_batch(&format!(
299                    "PRAGMA user_version = {}",
300                    SCHEMA_USER_VERSION + 99
301                ))
302                .unwrap();
303        }
304        match FeedbackStore::open(&tmp) {
305            Ok(_) => panic!("expected forward-version rejection, got Ok"),
306            Err(Error::Feedback(msg)) => assert!(
307                msg.contains("newer than this build"),
308                "expected forward-version rejection, got: {msg}"
309            ),
310            Err(other) => panic!("expected Error::Feedback, got {other:?}"),
311        }
312        let _ = std::fs::remove_file(&tmp);
313    }
314
315    #[test]
316    fn q_error_computes_correctly() {
317        let obs_over = sample("t1", 10, 100); // 10× underestimate
318        assert!((obs_over.q_error() - 10.0).abs() < 1e-9);
319        let obs_under = sample("t1", 100, 10); // 10× overestimate
320        assert!((obs_under.q_error() - 10.0).abs() < 1e-9);
321        let obs_exact = sample("t1", 100, 100);
322        assert!((obs_exact.q_error() - 1.0).abs() < 1e-9);
323        let obs_zero = sample("t1", 0, 100);
324        assert!(obs_zero.q_error().is_infinite());
325    }
326
327    #[test]
328    fn persists_to_disk() {
329        let tmp = std::env::temp_dir().join(format!("samkhya-test-{}.db", std::process::id()));
330        // ensure clean start
331        let _ = std::fs::remove_file(&tmp);
332        {
333            let store = FeedbackStore::open(&tmp).unwrap();
334            store.record(&sample("t1", 1, 2)).unwrap();
335        }
336        let store2 = FeedbackStore::open(&tmp).unwrap();
337        assert_eq!(store2.count().unwrap(), 1);
338        std::fs::remove_file(&tmp).ok();
339    }
340}