Skip to main content

sediment/
access.rs

1//! Access tracking for memory decay scoring.
2//!
3//! Uses a SQLite sidecar database to track access counts and timestamps,
4//! enabling freshness and frequency-based scoring without modifying LanceDB.
5
6use std::collections::HashMap;
7use std::path::Path;
8
9use rusqlite::{Connection, params};
10
11use crate::error::{Result, SedimentError};
12
13/// Combined access and validation data for decay scoring.
14#[derive(Debug, Clone)]
15pub struct DecayData {
16    pub access_count: u32,
17    pub last_accessed_at: i64,
18    pub created_at: i64,
19    pub validation_count: u32,
20}
21
22/// Tracks item access history in SQLite for decay scoring.
23pub struct AccessTracker {
24    conn: Connection,
25}
26
27impl AccessTracker {
28    /// Open or create the access tracking database.
29    pub fn open(path: &Path) -> Result<Self> {
30        let conn = Connection::open(path).map_err(|e| {
31            SedimentError::Database(format!("Failed to open access database: {}", e))
32        })?;
33
34        if let Err(e) = conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA busy_timeout=5000;") {
35            tracing::warn!("Failed to set SQLite PRAGMAs (access): {}", e);
36        }
37
38        conn.execute_batch(
39            "CREATE TABLE IF NOT EXISTS access_log (
40                item_id TEXT PRIMARY KEY,
41                access_count INTEGER NOT NULL DEFAULT 0,
42                last_accessed_at INTEGER NOT NULL,
43                created_at INTEGER NOT NULL
44            );",
45        )
46        .map_err(|e| {
47            SedimentError::Database(format!("Failed to create access_log table: {}", e))
48        })?;
49
50        // Idempotent schema migration: add validation_count column
51        if let Err(e) = conn.execute_batch(
52            "ALTER TABLE access_log ADD COLUMN validation_count INTEGER NOT NULL DEFAULT 0;",
53        ) {
54            let msg = e.to_string();
55            if !msg.contains("duplicate column") {
56                tracing::warn!("access_log migration unexpected error: {}", msg);
57            }
58        }
59
60        Ok(Self { conn })
61    }
62
63    /// Record an access for an item. If no record exists, creates one with the given created_at.
64    pub fn record_access(&self, item_id: &str, created_at: i64) -> Result<()> {
65        let now = chrono::Utc::now().timestamp();
66        self.conn
67            .execute(
68                "INSERT INTO access_log (item_id, access_count, last_accessed_at, created_at)
69                 VALUES (?1, 1, ?2, ?3)
70                 ON CONFLICT(item_id) DO UPDATE SET
71                     access_count = access_count + 1,
72                     last_accessed_at = ?2",
73                params![item_id, now, created_at],
74            )
75            .map_err(|e| SedimentError::Database(format!("Failed to record access: {}", e)))?;
76        Ok(())
77    }
78
79    /// Record a validation (replace/confirm) for an item.
80    pub fn record_validation(&self, item_id: &str, created_at: i64) -> Result<()> {
81        let now = chrono::Utc::now().timestamp();
82        self.conn
83            .execute(
84                "INSERT INTO access_log (item_id, access_count, last_accessed_at, created_at, validation_count)
85                 VALUES (?1, 0, ?2, ?3, 1)
86                 ON CONFLICT(item_id) DO UPDATE SET
87                     validation_count = validation_count + 1,
88                     last_accessed_at = ?2",
89                params![item_id, now, created_at],
90            )
91            .map_err(|e| {
92                SedimentError::Database(format!("Failed to record validation: {}", e))
93            })?;
94        Ok(())
95    }
96
97    /// Get validation counts for multiple items in a single query.
98    pub fn get_validation_counts(&self, item_ids: &[&str]) -> Result<HashMap<String, u32>> {
99        if item_ids.is_empty() {
100            return Ok(HashMap::new());
101        }
102
103        let placeholders: Vec<String> = item_ids
104            .iter()
105            .enumerate()
106            .map(|(i, _)| format!("?{}", i + 1))
107            .collect();
108        let sql = format!(
109            "SELECT item_id, COALESCE(validation_count, 0) FROM access_log WHERE item_id IN ({})",
110            placeholders.join(", ")
111        );
112
113        let mut stmt = self
114            .conn
115            .prepare(&sql)
116            .map_err(|e| SedimentError::Database(format!("Failed to prepare query: {}", e)))?;
117
118        let params: Vec<&dyn rusqlite::types::ToSql> = item_ids
119            .iter()
120            .map(|id| id as &dyn rusqlite::types::ToSql)
121            .collect();
122
123        let rows = stmt
124            .query_map(params.as_slice(), |row| {
125                Ok((row.get::<_, String>(0)?, row.get::<_, u32>(1)?))
126            })
127            .map_err(|e| {
128                SedimentError::Database(format!("Failed to query validation counts: {}", e))
129            })?;
130
131        let mut map = HashMap::new();
132        for row in rows {
133            let (id, count) = row.map_err(|e| {
134                SedimentError::Database(format!("Failed to read validation count: {}", e))
135            })?;
136            map.insert(id, count);
137        }
138
139        Ok(map)
140    }
141
142    /// Get combined access and validation data for decay scoring in a single query.
143    pub fn get_decay_data(&self, item_ids: &[&str]) -> Result<HashMap<String, DecayData>> {
144        if item_ids.is_empty() {
145            return Ok(HashMap::new());
146        }
147
148        let placeholders: Vec<String> = item_ids
149            .iter()
150            .enumerate()
151            .map(|(i, _)| format!("?{}", i + 1))
152            .collect();
153        let sql = format!(
154            "SELECT item_id, access_count, last_accessed_at, created_at, COALESCE(validation_count, 0) FROM access_log WHERE item_id IN ({})",
155            placeholders.join(", ")
156        );
157
158        let mut stmt = self
159            .conn
160            .prepare(&sql)
161            .map_err(|e| SedimentError::Database(format!("Failed to prepare query: {}", e)))?;
162
163        let params: Vec<&dyn rusqlite::types::ToSql> = item_ids
164            .iter()
165            .map(|id| id as &dyn rusqlite::types::ToSql)
166            .collect();
167
168        let rows = stmt
169            .query_map(params.as_slice(), |row| {
170                Ok((
171                    row.get::<_, String>(0)?,
172                    DecayData {
173                        access_count: row.get::<_, u32>(1)?,
174                        last_accessed_at: row.get::<_, i64>(2)?,
175                        created_at: row.get::<_, i64>(3)?,
176                        validation_count: row.get::<_, u32>(4)?,
177                    },
178                ))
179            })
180            .map_err(|e| SedimentError::Database(format!("Failed to query decay data: {}", e)))?;
181
182        let mut map = HashMap::new();
183        for row in rows {
184            let (id, data) = row.map_err(|e| {
185                SedimentError::Database(format!("Failed to read decay data: {}", e))
186            })?;
187            map.insert(id, data);
188        }
189
190        Ok(map)
191    }
192}
193
194#[cfg(test)]
195/// Record of access history for a single item.
196#[derive(Debug, Clone)]
197pub struct AccessRecord {
198    pub access_count: u32,
199    pub last_accessed_at: i64,
200    pub created_at: i64,
201}
202
203#[cfg(test)]
204impl AccessTracker {
205    /// Get validation count for an item.
206    pub fn get_validation_count(&self, item_id: &str) -> Result<u32> {
207        let count: u32 = self
208            .conn
209            .query_row(
210                "SELECT COALESCE(validation_count, 0) FROM access_log WHERE item_id = ?1",
211                params![item_id],
212                |row| row.get(0),
213            )
214            .unwrap_or(0);
215        Ok(count)
216    }
217
218    /// Get access records for a batch of item IDs.
219    pub fn get_accesses(&self, item_ids: &[&str]) -> Result<HashMap<String, AccessRecord>> {
220        if item_ids.is_empty() {
221            return Ok(HashMap::new());
222        }
223
224        let placeholders: Vec<String> = item_ids
225            .iter()
226            .enumerate()
227            .map(|(i, _)| format!("?{}", i + 1))
228            .collect();
229        let sql = format!(
230            "SELECT item_id, access_count, last_accessed_at, created_at FROM access_log WHERE item_id IN ({})",
231            placeholders.join(", ")
232        );
233
234        let mut stmt = self
235            .conn
236            .prepare(&sql)
237            .map_err(|e| SedimentError::Database(format!("Failed to prepare query: {}", e)))?;
238
239        let params: Vec<&dyn rusqlite::types::ToSql> = item_ids
240            .iter()
241            .map(|id| id as &dyn rusqlite::types::ToSql)
242            .collect();
243
244        let rows = stmt
245            .query_map(params.as_slice(), |row| {
246                Ok((
247                    row.get::<_, String>(0)?,
248                    AccessRecord {
249                        access_count: row.get::<_, u32>(1)?,
250                        last_accessed_at: row.get::<_, i64>(2)?,
251                        created_at: row.get::<_, i64>(3)?,
252                    },
253                ))
254            })
255            .map_err(|e| SedimentError::Database(format!("Failed to query accesses: {}", e)))?;
256
257        let mut map = HashMap::new();
258        for row in rows {
259            let (id, record) = row.map_err(|e| {
260                SedimentError::Database(format!("Failed to read access record: {}", e))
261            })?;
262            map.insert(id, record);
263        }
264
265        Ok(map)
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use tempfile::NamedTempFile;
273
274    #[test]
275    fn test_open_creates_table() {
276        let tmp = NamedTempFile::new().unwrap();
277        let tracker = AccessTracker::open(tmp.path()).unwrap();
278        // Should not error on second open
279        drop(tracker);
280        let _tracker2 = AccessTracker::open(tmp.path()).unwrap();
281    }
282
283    #[test]
284    fn test_record_and_get_access() {
285        let tmp = NamedTempFile::new().unwrap();
286        let tracker = AccessTracker::open(tmp.path()).unwrap();
287
288        let created = 1700000000i64;
289        tracker.record_access("item1", created).unwrap();
290        tracker.record_access("item1", created).unwrap();
291        tracker.record_access("item2", created).unwrap();
292
293        let records = tracker.get_accesses(&["item1", "item2", "item3"]).unwrap();
294
295        assert_eq!(records.len(), 2);
296        assert_eq!(records["item1"].access_count, 2);
297        assert_eq!(records["item1"].created_at, created);
298        assert_eq!(records["item2"].access_count, 1);
299        assert!(!records.contains_key("item3"));
300    }
301
302    #[test]
303    fn test_get_accesses_empty() {
304        let tmp = NamedTempFile::new().unwrap();
305        let tracker = AccessTracker::open(tmp.path()).unwrap();
306        let records = tracker.get_accesses(&[]).unwrap();
307        assert!(records.is_empty());
308    }
309
310    #[test]
311    fn test_record_validation_on_new_item() {
312        // Fix #5: validation should be recorded on the new (replacement) item
313        let tmp = NamedTempFile::new().unwrap();
314        let tracker = AccessTracker::open(tmp.path()).unwrap();
315
316        let created = 1700000000i64;
317        // Record validation on a new item (not pre-existing)
318        tracker.record_validation("new-item", created).unwrap();
319        tracker.record_validation("new-item", created).unwrap();
320
321        let count = tracker.get_validation_count("new-item").unwrap();
322        assert_eq!(
323            count, 2,
324            "Validation count should be 2 after two record_validation calls"
325        );
326
327        // Item that was never validated should have 0
328        let count = tracker.get_validation_count("other-item").unwrap();
329        assert_eq!(count, 0);
330    }
331
332    #[test]
333    fn test_get_validation_counts() {
334        let tmp = NamedTempFile::new().unwrap();
335        let tracker = AccessTracker::open(tmp.path()).unwrap();
336
337        let created = 1700000000i64;
338        tracker.record_validation("item-a", created).unwrap();
339        tracker.record_validation("item-a", created).unwrap();
340        tracker.record_validation("item-b", created).unwrap();
341
342        let counts = tracker
343            .get_validation_counts(&["item-a", "item-b", "item-c"])
344            .unwrap();
345
346        // Batch results should match individual lookups
347        assert_eq!(counts.get("item-a").copied().unwrap_or(0), 2);
348        assert_eq!(counts.get("item-b").copied().unwrap_or(0), 1);
349        // item-c has no record, should not be in map
350        assert!(!counts.contains_key("item-c"));
351
352        // Verify consistency with single-item method
353        assert_eq!(
354            counts.get("item-a").copied().unwrap_or(0),
355            tracker.get_validation_count("item-a").unwrap()
356        );
357    }
358}