Skip to main content

smelt_memory/storage/
sqlite.rs

1//! SQLite storage for episode metadata
2
3use crate::error::{MemoryError, MemoryResult};
4use crate::types::{Episode, EpisodeOutcome, Feedback};
5use chrono::{DateTime, Utc};
6use rusqlite::{params, Connection, OptionalExtension};
7use std::path::Path;
8use std::sync::{Arc, Mutex};
9use uuid::Uuid;
10
11/// SQLite-based storage for episode metadata
12pub struct EpisodeStorage {
13    conn: Arc<Mutex<Connection>>,
14}
15
16impl EpisodeStorage {
17    /// Open or create a storage database
18    pub fn open(path: &Path) -> MemoryResult<Self> {
19        let conn = Connection::open(path)?;
20        let storage = Self {
21            conn: Arc::new(Mutex::new(conn)),
22        };
23        storage.init_schema()?;
24        Ok(storage)
25    }
26
27    /// Create an in-memory storage (for testing)
28    pub fn in_memory() -> MemoryResult<Self> {
29        let conn = Connection::open_in_memory()?;
30        let storage = Self {
31            conn: Arc::new(Mutex::new(conn)),
32        };
33        storage.init_schema()?;
34        Ok(storage)
35    }
36
37    /// Initialize the database schema
38    fn init_schema(&self) -> MemoryResult<()> {
39        let conn = self
40            .conn
41            .lock()
42            .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
43
44        conn.execute_batch(
45            r#"
46            CREATE TABLE IF NOT EXISTS episodes (
47                id TEXT PRIMARY KEY,
48                created_at TEXT NOT NULL,
49                project TEXT,
50                summary TEXT NOT NULL,
51                task_type TEXT NOT NULL,
52                outcome TEXT NOT NULL,
53                files_modified TEXT NOT NULL,
54                errors_resolved TEXT NOT NULL,
55                tags TEXT NOT NULL,
56                intent_id TEXT,
57                delta_id TEXT,
58                commit_sha TEXT,
59                utility REAL NOT NULL DEFAULT 0.5,
60                helpful_count INTEGER NOT NULL DEFAULT 0,
61                feedback_count INTEGER NOT NULL DEFAULT 0
62            );
63
64            CREATE TABLE IF NOT EXISTS feedback (
65                id INTEGER PRIMARY KEY AUTOINCREMENT,
66                episode_id TEXT NOT NULL,
67                timestamp TEXT NOT NULL,
68                helpful INTEGER NOT NULL,
69                FOREIGN KEY (episode_id) REFERENCES episodes(id)
70            );
71
72            CREATE INDEX IF NOT EXISTS idx_episodes_project ON episodes(project);
73            CREATE INDEX IF NOT EXISTS idx_episodes_task_type ON episodes(task_type);
74            CREATE INDEX IF NOT EXISTS idx_episodes_created_at ON episodes(created_at);
75            CREATE INDEX IF NOT EXISTS idx_feedback_episode_id ON feedback(episode_id);
76            "#,
77        )?;
78
79        Ok(())
80    }
81
82    /// Store an episode
83    pub fn store_episode(&self, episode: &Episode) -> MemoryResult<()> {
84        let conn = self
85            .conn
86            .lock()
87            .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
88
89        conn.execute(
90            r#"
91            INSERT OR REPLACE INTO episodes
92            (id, created_at, project, summary, task_type, outcome,
93             files_modified, errors_resolved, tags, intent_id, delta_id,
94             commit_sha, utility, helpful_count, feedback_count)
95            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15)
96            "#,
97            params![
98                episode.id.to_string(),
99                episode.created_at.to_rfc3339(),
100                episode.project,
101                episode.summary,
102                episode.task_type,
103                outcome_to_str(&episode.outcome),
104                serde_json::to_string(&episode.files_modified)?,
105                serde_json::to_string(&episode.errors_resolved)?,
106                serde_json::to_string(&episode.tags)?,
107                episode.intent_id.map(|id| id.to_string()),
108                episode.delta_id.map(|id| id.to_string()),
109                episode.commit_sha,
110                episode.utility,
111                episode.helpful_count,
112                episode.feedback_count,
113            ],
114        )?;
115
116        Ok(())
117    }
118
119    /// Get an episode by ID
120    pub fn get_episode(&self, id: Uuid) -> MemoryResult<Option<Episode>> {
121        let conn = self
122            .conn
123            .lock()
124            .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
125
126        let mut stmt = conn.prepare(
127            r#"
128            SELECT id, created_at, project, summary, task_type, outcome,
129                   files_modified, errors_resolved, tags, intent_id, delta_id,
130                   commit_sha, utility, helpful_count, feedback_count
131            FROM episodes WHERE id = ?1
132            "#,
133        )?;
134
135        let result = stmt
136            .query_row([id.to_string()], |row| Ok(row_to_episode_raw(row)))
137            .optional()?;
138
139        match result {
140            Some(ep) => Ok(Some(ep)),
141            None => Ok(None),
142        }
143    }
144
145    /// List all episodes, optionally filtered by project
146    pub fn list_episodes(&self, project: Option<&str>) -> MemoryResult<Vec<Episode>> {
147        let conn = self
148            .conn
149            .lock()
150            .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
151
152        let mut episodes = Vec::new();
153
154        if let Some(proj) = project {
155            let mut stmt = conn.prepare(
156                r#"
157                SELECT id, created_at, project, summary, task_type, outcome,
158                       files_modified, errors_resolved, tags, intent_id, delta_id,
159                       commit_sha, utility, helpful_count, feedback_count
160                FROM episodes WHERE project = ?1
161                ORDER BY created_at DESC
162                "#,
163            )?;
164
165            let rows = stmt.query_map([proj], |row| Ok(row_to_episode_raw(row)))?;
166
167            for row in rows {
168                episodes.push(row?);
169            }
170        } else {
171            let mut stmt = conn.prepare(
172                r#"
173                SELECT id, created_at, project, summary, task_type, outcome,
174                       files_modified, errors_resolved, tags, intent_id, delta_id,
175                       commit_sha, utility, helpful_count, feedback_count
176                FROM episodes ORDER BY created_at DESC
177                "#,
178            )?;
179
180            let rows = stmt.query_map([], |row| Ok(row_to_episode_raw(row)))?;
181
182            for row in rows {
183                episodes.push(row?);
184            }
185        }
186
187        Ok(episodes)
188    }
189
190    /// Update episode utility score
191    pub fn update_utility(&self, id: Uuid, utility: f64) -> MemoryResult<()> {
192        let conn = self
193            .conn
194            .lock()
195            .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
196
197        conn.execute(
198            "UPDATE episodes SET utility = ?1 WHERE id = ?2",
199            params![utility, id.to_string()],
200        )?;
201
202        Ok(())
203    }
204
205    /// Record feedback for an episode
206    pub fn record_feedback(&self, episode_id: Uuid, helpful: bool) -> MemoryResult<()> {
207        let conn = self
208            .conn
209            .lock()
210            .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
211
212        // Insert feedback record
213        conn.execute(
214            r#"
215            INSERT INTO feedback (episode_id, timestamp, helpful)
216            VALUES (?1, ?2, ?3)
217            "#,
218            params![
219                episode_id.to_string(),
220                Utc::now().to_rfc3339(),
221                helpful as i32,
222            ],
223        )?;
224
225        // Update episode counts
226        if helpful {
227            conn.execute(
228                r#"
229                UPDATE episodes
230                SET helpful_count = helpful_count + 1,
231                    feedback_count = feedback_count + 1
232                WHERE id = ?1
233                "#,
234                [episode_id.to_string()],
235            )?;
236        } else {
237            conn.execute(
238                "UPDATE episodes SET feedback_count = feedback_count + 1 WHERE id = ?1",
239                [episode_id.to_string()],
240            )?;
241        }
242
243        Ok(())
244    }
245
246    /// Get all feedback for an episode
247    pub fn get_feedback(&self, episode_id: Uuid) -> MemoryResult<Vec<Feedback>> {
248        let conn = self
249            .conn
250            .lock()
251            .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
252
253        let mut stmt = conn
254            .prepare("SELECT episode_id, timestamp, helpful FROM feedback WHERE episode_id = ?1")?;
255
256        let mut feedback = Vec::new();
257        let rows = stmt.query_map([episode_id.to_string()], |row| {
258            let episode_id_str: String = row.get(0)?;
259            let timestamp_str: String = row.get(1)?;
260            let helpful: i32 = row.get(2)?;
261
262            Ok(Feedback {
263                episode_id: Uuid::parse_str(&episode_id_str).unwrap_or(Uuid::nil()),
264                timestamp: DateTime::parse_from_rfc3339(&timestamp_str)
265                    .map(|dt| dt.with_timezone(&Utc))
266                    .unwrap_or_else(|_| Utc::now()),
267                helpful: helpful != 0,
268            })
269        })?;
270
271        for row in rows {
272            feedback.push(row?);
273        }
274
275        Ok(feedback)
276    }
277
278    /// Get episode IDs for utility propagation
279    pub fn get_all_episode_ids(&self) -> MemoryResult<Vec<Uuid>> {
280        let conn = self
281            .conn
282            .lock()
283            .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
284
285        let mut stmt = conn.prepare("SELECT id FROM episodes")?;
286        let mut ids = Vec::new();
287
288        let rows = stmt.query_map([], |row| {
289            let id_str: String = row.get(0)?;
290            Ok(Uuid::parse_str(&id_str).unwrap_or(Uuid::nil()))
291        })?;
292
293        for row in rows {
294            ids.push(row?);
295        }
296
297        Ok(ids)
298    }
299
300    /// Get statistics about the memory
301    pub fn get_stats(&self, project: Option<&str>) -> MemoryResult<MemoryStats> {
302        let conn = self
303            .conn
304            .lock()
305            .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
306
307        let (total_episodes, total_feedback, avg_utility) = if let Some(proj) = project {
308            let mut stmt = conn.prepare(
309                r#"
310                SELECT COUNT(*), SUM(feedback_count), AVG(utility)
311                FROM episodes WHERE project = ?1
312                "#,
313            )?;
314            stmt.query_row([proj], |row| {
315                Ok((
316                    row.get::<_, i64>(0)? as usize,
317                    row.get::<_, Option<i64>>(1)?.unwrap_or(0) as usize,
318                    row.get::<_, Option<f64>>(2)?.unwrap_or(0.0),
319                ))
320            })?
321        } else {
322            let mut stmt =
323                conn.prepare("SELECT COUNT(*), SUM(feedback_count), AVG(utility) FROM episodes")?;
324            stmt.query_row([], |row| {
325                Ok((
326                    row.get::<_, i64>(0)? as usize,
327                    row.get::<_, Option<i64>>(1)?.unwrap_or(0) as usize,
328                    row.get::<_, Option<f64>>(2)?.unwrap_or(0.0),
329                ))
330            })?
331        };
332
333        Ok(MemoryStats {
334            total_episodes,
335            total_feedback,
336            avg_utility,
337        })
338    }
339}
340
341/// Memory statistics
342#[derive(Debug, Clone)]
343pub struct MemoryStats {
344    pub total_episodes: usize,
345    pub total_feedback: usize,
346    pub avg_utility: f64,
347}
348
349fn outcome_to_str(outcome: &EpisodeOutcome) -> &'static str {
350    match outcome {
351        EpisodeOutcome::Success => "success",
352        EpisodeOutcome::Partial => "partial",
353        EpisodeOutcome::Failure => "failure",
354    }
355}
356
357fn str_to_outcome(s: &str) -> EpisodeOutcome {
358    match s {
359        "success" => EpisodeOutcome::Success,
360        "partial" => EpisodeOutcome::Partial,
361        "failure" => EpisodeOutcome::Failure,
362        _ => EpisodeOutcome::Partial,
363    }
364}
365
366/// Convert a row to an Episode - infallible version that uses defaults for parse errors
367fn row_to_episode_raw(row: &rusqlite::Row) -> Episode {
368    let id_str: String = row.get(0).unwrap_or_default();
369    let created_at_str: String = row.get(1).unwrap_or_default();
370    let project: Option<String> = row.get(2).ok();
371    let summary: String = row.get(3).unwrap_or_default();
372    let task_type: String = row.get(4).unwrap_or_default();
373    let outcome_str: String = row.get(5).unwrap_or_default();
374    let files_json: String = row.get(6).unwrap_or_default();
375    let errors_json: String = row.get(7).unwrap_or_default();
376    let tags_json: String = row.get(8).unwrap_or_default();
377    let intent_id_str: Option<String> = row.get(9).ok();
378    let delta_id_str: Option<String> = row.get(10).ok();
379    let commit_sha: Option<String> = row.get(11).ok();
380    let utility: f64 = row.get(12).unwrap_or(0.5);
381    let helpful_count: u32 = row.get(13).unwrap_or(0);
382    let feedback_count: u32 = row.get(14).unwrap_or(0);
383
384    Episode {
385        id: Uuid::parse_str(&id_str).unwrap_or(Uuid::nil()),
386        created_at: DateTime::parse_from_rfc3339(&created_at_str)
387            .map(|dt| dt.with_timezone(&Utc))
388            .unwrap_or_else(|_| Utc::now()),
389        project,
390        summary,
391        task_type,
392        outcome: str_to_outcome(&outcome_str),
393        files_modified: serde_json::from_str(&files_json).unwrap_or_default(),
394        errors_resolved: serde_json::from_str(&errors_json).unwrap_or_default(),
395        tags: serde_json::from_str(&tags_json).unwrap_or_default(),
396        intent_id: intent_id_str.and_then(|s| Uuid::parse_str(&s).ok()),
397        delta_id: delta_id_str.and_then(|s| Uuid::parse_str(&s).ok()),
398        commit_sha,
399        utility,
400        helpful_count,
401        feedback_count,
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn test_store_and_get_episode() {
411        let storage = EpisodeStorage::in_memory().unwrap();
412
413        let episode = Episode::new(
414            "Test episode".to_string(),
415            "bugfix".to_string(),
416            EpisodeOutcome::Success,
417        )
418        .with_project("test-project".to_string())
419        .with_tags(vec!["rust".to_string()]);
420
421        storage.store_episode(&episode).unwrap();
422
423        let retrieved = storage.get_episode(episode.id).unwrap().unwrap();
424        assert_eq!(retrieved.summary, "Test episode");
425        assert_eq!(retrieved.project, Some("test-project".to_string()));
426    }
427
428    #[test]
429    fn test_list_episodes() {
430        let storage = EpisodeStorage::in_memory().unwrap();
431
432        let ep1 = Episode::new(
433            "Episode 1".to_string(),
434            "feature".to_string(),
435            EpisodeOutcome::Success,
436        )
437        .with_project("proj-a".to_string());
438        let ep2 = Episode::new(
439            "Episode 2".to_string(),
440            "bugfix".to_string(),
441            EpisodeOutcome::Success,
442        )
443        .with_project("proj-b".to_string());
444
445        storage.store_episode(&ep1).unwrap();
446        storage.store_episode(&ep2).unwrap();
447
448        // List all
449        let all = storage.list_episodes(None).unwrap();
450        assert_eq!(all.len(), 2);
451
452        // List by project
453        let proj_a = storage.list_episodes(Some("proj-a")).unwrap();
454        assert_eq!(proj_a.len(), 1);
455        assert_eq!(proj_a[0].summary, "Episode 1");
456    }
457
458    #[test]
459    fn test_feedback() {
460        let storage = EpisodeStorage::in_memory().unwrap();
461
462        let episode = Episode::new(
463            "Test".to_string(),
464            "test".to_string(),
465            EpisodeOutcome::Success,
466        );
467        storage.store_episode(&episode).unwrap();
468
469        // Record positive feedback
470        storage.record_feedback(episode.id, true).unwrap();
471        storage.record_feedback(episode.id, true).unwrap();
472        storage.record_feedback(episode.id, false).unwrap();
473
474        let updated = storage.get_episode(episode.id).unwrap().unwrap();
475        assert_eq!(updated.helpful_count, 2);
476        assert_eq!(updated.feedback_count, 3);
477
478        let feedback = storage.get_feedback(episode.id).unwrap();
479        assert_eq!(feedback.len(), 3);
480    }
481
482    #[test]
483    fn test_update_utility() {
484        let storage = EpisodeStorage::in_memory().unwrap();
485
486        let episode = Episode::new(
487            "Test".to_string(),
488            "test".to_string(),
489            EpisodeOutcome::Success,
490        );
491        storage.store_episode(&episode).unwrap();
492
493        storage.update_utility(episode.id, 0.85).unwrap();
494
495        let updated = storage.get_episode(episode.id).unwrap().unwrap();
496        assert!((updated.utility - 0.85).abs() < 0.001);
497    }
498
499    #[test]
500    fn test_stats() {
501        let storage = EpisodeStorage::in_memory().unwrap();
502
503        let ep1 = Episode::new(
504            "Ep1".to_string(),
505            "test".to_string(),
506            EpisodeOutcome::Success,
507        );
508        let ep2 = Episode::new(
509            "Ep2".to_string(),
510            "test".to_string(),
511            EpisodeOutcome::Partial,
512        );
513
514        storage.store_episode(&ep1).unwrap();
515        storage.store_episode(&ep2).unwrap();
516
517        let stats = storage.get_stats(None).unwrap();
518        assert_eq!(stats.total_episodes, 2);
519    }
520}