Skip to main content

sediment/
access.rs

1//! Access tracking for memory decay scoring.
2//!
3//! Uses a SQLite sidecar database to track access counts and timestamps,
4//! enabling freshness and frequency-based scoring without modifying LanceDB.
5
6use std::collections::HashMap;
7use std::path::Path;
8
9use rusqlite::{Connection, params};
10
11use crate::error::{Result, SedimentError};
12
13/// Record of access history for a single item.
14#[derive(Debug, Clone)]
15pub struct AccessRecord {
16    pub access_count: u32,
17    pub last_accessed_at: i64,
18    pub created_at: i64,
19}
20
21/// Combined access and validation data for decay scoring.
22#[derive(Debug, Clone)]
23pub struct DecayData {
24    pub access_count: u32,
25    pub last_accessed_at: i64,
26    pub created_at: i64,
27    pub validation_count: u32,
28}
29
30/// Tracks item access history in SQLite for decay scoring.
31pub struct AccessTracker {
32    conn: Connection,
33}
34
35impl AccessTracker {
36    /// Open or create the access tracking database.
37    pub fn open(path: &Path) -> Result<Self> {
38        let conn = Connection::open(path).map_err(|e| {
39            SedimentError::Database(format!("Failed to open access database: {}", e))
40        })?;
41
42        if let Err(e) = conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA busy_timeout=5000;") {
43            tracing::warn!("Failed to set SQLite PRAGMAs (access): {}", e);
44        }
45
46        conn.execute_batch(
47            "CREATE TABLE IF NOT EXISTS access_log (
48                item_id TEXT PRIMARY KEY,
49                access_count INTEGER NOT NULL DEFAULT 0,
50                last_accessed_at INTEGER NOT NULL,
51                created_at INTEGER NOT NULL
52            );",
53        )
54        .map_err(|e| {
55            SedimentError::Database(format!("Failed to create access_log table: {}", e))
56        })?;
57
58        // Idempotent schema migration: add validation_count column
59        if let Err(e) = conn.execute_batch(
60            "ALTER TABLE access_log ADD COLUMN validation_count INTEGER NOT NULL DEFAULT 0;",
61        ) {
62            let msg = e.to_string();
63            if !msg.contains("duplicate column") {
64                tracing::warn!("access_log migration unexpected error: {}", msg);
65            }
66        }
67
68        Ok(Self { conn })
69    }
70
71    /// Record an access for an item. If no record exists, creates one with the given created_at.
72    pub fn record_access(&self, item_id: &str, created_at: i64) -> Result<()> {
73        let now = chrono::Utc::now().timestamp();
74        self.conn
75            .execute(
76                "INSERT INTO access_log (item_id, access_count, last_accessed_at, created_at)
77                 VALUES (?1, 1, ?2, ?3)
78                 ON CONFLICT(item_id) DO UPDATE SET
79                     access_count = access_count + 1,
80                     last_accessed_at = ?2",
81                params![item_id, now, created_at],
82            )
83            .map_err(|e| SedimentError::Database(format!("Failed to record access: {}", e)))?;
84        Ok(())
85    }
86
87    /// Record a validation (replace/confirm) for an item.
88    pub fn record_validation(&self, item_id: &str, created_at: i64) -> Result<()> {
89        let now = chrono::Utc::now().timestamp();
90        self.conn
91            .execute(
92                "INSERT INTO access_log (item_id, access_count, last_accessed_at, created_at, validation_count)
93                 VALUES (?1, 0, ?2, ?3, 1)
94                 ON CONFLICT(item_id) DO UPDATE SET
95                     validation_count = validation_count + 1,
96                     last_accessed_at = ?2",
97                params![item_id, now, created_at],
98            )
99            .map_err(|e| {
100                SedimentError::Database(format!("Failed to record validation: {}", e))
101            })?;
102        Ok(())
103    }
104
105    /// Get validation count for an item.
106    pub fn get_validation_count(&self, item_id: &str) -> Result<u32> {
107        let count: u32 = self
108            .conn
109            .query_row(
110                "SELECT COALESCE(validation_count, 0) FROM access_log WHERE item_id = ?1",
111                params![item_id],
112                |row| row.get(0),
113            )
114            .unwrap_or(0);
115        Ok(count)
116    }
117
118    /// Get validation counts for multiple items in a single query.
119    pub fn get_validation_counts(&self, item_ids: &[&str]) -> Result<HashMap<String, u32>> {
120        if item_ids.is_empty() {
121            return Ok(HashMap::new());
122        }
123
124        let placeholders: Vec<String> = item_ids
125            .iter()
126            .enumerate()
127            .map(|(i, _)| format!("?{}", i + 1))
128            .collect();
129        let sql = format!(
130            "SELECT item_id, COALESCE(validation_count, 0) FROM access_log WHERE item_id IN ({})",
131            placeholders.join(", ")
132        );
133
134        let mut stmt = self
135            .conn
136            .prepare(&sql)
137            .map_err(|e| SedimentError::Database(format!("Failed to prepare query: {}", e)))?;
138
139        let params: Vec<&dyn rusqlite::types::ToSql> = item_ids
140            .iter()
141            .map(|id| id as &dyn rusqlite::types::ToSql)
142            .collect();
143
144        let rows = stmt
145            .query_map(params.as_slice(), |row| {
146                Ok((row.get::<_, String>(0)?, row.get::<_, u32>(1)?))
147            })
148            .map_err(|e| {
149                SedimentError::Database(format!("Failed to query validation counts: {}", e))
150            })?;
151
152        let mut map = HashMap::new();
153        for row in rows {
154            let (id, count) = row.map_err(|e| {
155                SedimentError::Database(format!("Failed to read validation count: {}", e))
156            })?;
157            map.insert(id, count);
158        }
159
160        Ok(map)
161    }
162
163    /// Get combined access and validation data for decay scoring in a single query.
164    pub fn get_decay_data(&self, item_ids: &[&str]) -> Result<HashMap<String, DecayData>> {
165        if item_ids.is_empty() {
166            return Ok(HashMap::new());
167        }
168
169        let placeholders: Vec<String> = item_ids
170            .iter()
171            .enumerate()
172            .map(|(i, _)| format!("?{}", i + 1))
173            .collect();
174        let sql = format!(
175            "SELECT item_id, access_count, last_accessed_at, created_at, COALESCE(validation_count, 0) FROM access_log WHERE item_id IN ({})",
176            placeholders.join(", ")
177        );
178
179        let mut stmt = self
180            .conn
181            .prepare(&sql)
182            .map_err(|e| SedimentError::Database(format!("Failed to prepare query: {}", e)))?;
183
184        let params: Vec<&dyn rusqlite::types::ToSql> = item_ids
185            .iter()
186            .map(|id| id as &dyn rusqlite::types::ToSql)
187            .collect();
188
189        let rows = stmt
190            .query_map(params.as_slice(), |row| {
191                Ok((
192                    row.get::<_, String>(0)?,
193                    DecayData {
194                        access_count: row.get::<_, u32>(1)?,
195                        last_accessed_at: row.get::<_, i64>(2)?,
196                        created_at: row.get::<_, i64>(3)?,
197                        validation_count: row.get::<_, u32>(4)?,
198                    },
199                ))
200            })
201            .map_err(|e| SedimentError::Database(format!("Failed to query decay data: {}", e)))?;
202
203        let mut map = HashMap::new();
204        for row in rows {
205            let (id, data) = row.map_err(|e| {
206                SedimentError::Database(format!("Failed to read decay data: {}", e))
207            })?;
208            map.insert(id, data);
209        }
210
211        Ok(map)
212    }
213
214    /// Get access records for a batch of item IDs.
215    pub fn get_accesses(&self, item_ids: &[&str]) -> Result<HashMap<String, AccessRecord>> {
216        if item_ids.is_empty() {
217            return Ok(HashMap::new());
218        }
219
220        let placeholders: Vec<String> = item_ids
221            .iter()
222            .enumerate()
223            .map(|(i, _)| format!("?{}", i + 1))
224            .collect();
225        let sql = format!(
226            "SELECT item_id, access_count, last_accessed_at, created_at FROM access_log WHERE item_id IN ({})",
227            placeholders.join(", ")
228        );
229
230        let mut stmt = self
231            .conn
232            .prepare(&sql)
233            .map_err(|e| SedimentError::Database(format!("Failed to prepare query: {}", e)))?;
234
235        let params: Vec<&dyn rusqlite::types::ToSql> = item_ids
236            .iter()
237            .map(|id| id as &dyn rusqlite::types::ToSql)
238            .collect();
239
240        let rows = stmt
241            .query_map(params.as_slice(), |row| {
242                Ok((
243                    row.get::<_, String>(0)?,
244                    AccessRecord {
245                        access_count: row.get::<_, u32>(1)?,
246                        last_accessed_at: row.get::<_, i64>(2)?,
247                        created_at: row.get::<_, i64>(3)?,
248                    },
249                ))
250            })
251            .map_err(|e| SedimentError::Database(format!("Failed to query accesses: {}", e)))?;
252
253        let mut map = HashMap::new();
254        for row in rows {
255            let (id, record) = row.map_err(|e| {
256                SedimentError::Database(format!("Failed to read access record: {}", e))
257            })?;
258            map.insert(id, record);
259        }
260
261        Ok(map)
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use tempfile::NamedTempFile;
269
270    #[test]
271    fn test_open_creates_table() {
272        let tmp = NamedTempFile::new().unwrap();
273        let tracker = AccessTracker::open(tmp.path()).unwrap();
274        // Should not error on second open
275        drop(tracker);
276        let _tracker2 = AccessTracker::open(tmp.path()).unwrap();
277    }
278
279    #[test]
280    fn test_record_and_get_access() {
281        let tmp = NamedTempFile::new().unwrap();
282        let tracker = AccessTracker::open(tmp.path()).unwrap();
283
284        let created = 1700000000i64;
285        tracker.record_access("item1", created).unwrap();
286        tracker.record_access("item1", created).unwrap();
287        tracker.record_access("item2", created).unwrap();
288
289        let records = tracker.get_accesses(&["item1", "item2", "item3"]).unwrap();
290
291        assert_eq!(records.len(), 2);
292        assert_eq!(records["item1"].access_count, 2);
293        assert_eq!(records["item1"].created_at, created);
294        assert_eq!(records["item2"].access_count, 1);
295        assert!(!records.contains_key("item3"));
296    }
297
298    #[test]
299    fn test_get_accesses_empty() {
300        let tmp = NamedTempFile::new().unwrap();
301        let tracker = AccessTracker::open(tmp.path()).unwrap();
302        let records = tracker.get_accesses(&[]).unwrap();
303        assert!(records.is_empty());
304    }
305
306    #[test]
307    fn test_record_validation_on_new_item() {
308        // Fix #5: validation should be recorded on the new (replacement) item
309        let tmp = NamedTempFile::new().unwrap();
310        let tracker = AccessTracker::open(tmp.path()).unwrap();
311
312        let created = 1700000000i64;
313        // Record validation on a new item (not pre-existing)
314        tracker.record_validation("new-item", created).unwrap();
315        tracker.record_validation("new-item", created).unwrap();
316
317        let count = tracker.get_validation_count("new-item").unwrap();
318        assert_eq!(
319            count, 2,
320            "Validation count should be 2 after two record_validation calls"
321        );
322
323        // Item that was never validated should have 0
324        let count = tracker.get_validation_count("other-item").unwrap();
325        assert_eq!(count, 0);
326    }
327
328    #[test]
329    fn test_get_validation_counts() {
330        let tmp = NamedTempFile::new().unwrap();
331        let tracker = AccessTracker::open(tmp.path()).unwrap();
332
333        let created = 1700000000i64;
334        tracker.record_validation("item-a", created).unwrap();
335        tracker.record_validation("item-a", created).unwrap();
336        tracker.record_validation("item-b", created).unwrap();
337
338        let counts = tracker
339            .get_validation_counts(&["item-a", "item-b", "item-c"])
340            .unwrap();
341
342        // Batch results should match individual lookups
343        assert_eq!(counts.get("item-a").copied().unwrap_or(0), 2);
344        assert_eq!(counts.get("item-b").copied().unwrap_or(0), 1);
345        // item-c has no record, should not be in map
346        assert!(!counts.contains_key("item-c"));
347
348        // Verify consistency with single-item method
349        assert_eq!(
350            counts.get("item-a").copied().unwrap_or(0),
351            tracker.get_validation_count("item-a").unwrap()
352        );
353    }
354}