1use crate::error::{MemoryError, MemoryResult};
4use crate::types::{Episode, EpisodeOutcome, Feedback};
5use chrono::{DateTime, Utc};
6use rusqlite::{params, Connection, OptionalExtension};
7use std::path::Path;
8use std::sync::{Arc, Mutex};
9use uuid::Uuid;
10
11pub struct EpisodeStorage {
13 conn: Arc<Mutex<Connection>>,
14}
15
16impl EpisodeStorage {
17 pub fn open(path: &Path) -> MemoryResult<Self> {
19 let conn = Connection::open(path)?;
20 let storage = Self {
21 conn: Arc::new(Mutex::new(conn)),
22 };
23 storage.init_schema()?;
24 Ok(storage)
25 }
26
27 pub fn in_memory() -> MemoryResult<Self> {
29 let conn = Connection::open_in_memory()?;
30 let storage = Self {
31 conn: Arc::new(Mutex::new(conn)),
32 };
33 storage.init_schema()?;
34 Ok(storage)
35 }
36
37 fn init_schema(&self) -> MemoryResult<()> {
39 let conn = self
40 .conn
41 .lock()
42 .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
43
44 conn.execute_batch(
45 r#"
46 CREATE TABLE IF NOT EXISTS episodes (
47 id TEXT PRIMARY KEY,
48 created_at TEXT NOT NULL,
49 project TEXT,
50 summary TEXT NOT NULL,
51 task_type TEXT NOT NULL,
52 outcome TEXT NOT NULL,
53 files_modified TEXT NOT NULL,
54 errors_resolved TEXT NOT NULL,
55 tags TEXT NOT NULL,
56 intent_id TEXT,
57 delta_id TEXT,
58 commit_sha TEXT,
59 utility REAL NOT NULL DEFAULT 0.5,
60 helpful_count INTEGER NOT NULL DEFAULT 0,
61 feedback_count INTEGER NOT NULL DEFAULT 0
62 );
63
64 CREATE TABLE IF NOT EXISTS feedback (
65 id INTEGER PRIMARY KEY AUTOINCREMENT,
66 episode_id TEXT NOT NULL,
67 timestamp TEXT NOT NULL,
68 helpful INTEGER NOT NULL,
69 FOREIGN KEY (episode_id) REFERENCES episodes(id)
70 );
71
72 CREATE INDEX IF NOT EXISTS idx_episodes_project ON episodes(project);
73 CREATE INDEX IF NOT EXISTS idx_episodes_task_type ON episodes(task_type);
74 CREATE INDEX IF NOT EXISTS idx_episodes_created_at ON episodes(created_at);
75 CREATE INDEX IF NOT EXISTS idx_feedback_episode_id ON feedback(episode_id);
76 "#,
77 )?;
78
79 Ok(())
80 }
81
82 pub fn store_episode(&self, episode: &Episode) -> MemoryResult<()> {
84 let conn = self
85 .conn
86 .lock()
87 .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
88
89 conn.execute(
90 r#"
91 INSERT OR REPLACE INTO episodes
92 (id, created_at, project, summary, task_type, outcome,
93 files_modified, errors_resolved, tags, intent_id, delta_id,
94 commit_sha, utility, helpful_count, feedback_count)
95 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15)
96 "#,
97 params![
98 episode.id.to_string(),
99 episode.created_at.to_rfc3339(),
100 episode.project,
101 episode.summary,
102 episode.task_type,
103 outcome_to_str(&episode.outcome),
104 serde_json::to_string(&episode.files_modified)?,
105 serde_json::to_string(&episode.errors_resolved)?,
106 serde_json::to_string(&episode.tags)?,
107 episode.intent_id.map(|id| id.to_string()),
108 episode.delta_id.map(|id| id.to_string()),
109 episode.commit_sha,
110 episode.utility,
111 episode.helpful_count,
112 episode.feedback_count,
113 ],
114 )?;
115
116 Ok(())
117 }
118
119 pub fn get_episode(&self, id: Uuid) -> MemoryResult<Option<Episode>> {
121 let conn = self
122 .conn
123 .lock()
124 .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
125
126 let mut stmt = conn.prepare(
127 r#"
128 SELECT id, created_at, project, summary, task_type, outcome,
129 files_modified, errors_resolved, tags, intent_id, delta_id,
130 commit_sha, utility, helpful_count, feedback_count
131 FROM episodes WHERE id = ?1
132 "#,
133 )?;
134
135 let result = stmt
136 .query_row([id.to_string()], |row| Ok(row_to_episode_raw(row)))
137 .optional()?;
138
139 match result {
140 Some(ep) => Ok(Some(ep)),
141 None => Ok(None),
142 }
143 }
144
145 pub fn list_episodes(&self, project: Option<&str>) -> MemoryResult<Vec<Episode>> {
147 let conn = self
148 .conn
149 .lock()
150 .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
151
152 let mut episodes = Vec::new();
153
154 if let Some(proj) = project {
155 let mut stmt = conn.prepare(
156 r#"
157 SELECT id, created_at, project, summary, task_type, outcome,
158 files_modified, errors_resolved, tags, intent_id, delta_id,
159 commit_sha, utility, helpful_count, feedback_count
160 FROM episodes WHERE project = ?1
161 ORDER BY created_at DESC
162 "#,
163 )?;
164
165 let rows = stmt.query_map([proj], |row| Ok(row_to_episode_raw(row)))?;
166
167 for row in rows {
168 episodes.push(row?);
169 }
170 } else {
171 let mut stmt = conn.prepare(
172 r#"
173 SELECT id, created_at, project, summary, task_type, outcome,
174 files_modified, errors_resolved, tags, intent_id, delta_id,
175 commit_sha, utility, helpful_count, feedback_count
176 FROM episodes ORDER BY created_at DESC
177 "#,
178 )?;
179
180 let rows = stmt.query_map([], |row| Ok(row_to_episode_raw(row)))?;
181
182 for row in rows {
183 episodes.push(row?);
184 }
185 }
186
187 Ok(episodes)
188 }
189
190 pub fn update_utility(&self, id: Uuid, utility: f64) -> MemoryResult<()> {
192 let conn = self
193 .conn
194 .lock()
195 .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
196
197 conn.execute(
198 "UPDATE episodes SET utility = ?1 WHERE id = ?2",
199 params![utility, id.to_string()],
200 )?;
201
202 Ok(())
203 }
204
205 pub fn record_feedback(&self, episode_id: Uuid, helpful: bool) -> MemoryResult<()> {
207 let conn = self
208 .conn
209 .lock()
210 .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
211
212 conn.execute(
214 r#"
215 INSERT INTO feedback (episode_id, timestamp, helpful)
216 VALUES (?1, ?2, ?3)
217 "#,
218 params![
219 episode_id.to_string(),
220 Utc::now().to_rfc3339(),
221 helpful as i32,
222 ],
223 )?;
224
225 if helpful {
227 conn.execute(
228 r#"
229 UPDATE episodes
230 SET helpful_count = helpful_count + 1,
231 feedback_count = feedback_count + 1
232 WHERE id = ?1
233 "#,
234 [episode_id.to_string()],
235 )?;
236 } else {
237 conn.execute(
238 "UPDATE episodes SET feedback_count = feedback_count + 1 WHERE id = ?1",
239 [episode_id.to_string()],
240 )?;
241 }
242
243 Ok(())
244 }
245
246 pub fn get_feedback(&self, episode_id: Uuid) -> MemoryResult<Vec<Feedback>> {
248 let conn = self
249 .conn
250 .lock()
251 .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
252
253 let mut stmt = conn
254 .prepare("SELECT episode_id, timestamp, helpful FROM feedback WHERE episode_id = ?1")?;
255
256 let mut feedback = Vec::new();
257 let rows = stmt.query_map([episode_id.to_string()], |row| {
258 let episode_id_str: String = row.get(0)?;
259 let timestamp_str: String = row.get(1)?;
260 let helpful: i32 = row.get(2)?;
261
262 Ok(Feedback {
263 episode_id: Uuid::parse_str(&episode_id_str).unwrap_or(Uuid::nil()),
264 timestamp: DateTime::parse_from_rfc3339(×tamp_str)
265 .map(|dt| dt.with_timezone(&Utc))
266 .unwrap_or_else(|_| Utc::now()),
267 helpful: helpful != 0,
268 })
269 })?;
270
271 for row in rows {
272 feedback.push(row?);
273 }
274
275 Ok(feedback)
276 }
277
278 pub fn get_all_episode_ids(&self) -> MemoryResult<Vec<Uuid>> {
280 let conn = self
281 .conn
282 .lock()
283 .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
284
285 let mut stmt = conn.prepare("SELECT id FROM episodes")?;
286 let mut ids = Vec::new();
287
288 let rows = stmt.query_map([], |row| {
289 let id_str: String = row.get(0)?;
290 Ok(Uuid::parse_str(&id_str).unwrap_or(Uuid::nil()))
291 })?;
292
293 for row in rows {
294 ids.push(row?);
295 }
296
297 Ok(ids)
298 }
299
300 pub fn get_stats(&self, project: Option<&str>) -> MemoryResult<MemoryStats> {
302 let conn = self
303 .conn
304 .lock()
305 .map_err(|e| MemoryError::Storage(format!("Failed to acquire lock: {}", e)))?;
306
307 let (total_episodes, total_feedback, avg_utility) = if let Some(proj) = project {
308 let mut stmt = conn.prepare(
309 r#"
310 SELECT COUNT(*), SUM(feedback_count), AVG(utility)
311 FROM episodes WHERE project = ?1
312 "#,
313 )?;
314 stmt.query_row([proj], |row| {
315 Ok((
316 row.get::<_, i64>(0)? as usize,
317 row.get::<_, Option<i64>>(1)?.unwrap_or(0) as usize,
318 row.get::<_, Option<f64>>(2)?.unwrap_or(0.0),
319 ))
320 })?
321 } else {
322 let mut stmt =
323 conn.prepare("SELECT COUNT(*), SUM(feedback_count), AVG(utility) FROM episodes")?;
324 stmt.query_row([], |row| {
325 Ok((
326 row.get::<_, i64>(0)? as usize,
327 row.get::<_, Option<i64>>(1)?.unwrap_or(0) as usize,
328 row.get::<_, Option<f64>>(2)?.unwrap_or(0.0),
329 ))
330 })?
331 };
332
333 Ok(MemoryStats {
334 total_episodes,
335 total_feedback,
336 avg_utility,
337 })
338 }
339}
340
341#[derive(Debug, Clone)]
343pub struct MemoryStats {
344 pub total_episodes: usize,
345 pub total_feedback: usize,
346 pub avg_utility: f64,
347}
348
349fn outcome_to_str(outcome: &EpisodeOutcome) -> &'static str {
350 match outcome {
351 EpisodeOutcome::Success => "success",
352 EpisodeOutcome::Partial => "partial",
353 EpisodeOutcome::Failure => "failure",
354 }
355}
356
357fn str_to_outcome(s: &str) -> EpisodeOutcome {
358 match s {
359 "success" => EpisodeOutcome::Success,
360 "partial" => EpisodeOutcome::Partial,
361 "failure" => EpisodeOutcome::Failure,
362 _ => EpisodeOutcome::Partial,
363 }
364}
365
366fn row_to_episode_raw(row: &rusqlite::Row) -> Episode {
368 let id_str: String = row.get(0).unwrap_or_default();
369 let created_at_str: String = row.get(1).unwrap_or_default();
370 let project: Option<String> = row.get(2).ok();
371 let summary: String = row.get(3).unwrap_or_default();
372 let task_type: String = row.get(4).unwrap_or_default();
373 let outcome_str: String = row.get(5).unwrap_or_default();
374 let files_json: String = row.get(6).unwrap_or_default();
375 let errors_json: String = row.get(7).unwrap_or_default();
376 let tags_json: String = row.get(8).unwrap_or_default();
377 let intent_id_str: Option<String> = row.get(9).ok();
378 let delta_id_str: Option<String> = row.get(10).ok();
379 let commit_sha: Option<String> = row.get(11).ok();
380 let utility: f64 = row.get(12).unwrap_or(0.5);
381 let helpful_count: u32 = row.get(13).unwrap_or(0);
382 let feedback_count: u32 = row.get(14).unwrap_or(0);
383
384 Episode {
385 id: Uuid::parse_str(&id_str).unwrap_or(Uuid::nil()),
386 created_at: DateTime::parse_from_rfc3339(&created_at_str)
387 .map(|dt| dt.with_timezone(&Utc))
388 .unwrap_or_else(|_| Utc::now()),
389 project,
390 summary,
391 task_type,
392 outcome: str_to_outcome(&outcome_str),
393 files_modified: serde_json::from_str(&files_json).unwrap_or_default(),
394 errors_resolved: serde_json::from_str(&errors_json).unwrap_or_default(),
395 tags: serde_json::from_str(&tags_json).unwrap_or_default(),
396 intent_id: intent_id_str.and_then(|s| Uuid::parse_str(&s).ok()),
397 delta_id: delta_id_str.and_then(|s| Uuid::parse_str(&s).ok()),
398 commit_sha,
399 utility,
400 helpful_count,
401 feedback_count,
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_store_and_get_episode() {
411 let storage = EpisodeStorage::in_memory().unwrap();
412
413 let episode = Episode::new(
414 "Test episode".to_string(),
415 "bugfix".to_string(),
416 EpisodeOutcome::Success,
417 )
418 .with_project("test-project".to_string())
419 .with_tags(vec!["rust".to_string()]);
420
421 storage.store_episode(&episode).unwrap();
422
423 let retrieved = storage.get_episode(episode.id).unwrap().unwrap();
424 assert_eq!(retrieved.summary, "Test episode");
425 assert_eq!(retrieved.project, Some("test-project".to_string()));
426 }
427
428 #[test]
429 fn test_list_episodes() {
430 let storage = EpisodeStorage::in_memory().unwrap();
431
432 let ep1 = Episode::new(
433 "Episode 1".to_string(),
434 "feature".to_string(),
435 EpisodeOutcome::Success,
436 )
437 .with_project("proj-a".to_string());
438 let ep2 = Episode::new(
439 "Episode 2".to_string(),
440 "bugfix".to_string(),
441 EpisodeOutcome::Success,
442 )
443 .with_project("proj-b".to_string());
444
445 storage.store_episode(&ep1).unwrap();
446 storage.store_episode(&ep2).unwrap();
447
448 let all = storage.list_episodes(None).unwrap();
450 assert_eq!(all.len(), 2);
451
452 let proj_a = storage.list_episodes(Some("proj-a")).unwrap();
454 assert_eq!(proj_a.len(), 1);
455 assert_eq!(proj_a[0].summary, "Episode 1");
456 }
457
458 #[test]
459 fn test_feedback() {
460 let storage = EpisodeStorage::in_memory().unwrap();
461
462 let episode = Episode::new(
463 "Test".to_string(),
464 "test".to_string(),
465 EpisodeOutcome::Success,
466 );
467 storage.store_episode(&episode).unwrap();
468
469 storage.record_feedback(episode.id, true).unwrap();
471 storage.record_feedback(episode.id, true).unwrap();
472 storage.record_feedback(episode.id, false).unwrap();
473
474 let updated = storage.get_episode(episode.id).unwrap().unwrap();
475 assert_eq!(updated.helpful_count, 2);
476 assert_eq!(updated.feedback_count, 3);
477
478 let feedback = storage.get_feedback(episode.id).unwrap();
479 assert_eq!(feedback.len(), 3);
480 }
481
482 #[test]
483 fn test_update_utility() {
484 let storage = EpisodeStorage::in_memory().unwrap();
485
486 let episode = Episode::new(
487 "Test".to_string(),
488 "test".to_string(),
489 EpisodeOutcome::Success,
490 );
491 storage.store_episode(&episode).unwrap();
492
493 storage.update_utility(episode.id, 0.85).unwrap();
494
495 let updated = storage.get_episode(episode.id).unwrap().unwrap();
496 assert!((updated.utility - 0.85).abs() < 0.001);
497 }
498
499 #[test]
500 fn test_stats() {
501 let storage = EpisodeStorage::in_memory().unwrap();
502
503 let ep1 = Episode::new(
504 "Ep1".to_string(),
505 "test".to_string(),
506 EpisodeOutcome::Success,
507 );
508 let ep2 = Episode::new(
509 "Ep2".to_string(),
510 "test".to_string(),
511 EpisodeOutcome::Partial,
512 );
513
514 storage.store_episode(&ep1).unwrap();
515 storage.store_episode(&ep2).unwrap();
516
517 let stats = storage.get_stats(None).unwrap();
518 assert_eq!(stats.total_episodes, 2);
519 }
520}