Skip to main content

synwire_agent/experience/
pool.rs

1//! SQLite-backed experience pool with global + project-local tiers.
2
3use std::collections::HashMap;
4use std::path::Path;
5use std::sync::Mutex;
6
7/// Error type for experience pool operations.
8#[non_exhaustive]
9#[derive(Debug, thiserror::Error)]
10pub enum ExperienceError {
11    /// `SQLite` error.
12    #[error("SQLite error: {0}")]
13    Sqlite(String),
14    /// I/O error.
15    #[error("I/O error: {0}")]
16    Io(String),
17}
18
19impl From<rusqlite::Error> for ExperienceError {
20    fn from(e: rusqlite::Error) -> Self {
21        Self::Sqlite(e.to_string())
22    }
23}
24
25impl From<std::io::Error> for ExperienceError {
26    fn from(e: std::io::Error) -> Self {
27        Self::Io(e.to_string())
28    }
29}
30
31/// A single experience entry linking a task description to modified files.
32#[non_exhaustive]
33pub struct ExperienceEntry {
34    /// Task description (from the agent prompt).
35    pub task_description: String,
36    /// Files modified in this edit session.
37    pub files_modified: Vec<String>,
38    /// Timestamp (ISO 8601).
39    pub recorded_at: String,
40}
41
42/// Stop words to skip when tokenising a task description for keyword matching.
43const STOP_WORDS: &[&str] = &[
44    "the", "and", "for", "with", "that", "this", "from", "are", "was", "were", "have", "has",
45    "been", "will", "would", "could", "should", "into", "onto", "over", "under", "also", "then",
46    "than", "when",
47];
48
49fn keywords(description: &str) -> Vec<String> {
50    description
51        .split(|c: char| !c.is_alphanumeric())
52        .filter(|w| w.len() > 3)
53        .map(str::to_lowercase)
54        .filter(|w| !STOP_WORDS.contains(&w.as_str()))
55        .collect()
56}
57
58fn init_schema(conn: &rusqlite::Connection) -> Result<(), ExperienceError> {
59    conn.execute_batch(
60        "PRAGMA journal_mode=WAL;
61         PRAGMA synchronous=NORMAL;
62         CREATE TABLE IF NOT EXISTS experiences (
63             id INTEGER PRIMARY KEY AUTOINCREMENT,
64             task_description TEXT NOT NULL,
65             file_path TEXT NOT NULL,
66             recorded_at TEXT NOT NULL
67         );
68         CREATE INDEX IF NOT EXISTS idx_task ON experiences(task_description);
69         CREATE INDEX IF NOT EXISTS idx_file ON experiences(file_path);",
70    )?;
71    Ok(())
72}
73
74/// SQLite-backed experience pool.
75///
76/// Records task-to-file associations and supports keyword-based file retrieval.
77pub struct ExperiencePool {
78    conn: Mutex<rusqlite::Connection>,
79}
80
81impl ExperiencePool {
82    /// Open or create the experience database at the given path.
83    pub fn open(path: &Path) -> Result<Self, ExperienceError> {
84        if let Some(parent) = path.parent() {
85            std::fs::create_dir_all(parent)?;
86        }
87        let conn = rusqlite::Connection::open(path)?;
88        init_schema(&conn)?;
89        Ok(Self {
90            conn: Mutex::new(conn),
91        })
92    }
93
94    /// Record an edit event, inserting one row per file in [`ExperienceEntry::files_modified`].
95    #[allow(clippy::significant_drop_tightening)]
96    pub fn record(&self, entry: &ExperienceEntry) -> Result<(), ExperienceError> {
97        let conn = self
98            .conn
99            .lock()
100            .map_err(|e| ExperienceError::Sqlite(e.to_string()))?;
101        for file in &entry.files_modified {
102            let _rows_inserted = conn.execute(
103                "INSERT INTO experiences (task_description, file_path, recorded_at) VALUES (?1, ?2, ?3)",
104                rusqlite::params![entry.task_description, file, entry.recorded_at],
105            )?;
106        }
107        Ok(())
108    }
109
110    /// Query relevant files for a task description.
111    ///
112    /// Uses keyword matching against stored task descriptions.
113    /// Returns `(file_path, count)` pairs sorted by frequency descending.
114    ///
115    /// # Example
116    ///
117    /// ```no_run
118    /// # use synwire_agent::experience::{ExperiencePool, ExperienceEntry};
119    /// # let pool = ExperiencePool::open(std::path::Path::new("/tmp/exp.db")).unwrap();
120    /// let files = pool.query_files("fix authentication bug").unwrap();
121    /// for (path, count) in files {
122    ///     println!("{path}: {count}");
123    /// }
124    /// ```
125    #[allow(clippy::significant_drop_tightening)]
126    pub fn query_files(&self, description: &str) -> Result<Vec<(String, u32)>, ExperienceError> {
127        let words = keywords(description);
128        if words.is_empty() {
129            return Ok(Vec::new());
130        }
131
132        let conn = self
133            .conn
134            .lock()
135            .map_err(|e| ExperienceError::Sqlite(e.to_string()))?;
136        let mut totals: HashMap<String, u32> = HashMap::new();
137
138        for keyword in &words {
139            let pattern = format!("%{keyword}%");
140            let mut stmt = conn.prepare(
141                "SELECT file_path, COUNT(*) as cnt FROM experiences
142                 WHERE task_description LIKE ?1
143                 GROUP BY file_path",
144            )?;
145            let rows = stmt.query_map(rusqlite::params![pattern], |row| {
146                Ok((row.get::<_, String>(0)?, row.get::<_, u32>(1)?))
147            })?;
148            for row in rows {
149                let (file, cnt) = row?;
150                *totals.entry(file).or_insert(0) += cnt;
151            }
152        }
153
154        let mut result: Vec<(String, u32)> = totals.into_iter().collect();
155        result.sort_by(|a, b| b.1.cmp(&a.1));
156        Ok(result)
157    }
158}
159
160/// Two-tier experience pool: project-local first, global fallback.
161///
162/// Queries local pool first; falls back to global when local yields no results.
163/// Records to both tiers on every edit event.
164pub struct TieredExperiencePool {
165    local: ExperiencePool,
166    global: ExperiencePool,
167}
168
169impl TieredExperiencePool {
170    /// Open both local and global experience pool tiers.
171    pub fn open(local_path: &Path, global_path: &Path) -> Result<Self, ExperienceError> {
172        Ok(Self {
173            local: ExperiencePool::open(local_path)?,
174            global: ExperiencePool::open(global_path)?,
175        })
176    }
177
178    /// Query with local-first fallback to global.
179    ///
180    /// Returns local results if any exist; otherwise returns global results.
181    pub fn query_files(&self, description: &str) -> Result<Vec<(String, u32)>, ExperienceError> {
182        let local = self.local.query_files(description)?;
183        if !local.is_empty() {
184            return Ok(local);
185        }
186        self.global.query_files(description)
187    }
188
189    /// Record to both local and global pools.
190    pub fn record(&self, entry: &ExperienceEntry) -> Result<(), ExperienceError> {
191        self.local.record(entry)?;
192        self.global.record(entry)?;
193        Ok(())
194    }
195}
196
197/// Record an edit completion event to the experience pool.
198///
199/// Called by the agent runtime after a successful edit directive completes.
200///
201/// # Errors
202///
203/// Returns [`ExperienceError`] if the database write fails.
204pub fn record_edit_completion(
205    pool: &ExperiencePool,
206    description: &str,
207    files: &[&str],
208) -> Result<(), ExperienceError> {
209    let entry = ExperienceEntry {
210        task_description: description.to_owned(),
211        files_modified: files.iter().map(|s| (*s).to_string()).collect(),
212        recorded_at: chrono::Utc::now().to_rfc3339(),
213    };
214    pool.record(&entry)
215}
216
217#[cfg(test)]
218#[allow(clippy::unwrap_used, clippy::expect_used)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn experience_pool_cross_session() {
224        let dir = tempfile::tempdir().unwrap();
225        let pool = ExperiencePool::open(dir.path().join("exp.db").as_path()).unwrap();
226
227        let entry = ExperienceEntry {
228            task_description: "fix authentication bug".to_owned(),
229            files_modified: vec!["src/auth.rs".to_owned(), "src/middleware.rs".to_owned()],
230            recorded_at: "2024-01-01T00:00:00Z".to_owned(),
231        };
232        pool.record(&entry).unwrap();
233
234        let files = pool.query_files("authentication").unwrap();
235        assert!(!files.is_empty());
236        assert!(files.iter().any(|(f, _)| f.contains("auth")));
237    }
238
239    #[test]
240    fn tiered_pool_falls_back_to_global() {
241        let dir = tempfile::tempdir().unwrap();
242        let local_path = dir.path().join("local.db");
243        let global_path = dir.path().join("global.db");
244        let tiered = TieredExperiencePool::open(&local_path, &global_path).unwrap();
245
246        // Record only to global via the underlying global pool
247        let global = ExperiencePool::open(&global_path).unwrap();
248        let entry = ExperienceEntry {
249            task_description: "network timeout handling".to_owned(),
250            files_modified: vec!["src/network.rs".to_owned()],
251            recorded_at: "2024-01-01T00:00:00Z".to_owned(),
252        };
253        global.record(&entry).unwrap();
254
255        // Local yields nothing; global fallback should fire
256        let files = tiered.query_files("network timeout").unwrap();
257        assert!(!files.is_empty());
258    }
259
260    #[test]
261    fn record_edit_completion_helper() {
262        let dir = tempfile::tempdir().unwrap();
263        let pool = ExperiencePool::open(dir.path().join("exp.db").as_path()).unwrap();
264        record_edit_completion(&pool, "refactor parser logic", &["src/parser.rs"]).unwrap();
265        let files = pool.query_files("parser logic").unwrap();
266        assert!(!files.is_empty());
267    }
268}