Skip to main content

zeph_memory/store/
experiments.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use super::SqliteStore;
5use crate::error::MemoryError;
6#[allow(unused_imports)]
7use zeph_db::sql;
8
9#[derive(Debug, Clone)]
10pub struct ExperimentResultRow {
11    pub id: i64,
12    pub session_id: String,
13    pub parameter: String,
14    pub value_json: String,
15    pub baseline_score: f64,
16    pub candidate_score: f64,
17    pub delta: f64,
18    pub latency_ms: i64,
19    pub tokens_used: i64,
20    pub accepted: bool,
21    pub source: String,
22    pub created_at: String,
23}
24
25#[derive(Debug, Clone)]
26pub struct NewExperimentResult<'a> {
27    pub session_id: &'a str,
28    pub parameter: &'a str,
29    pub value_json: &'a str,
30    pub baseline_score: f64,
31    pub candidate_score: f64,
32    pub delta: f64,
33    pub latency_ms: i64,
34    pub tokens_used: i64,
35    pub accepted: bool,
36    pub source: &'a str,
37}
38
39#[derive(Debug, Clone)]
40pub struct SessionSummaryRow {
41    pub session_id: String,
42    pub total: i64,
43    pub accepted_count: i64,
44    pub best_delta: f64,
45    pub total_tokens: i64,
46}
47
48/// Validate that `s` looks like `YYYY-MM-DD HH:MM:SS` or `YYYY-MM-DDTHH:MM:SS`.
49fn validate_timestamp(s: &str) -> Result<(), MemoryError> {
50    let bytes = s.as_bytes();
51    // Minimum length: "2000-01-01 00:00:00" = 19 chars
52    if bytes.len() < 19 {
53        return Err(MemoryError::Other(format!(
54            "invalid timestamp format (too short): {s:?}"
55        )));
56    }
57    let sep = bytes[10];
58    if sep != b' ' && sep != b'T' {
59        return Err(MemoryError::Other(format!(
60            "invalid timestamp format (expected space or T at position 10): {s:?}"
61        )));
62    }
63    // Check digit positions: YYYY-MM-DD HH:MM:SS
64    let digits_at = [0, 1, 2, 3, 5, 6, 8, 9, 11, 12, 14, 15, 17, 18];
65    let dashes_at = [4, 7];
66    let colons_at = [13, 16];
67    for i in digits_at {
68        if !bytes[i].is_ascii_digit() {
69            return Err(MemoryError::Other(format!(
70                "invalid timestamp format (expected digit at {i}): {s:?}"
71            )));
72        }
73    }
74    for i in dashes_at {
75        if bytes[i] != b'-' {
76            return Err(MemoryError::Other(format!(
77                "invalid timestamp format (expected '-' at {i}): {s:?}"
78            )));
79        }
80    }
81    for i in colons_at {
82        if bytes[i] != b':' {
83            return Err(MemoryError::Other(format!(
84                "invalid timestamp format (expected ':' at {i}): {s:?}"
85            )));
86        }
87    }
88    Ok(())
89}
90
91type ResultTuple = (
92    i64,
93    String,
94    String,
95    String,
96    f64,
97    f64,
98    f64,
99    i64,
100    i64,
101    bool,
102    String,
103    String,
104);
105
106fn row_from_tuple(t: ResultTuple) -> ExperimentResultRow {
107    ExperimentResultRow {
108        id: t.0,
109        session_id: t.1,
110        parameter: t.2,
111        value_json: t.3,
112        baseline_score: t.4,
113        candidate_score: t.5,
114        delta: t.6,
115        latency_ms: t.7,
116        tokens_used: t.8,
117        accepted: t.9,
118        source: t.10,
119        created_at: t.11,
120    }
121}
122
123impl SqliteStore {
124    /// Insert an experiment result and return the new row ID.
125    ///
126    /// # Errors
127    ///
128    /// Returns an error if the insert fails.
129    pub async fn insert_experiment_result(
130        &self,
131        result: &NewExperimentResult<'_>,
132    ) -> Result<i64, MemoryError> {
133        let row: (i64,) = zeph_db::query_as(sql!(
134            "INSERT INTO experiment_results \
135             (session_id, parameter, value_json, baseline_score, candidate_score, \
136              delta, latency_ms, tokens_used, accepted, source) \
137             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) RETURNING id"
138        ))
139        .bind(result.session_id)
140        .bind(result.parameter)
141        .bind(result.value_json)
142        .bind(result.baseline_score)
143        .bind(result.candidate_score)
144        .bind(result.delta)
145        .bind(result.latency_ms)
146        .bind(result.tokens_used)
147        .bind(result.accepted)
148        .bind(result.source)
149        .fetch_one(&self.pool)
150        .await?;
151        Ok(row.0)
152    }
153
154    /// List experiment results, optionally filtered by `session_id`, newest first.
155    ///
156    /// # Errors
157    ///
158    /// Returns an error if the query fails.
159    pub async fn list_experiment_results(
160        &self,
161        session_id: Option<&str>,
162        limit: u32,
163    ) -> Result<Vec<ExperimentResultRow>, MemoryError> {
164        let rows: Vec<ResultTuple> = if let Some(sid) = session_id {
165            zeph_db::query_as(sql!(
166                "SELECT id, session_id, parameter, value_json, baseline_score, candidate_score, \
167                 delta, latency_ms, tokens_used, accepted, source, created_at \
168                 FROM experiment_results WHERE session_id = ? ORDER BY id DESC LIMIT ?"
169            ))
170            .bind(sid)
171            .bind(limit)
172            .fetch_all(&self.pool)
173            .await?
174        } else {
175            zeph_db::query_as(sql!(
176                "SELECT id, session_id, parameter, value_json, baseline_score, candidate_score, \
177                 delta, latency_ms, tokens_used, accepted, source, created_at \
178                 FROM experiment_results ORDER BY id DESC LIMIT ?"
179            ))
180            .bind(limit)
181            .fetch_all(&self.pool)
182            .await?
183        };
184        Ok(rows.into_iter().map(row_from_tuple).collect())
185    }
186
187    /// Get the best accepted result (highest delta), optionally filtered by parameter.
188    ///
189    /// # Errors
190    ///
191    /// Returns an error if the query fails.
192    pub async fn best_experiment_result(
193        &self,
194        parameter: Option<&str>,
195    ) -> Result<Option<ExperimentResultRow>, MemoryError> {
196        let row: Option<ResultTuple> = if let Some(param) = parameter {
197            zeph_db::query_as(sql!(
198                "SELECT id, session_id, parameter, value_json, baseline_score, candidate_score, \
199                 delta, latency_ms, tokens_used, accepted, source, created_at \
200                 FROM experiment_results \
201                 WHERE accepted = 1 AND parameter = ? ORDER BY delta DESC LIMIT 1"
202            ))
203            .bind(param)
204            .fetch_optional(&self.pool)
205            .await?
206        } else {
207            zeph_db::query_as(sql!(
208                "SELECT id, session_id, parameter, value_json, baseline_score, candidate_score, \
209                 delta, latency_ms, tokens_used, accepted, source, created_at \
210                 FROM experiment_results \
211                 WHERE accepted = 1 ORDER BY delta DESC LIMIT 1"
212            ))
213            .fetch_optional(&self.pool)
214            .await?
215        };
216        Ok(row.map(row_from_tuple))
217    }
218
219    /// Get all results since a given ISO-8601 timestamp (`YYYY-MM-DD HH:MM:SS` or `YYYY-MM-DDTHH:MM:SS`).
220    ///
221    /// # Errors
222    ///
223    /// Returns `MemoryError::Other` if `since` does not match the expected timestamp format.
224    /// Returns an error if the query fails.
225    pub async fn experiment_results_since(
226        &self,
227        since: &str,
228    ) -> Result<Vec<ExperimentResultRow>, MemoryError> {
229        validate_timestamp(since)?;
230        let rows: Vec<ResultTuple> = zeph_db::query_as(sql!(
231            "SELECT id, session_id, parameter, value_json, baseline_score, candidate_score, \
232             delta, latency_ms, tokens_used, accepted, source, created_at \
233             FROM experiment_results WHERE created_at >= ? ORDER BY id DESC LIMIT 10000"
234        ))
235        .bind(since)
236        .fetch_all(&self.pool)
237        .await?;
238        Ok(rows.into_iter().map(row_from_tuple).collect())
239    }
240
241    /// Get a summary for a specific session.
242    ///
243    /// # Errors
244    ///
245    /// Returns an error if the query fails.
246    pub async fn experiment_session_summary(
247        &self,
248        session_id: &str,
249    ) -> Result<Option<SessionSummaryRow>, MemoryError> {
250        let row: Option<(String, i64, i64, Option<f64>, i64)> = zeph_db::query_as(sql!(
251            "SELECT session_id, COUNT(*) as total, \
252             SUM(CASE WHEN accepted = 1 THEN 1 ELSE 0 END) as accepted_count, \
253             MAX(CASE WHEN accepted = 1 THEN delta ELSE NULL END) as best_delta, \
254             SUM(tokens_used) as total_tokens \
255             FROM experiment_results WHERE session_id = ? GROUP BY session_id"
256        ))
257        .bind(session_id)
258        .fetch_optional(&self.pool)
259        .await?;
260        Ok(row.map(
261            |(sid, total, accepted_count, best_delta, total_tokens)| SessionSummaryRow {
262                session_id: sid,
263                total,
264                accepted_count,
265                best_delta: best_delta.unwrap_or(0.0),
266                total_tokens,
267            },
268        ))
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    async fn test_store() -> SqliteStore {
277        SqliteStore::new(":memory:").await.unwrap()
278    }
279
280    fn make_result<'a>(
281        session_id: &'a str,
282        parameter: &'a str,
283        accepted: bool,
284        delta: f64,
285    ) -> NewExperimentResult<'a> {
286        NewExperimentResult {
287            session_id,
288            parameter,
289            value_json: r#"{"type":"Float","value":0.7}"#,
290            baseline_score: 7.0,
291            candidate_score: 7.0 + delta,
292            delta,
293            latency_ms: 500,
294            tokens_used: 100,
295            accepted,
296            source: "manual",
297        }
298    }
299
300    #[tokio::test]
301    async fn insert_and_list_results() {
302        let store = test_store().await;
303        let r = make_result("session-1", "temperature", true, 1.0);
304        let id = store.insert_experiment_result(&r).await.unwrap();
305        assert!(id > 0);
306
307        let rows = store
308            .list_experiment_results(Some("session-1"), 10)
309            .await
310            .unwrap();
311        assert_eq!(rows.len(), 1);
312        assert_eq!(rows[0].session_id, "session-1");
313        assert_eq!(rows[0].parameter, "temperature");
314        assert!(rows[0].accepted);
315        assert!((rows[0].delta - 1.0).abs() < f64::EPSILON);
316    }
317
318    #[tokio::test]
319    async fn list_results_no_filter_returns_all() {
320        let store = test_store().await;
321        store
322            .insert_experiment_result(&make_result("s1", "temperature", true, 1.0))
323            .await
324            .unwrap();
325        store
326            .insert_experiment_result(&make_result("s2", "top_p", false, -0.2))
327            .await
328            .unwrap();
329
330        let rows = store.list_experiment_results(None, 10).await.unwrap();
331        assert_eq!(rows.len(), 2);
332        // newest first
333        assert_eq!(rows[0].session_id, "s2");
334    }
335
336    #[tokio::test]
337    async fn best_result_returns_accepted_highest_delta() {
338        let store = test_store().await;
339        store
340            .insert_experiment_result(&make_result("s1", "temperature", false, 2.0))
341            .await
342            .unwrap();
343        store
344            .insert_experiment_result(&make_result("s1", "temperature", true, 0.5))
345            .await
346            .unwrap();
347        store
348            .insert_experiment_result(&make_result("s1", "temperature", true, 1.5))
349            .await
350            .unwrap();
351
352        let best = store.best_experiment_result(None).await.unwrap().unwrap();
353        assert!(best.accepted);
354        assert!((best.delta - 1.5).abs() < f64::EPSILON);
355    }
356
357    #[tokio::test]
358    async fn best_result_filtered_by_parameter() {
359        let store = test_store().await;
360        store
361            .insert_experiment_result(&make_result("s1", "temperature", true, 2.0))
362            .await
363            .unwrap();
364        store
365            .insert_experiment_result(&make_result("s1", "top_p", true, 1.0))
366            .await
367            .unwrap();
368
369        let best = store
370            .best_experiment_result(Some("top_p"))
371            .await
372            .unwrap()
373            .unwrap();
374        assert_eq!(best.parameter, "top_p");
375    }
376
377    #[tokio::test]
378    async fn best_result_no_accepted_returns_none() {
379        let store = test_store().await;
380        store
381            .insert_experiment_result(&make_result("s1", "temperature", false, 2.0))
382            .await
383            .unwrap();
384        let best = store.best_experiment_result(None).await.unwrap();
385        assert!(best.is_none());
386    }
387
388    #[tokio::test]
389    async fn session_summary_aggregation() {
390        let store = test_store().await;
391        store
392            .insert_experiment_result(&make_result("sess", "temperature", true, 1.0))
393            .await
394            .unwrap();
395        store
396            .insert_experiment_result(&make_result("sess", "top_p", false, -0.2))
397            .await
398            .unwrap();
399        store
400            .insert_experiment_result(&make_result("sess", "top_k", true, 0.8))
401            .await
402            .unwrap();
403
404        let summary = store
405            .experiment_session_summary("sess")
406            .await
407            .unwrap()
408            .unwrap();
409        assert_eq!(summary.session_id, "sess");
410        assert_eq!(summary.total, 3);
411        assert_eq!(summary.accepted_count, 2);
412        assert!((summary.best_delta - 1.0).abs() < f64::EPSILON);
413        assert_eq!(summary.total_tokens, 300);
414    }
415
416    #[tokio::test]
417    async fn session_summary_unknown_session_returns_none() {
418        let store = test_store().await;
419        let summary = store
420            .experiment_session_summary("nonexistent")
421            .await
422            .unwrap();
423        assert!(summary.is_none());
424    }
425
426    #[tokio::test]
427    async fn results_since_time_filtering() {
428        let store = test_store().await;
429        // Insert a result, then query with a future timestamp — expect nothing
430        store
431            .insert_experiment_result(&make_result("s1", "temperature", true, 1.0))
432            .await
433            .unwrap();
434
435        let rows = store
436            .experiment_results_since("2099-01-01 00:00:00")
437            .await
438            .unwrap();
439        assert!(rows.is_empty());
440
441        // Query with a past timestamp — expect the result
442        let rows = store
443            .experiment_results_since("2000-01-01 00:00:00")
444            .await
445            .unwrap();
446        assert_eq!(rows.len(), 1);
447    }
448
449    #[tokio::test]
450    async fn results_since_rejects_invalid_timestamp() {
451        let store = test_store().await;
452        let bad = [
453            "",
454            "not-a-date",
455            "0000-00-00",
456            "2024-01-01",
457            "2024/01/01 00:00:00",
458        ];
459        for ts in bad {
460            let err = store.experiment_results_since(ts).await;
461            assert!(err.is_err(), "expected error for timestamp: {ts:?}");
462        }
463        // ISO-8601 with T separator should work
464        let store2 = test_store().await;
465        let rows = store2
466            .experiment_results_since("2000-01-01T00:00:00")
467            .await
468            .unwrap();
469        assert!(rows.is_empty());
470    }
471
472    #[tokio::test]
473    async fn list_results_respects_limit() {
474        let store = test_store().await;
475        for i in 0..5 {
476            store
477                .insert_experiment_result(&make_result(
478                    "s1",
479                    "temperature",
480                    i % 2 == 0,
481                    f64::from(i),
482                ))
483                .await
484                .unwrap();
485        }
486        let rows = store.list_experiment_results(None, 3).await.unwrap();
487        assert_eq!(rows.len(), 3);
488    }
489}