Skip to main content

sediment/
access.rs

1//! Access tracking for memory decay scoring.
2//!
3//! Uses a SQLite sidecar database to track access counts and timestamps,
4//! enabling freshness and frequency-based scoring without modifying LanceDB.
5
6use std::collections::HashMap;
7use std::path::Path;
8
9use rusqlite::{Connection, params};
10
11use crate::error::{Result, SedimentError};
12
13/// Record of access history for a single item.
14#[derive(Debug, Clone)]
15pub struct AccessRecord {
16    pub access_count: u32,
17    pub last_accessed_at: i64,
18    pub created_at: i64,
19}
20
21/// Tracks item access history in SQLite for decay scoring.
22pub struct AccessTracker {
23    conn: Connection,
24}
25
26impl AccessTracker {
27    /// Open or create the access tracking database.
28    pub fn open(path: &Path) -> Result<Self> {
29        let conn = Connection::open(path).map_err(|e| {
30            SedimentError::Database(format!("Failed to open access database: {}", e))
31        })?;
32
33        if let Err(e) = conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA busy_timeout=5000;") {
34            tracing::warn!("Failed to set SQLite PRAGMAs (access): {}", e);
35        }
36
37        conn.execute_batch(
38            "CREATE TABLE IF NOT EXISTS access_log (
39                item_id TEXT PRIMARY KEY,
40                access_count INTEGER NOT NULL DEFAULT 0,
41                last_accessed_at INTEGER NOT NULL,
42                created_at INTEGER NOT NULL
43            );",
44        )
45        .map_err(|e| {
46            SedimentError::Database(format!("Failed to create access_log table: {}", e))
47        })?;
48
49        // Idempotent schema migration: add validation_count column
50        if let Err(e) = conn.execute_batch(
51            "ALTER TABLE access_log ADD COLUMN validation_count INTEGER NOT NULL DEFAULT 0;",
52        ) {
53            let msg = e.to_string();
54            if !msg.contains("duplicate column") {
55                tracing::warn!("access_log migration unexpected error: {}", msg);
56            }
57        }
58
59        Ok(Self { conn })
60    }
61
62    /// Record an access for an item. If no record exists, creates one with the given created_at.
63    pub fn record_access(&self, item_id: &str, created_at: i64) -> Result<()> {
64        let now = chrono::Utc::now().timestamp();
65        self.conn
66            .execute(
67                "INSERT INTO access_log (item_id, access_count, last_accessed_at, created_at)
68                 VALUES (?1, 1, ?2, ?3)
69                 ON CONFLICT(item_id) DO UPDATE SET
70                     access_count = access_count + 1,
71                     last_accessed_at = ?2",
72                params![item_id, now, created_at],
73            )
74            .map_err(|e| SedimentError::Database(format!("Failed to record access: {}", e)))?;
75        Ok(())
76    }
77
78    /// Record a validation (replace/confirm) for an item.
79    pub fn record_validation(&self, item_id: &str, created_at: i64) -> Result<()> {
80        let now = chrono::Utc::now().timestamp();
81        self.conn
82            .execute(
83                "INSERT INTO access_log (item_id, access_count, last_accessed_at, created_at, validation_count)
84                 VALUES (?1, 0, ?2, ?3, 1)
85                 ON CONFLICT(item_id) DO UPDATE SET
86                     validation_count = validation_count + 1,
87                     last_accessed_at = ?2",
88                params![item_id, now, created_at],
89            )
90            .map_err(|e| {
91                SedimentError::Database(format!("Failed to record validation: {}", e))
92            })?;
93        Ok(())
94    }
95
96    /// Get validation count for an item.
97    pub fn get_validation_count(&self, item_id: &str) -> Result<u32> {
98        let count: u32 = self
99            .conn
100            .query_row(
101                "SELECT COALESCE(validation_count, 0) FROM access_log WHERE item_id = ?1",
102                params![item_id],
103                |row| row.get(0),
104            )
105            .unwrap_or(0);
106        Ok(count)
107    }
108
109    /// Get validation counts for multiple items in a single query.
110    pub fn get_validation_counts(&self, item_ids: &[&str]) -> Result<HashMap<String, u32>> {
111        if item_ids.is_empty() {
112            return Ok(HashMap::new());
113        }
114
115        let placeholders: Vec<String> = item_ids
116            .iter()
117            .enumerate()
118            .map(|(i, _)| format!("?{}", i + 1))
119            .collect();
120        let sql = format!(
121            "SELECT item_id, COALESCE(validation_count, 0) FROM access_log WHERE item_id IN ({})",
122            placeholders.join(", ")
123        );
124
125        let mut stmt = self
126            .conn
127            .prepare(&sql)
128            .map_err(|e| SedimentError::Database(format!("Failed to prepare query: {}", e)))?;
129
130        let params: Vec<&dyn rusqlite::types::ToSql> = item_ids
131            .iter()
132            .map(|id| id as &dyn rusqlite::types::ToSql)
133            .collect();
134
135        let rows = stmt
136            .query_map(params.as_slice(), |row| {
137                Ok((row.get::<_, String>(0)?, row.get::<_, u32>(1)?))
138            })
139            .map_err(|e| {
140                SedimentError::Database(format!("Failed to query validation counts: {}", e))
141            })?;
142
143        let mut map = HashMap::new();
144        for row in rows {
145            let (id, count) = row.map_err(|e| {
146                SedimentError::Database(format!("Failed to read validation count: {}", e))
147            })?;
148            map.insert(id, count);
149        }
150
151        Ok(map)
152    }
153
154    /// Get access records for a batch of item IDs.
155    pub fn get_accesses(&self, item_ids: &[&str]) -> Result<HashMap<String, AccessRecord>> {
156        if item_ids.is_empty() {
157            return Ok(HashMap::new());
158        }
159
160        let placeholders: Vec<String> = item_ids
161            .iter()
162            .enumerate()
163            .map(|(i, _)| format!("?{}", i + 1))
164            .collect();
165        let sql = format!(
166            "SELECT item_id, access_count, last_accessed_at, created_at FROM access_log WHERE item_id IN ({})",
167            placeholders.join(", ")
168        );
169
170        let mut stmt = self
171            .conn
172            .prepare(&sql)
173            .map_err(|e| SedimentError::Database(format!("Failed to prepare query: {}", e)))?;
174
175        let params: Vec<&dyn rusqlite::types::ToSql> = item_ids
176            .iter()
177            .map(|id| id as &dyn rusqlite::types::ToSql)
178            .collect();
179
180        let rows = stmt
181            .query_map(params.as_slice(), |row| {
182                Ok((
183                    row.get::<_, String>(0)?,
184                    AccessRecord {
185                        access_count: row.get::<_, u32>(1)?,
186                        last_accessed_at: row.get::<_, i64>(2)?,
187                        created_at: row.get::<_, i64>(3)?,
188                    },
189                ))
190            })
191            .map_err(|e| SedimentError::Database(format!("Failed to query accesses: {}", e)))?;
192
193        let mut map = HashMap::new();
194        for row in rows {
195            let (id, record) = row.map_err(|e| {
196                SedimentError::Database(format!("Failed to read access record: {}", e))
197            })?;
198            map.insert(id, record);
199        }
200
201        Ok(map)
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use tempfile::NamedTempFile;
209
210    #[test]
211    fn test_open_creates_table() {
212        let tmp = NamedTempFile::new().unwrap();
213        let tracker = AccessTracker::open(tmp.path()).unwrap();
214        // Should not error on second open
215        drop(tracker);
216        let _tracker2 = AccessTracker::open(tmp.path()).unwrap();
217    }
218
219    #[test]
220    fn test_record_and_get_access() {
221        let tmp = NamedTempFile::new().unwrap();
222        let tracker = AccessTracker::open(tmp.path()).unwrap();
223
224        let created = 1700000000i64;
225        tracker.record_access("item1", created).unwrap();
226        tracker.record_access("item1", created).unwrap();
227        tracker.record_access("item2", created).unwrap();
228
229        let records = tracker.get_accesses(&["item1", "item2", "item3"]).unwrap();
230
231        assert_eq!(records.len(), 2);
232        assert_eq!(records["item1"].access_count, 2);
233        assert_eq!(records["item1"].created_at, created);
234        assert_eq!(records["item2"].access_count, 1);
235        assert!(!records.contains_key("item3"));
236    }
237
238    #[test]
239    fn test_get_accesses_empty() {
240        let tmp = NamedTempFile::new().unwrap();
241        let tracker = AccessTracker::open(tmp.path()).unwrap();
242        let records = tracker.get_accesses(&[]).unwrap();
243        assert!(records.is_empty());
244    }
245
246    #[test]
247    fn test_record_validation_on_new_item() {
248        // Fix #5: validation should be recorded on the new (replacement) item
249        let tmp = NamedTempFile::new().unwrap();
250        let tracker = AccessTracker::open(tmp.path()).unwrap();
251
252        let created = 1700000000i64;
253        // Record validation on a new item (not pre-existing)
254        tracker.record_validation("new-item", created).unwrap();
255        tracker.record_validation("new-item", created).unwrap();
256
257        let count = tracker.get_validation_count("new-item").unwrap();
258        assert_eq!(
259            count, 2,
260            "Validation count should be 2 after two record_validation calls"
261        );
262
263        // Item that was never validated should have 0
264        let count = tracker.get_validation_count("other-item").unwrap();
265        assert_eq!(count, 0);
266    }
267
268    #[test]
269    fn test_get_validation_counts() {
270        let tmp = NamedTempFile::new().unwrap();
271        let tracker = AccessTracker::open(tmp.path()).unwrap();
272
273        let created = 1700000000i64;
274        tracker.record_validation("item-a", created).unwrap();
275        tracker.record_validation("item-a", created).unwrap();
276        tracker.record_validation("item-b", created).unwrap();
277
278        let counts = tracker
279            .get_validation_counts(&["item-a", "item-b", "item-c"])
280            .unwrap();
281
282        // Batch results should match individual lookups
283        assert_eq!(counts.get("item-a").copied().unwrap_or(0), 2);
284        assert_eq!(counts.get("item-b").copied().unwrap_or(0), 1);
285        // item-c has no record, should not be in map
286        assert!(!counts.contains_key("item-c"));
287
288        // Verify consistency with single-item method
289        assert_eq!(
290            counts.get("item-a").copied().unwrap_or(0),
291            tracker.get_validation_count("item-a").unwrap()
292        );
293    }
294}