Skip to main content

sediment/
access.rs

1//! Access tracking for memory decay scoring.
2//!
3//! Uses a SQLite sidecar database to track access counts and timestamps,
4//! enabling freshness and frequency-based scoring without modifying LanceDB.
5
6use std::collections::HashMap;
7use std::path::Path;
8
9use rusqlite::{Connection, params};
10
11use crate::error::{Result, SedimentError};
12
13/// Record of access history for a single item.
14#[derive(Debug, Clone)]
15pub struct AccessRecord {
16    pub access_count: u32,
17    pub last_accessed_at: i64,
18    pub created_at: i64,
19}
20
21/// Tracks item access history in SQLite for decay scoring.
22pub struct AccessTracker {
23    conn: Connection,
24}
25
26impl AccessTracker {
27    /// Open or create the access tracking database.
28    pub fn open(path: &Path) -> Result<Self> {
29        let conn = Connection::open(path).map_err(|e| {
30            SedimentError::Database(format!("Failed to open access database: {}", e))
31        })?;
32
33        conn.execute_batch(
34            "CREATE TABLE IF NOT EXISTS access_log (
35                item_id TEXT PRIMARY KEY,
36                access_count INTEGER NOT NULL DEFAULT 0,
37                last_accessed_at INTEGER NOT NULL,
38                created_at INTEGER NOT NULL
39            );",
40        )
41        .map_err(|e| {
42            SedimentError::Database(format!("Failed to create access_log table: {}", e))
43        })?;
44
45        // Idempotent schema migration: add validation_count column
46        let _ = conn.execute_batch(
47            "ALTER TABLE access_log ADD COLUMN validation_count INTEGER NOT NULL DEFAULT 0;",
48        );
49
50        Ok(Self { conn })
51    }
52
53    /// Record an access for an item. If no record exists, creates one with the given created_at.
54    pub fn record_access(&self, item_id: &str, created_at: i64) -> Result<()> {
55        let now = chrono::Utc::now().timestamp();
56        self.conn
57            .execute(
58                "INSERT INTO access_log (item_id, access_count, last_accessed_at, created_at)
59                 VALUES (?1, 1, ?2, ?3)
60                 ON CONFLICT(item_id) DO UPDATE SET
61                     access_count = access_count + 1,
62                     last_accessed_at = ?2",
63                params![item_id, now, created_at],
64            )
65            .map_err(|e| SedimentError::Database(format!("Failed to record access: {}", e)))?;
66        Ok(())
67    }
68
69    /// Record a validation (replace/confirm) for an item.
70    pub fn record_validation(&self, item_id: &str, created_at: i64) -> Result<()> {
71        let now = chrono::Utc::now().timestamp();
72        self.conn
73            .execute(
74                "INSERT INTO access_log (item_id, access_count, last_accessed_at, created_at, validation_count)
75                 VALUES (?1, 0, ?2, ?3, 1)
76                 ON CONFLICT(item_id) DO UPDATE SET
77                     validation_count = validation_count + 1,
78                     last_accessed_at = ?2",
79                params![item_id, now, created_at],
80            )
81            .map_err(|e| {
82                SedimentError::Database(format!("Failed to record validation: {}", e))
83            })?;
84        Ok(())
85    }
86
87    /// Get validation count for an item.
88    pub fn get_validation_count(&self, item_id: &str) -> Result<u32> {
89        let count: u32 = self
90            .conn
91            .query_row(
92                "SELECT COALESCE(validation_count, 0) FROM access_log WHERE item_id = ?1",
93                params![item_id],
94                |row| row.get(0),
95            )
96            .unwrap_or(0);
97        Ok(count)
98    }
99
100    /// Get access records for a batch of item IDs.
101    pub fn get_accesses(&self, item_ids: &[&str]) -> Result<HashMap<String, AccessRecord>> {
102        if item_ids.is_empty() {
103            return Ok(HashMap::new());
104        }
105
106        let placeholders: Vec<String> = item_ids
107            .iter()
108            .enumerate()
109            .map(|(i, _)| format!("?{}", i + 1))
110            .collect();
111        let sql = format!(
112            "SELECT item_id, access_count, last_accessed_at, created_at FROM access_log WHERE item_id IN ({})",
113            placeholders.join(", ")
114        );
115
116        let mut stmt = self
117            .conn
118            .prepare(&sql)
119            .map_err(|e| SedimentError::Database(format!("Failed to prepare query: {}", e)))?;
120
121        let params: Vec<&dyn rusqlite::types::ToSql> = item_ids
122            .iter()
123            .map(|id| id as &dyn rusqlite::types::ToSql)
124            .collect();
125
126        let rows = stmt
127            .query_map(params.as_slice(), |row| {
128                Ok((
129                    row.get::<_, String>(0)?,
130                    AccessRecord {
131                        access_count: row.get::<_, u32>(1)?,
132                        last_accessed_at: row.get::<_, i64>(2)?,
133                        created_at: row.get::<_, i64>(3)?,
134                    },
135                ))
136            })
137            .map_err(|e| SedimentError::Database(format!("Failed to query accesses: {}", e)))?;
138
139        let mut map = HashMap::new();
140        for row in rows {
141            let (id, record) = row.map_err(|e| {
142                SedimentError::Database(format!("Failed to read access record: {}", e))
143            })?;
144            map.insert(id, record);
145        }
146
147        Ok(map)
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use tempfile::NamedTempFile;
155
156    #[test]
157    fn test_open_creates_table() {
158        let tmp = NamedTempFile::new().unwrap();
159        let tracker = AccessTracker::open(tmp.path()).unwrap();
160        // Should not error on second open
161        drop(tracker);
162        let _tracker2 = AccessTracker::open(tmp.path()).unwrap();
163    }
164
165    #[test]
166    fn test_record_and_get_access() {
167        let tmp = NamedTempFile::new().unwrap();
168        let tracker = AccessTracker::open(tmp.path()).unwrap();
169
170        let created = 1700000000i64;
171        tracker.record_access("item1", created).unwrap();
172        tracker.record_access("item1", created).unwrap();
173        tracker.record_access("item2", created).unwrap();
174
175        let records = tracker.get_accesses(&["item1", "item2", "item3"]).unwrap();
176
177        assert_eq!(records.len(), 2);
178        assert_eq!(records["item1"].access_count, 2);
179        assert_eq!(records["item1"].created_at, created);
180        assert_eq!(records["item2"].access_count, 1);
181        assert!(!records.contains_key("item3"));
182    }
183
184    #[test]
185    fn test_get_accesses_empty() {
186        let tmp = NamedTempFile::new().unwrap();
187        let tracker = AccessTracker::open(tmp.path()).unwrap();
188        let records = tracker.get_accesses(&[]).unwrap();
189        assert!(records.is_empty());
190    }
191}