Skip to main content

zeph_memory/five_signal/
access_frequency.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::HashMap;
5
6use sqlx::Row as _;
7use zeph_db::DbPool;
8
9use crate::types::MessageId;
10
11/// Cap applied before log-normalization to prevent a single ultra-hot fact from
12/// dominating frequency scores across the candidate set.
13const MAX_ACCESS_COUNT: f64 = 10_000.0;
14
15/// Per-turn access frequency aggregator backed by `fact_access_log`.
16///
17/// Loads raw access counts for a candidate set in a single `GROUP BY` query per turn.
18/// Normalized values are `log(1 + count) / log(1 + MAX_ACCESS_COUNT)` ∈ `[0.0, 1.0]`.
19pub struct AccessFrequencyCache {
20    pool: DbPool,
21}
22
23impl AccessFrequencyCache {
24    /// Create a new cache backed by the given pool.
25    #[must_use]
26    pub fn new(pool: DbPool) -> Self {
27        Self { pool }
28    }
29
30    /// Load and normalize access counts for `fact_ids` within `session_id`.
31    ///
32    /// Issues a single SQL `GROUP BY` query indexed by `(session_id, accessed_at DESC)`.
33    /// Returns a map of `fact_id → normalized_score ∈ [0.0, 1.0]`.
34    ///
35    /// # Errors
36    ///
37    /// Returns an error if the database query fails.
38    #[tracing::instrument(
39        name = "memory.five_signal.access_frequency.load",
40        skip(self, fact_ids),
41        fields(fact_count = fact_ids.len())
42    )]
43    pub async fn load_for_candidates(
44        &self,
45        session_id: &str,
46        fact_ids: &[MessageId],
47    ) -> Result<HashMap<MessageId, f64>, crate::error::MemoryError> {
48        tracing::debug!("five_signal: loading access frequencies");
49
50        if fact_ids.is_empty() {
51            return Ok(HashMap::new());
52        }
53
54        let ids: Vec<i64> = fact_ids.iter().map(|id| id.0).collect();
55
56        // sqlx does not support binding Vec<i64> with IN directly for all backends;
57        // build the query with placeholders manually.
58        let placeholders: String = ids
59            .iter()
60            .enumerate()
61            .map(|(i, _)| format!("?{}", i + 2))
62            .collect::<Vec<_>>()
63            .join(", ");
64
65        let sql = format!(
66            "SELECT fact_id, COUNT(*) as cnt FROM fact_access_log \
67             WHERE session_id = ?1 AND fact_id IN ({placeholders}) \
68             GROUP BY fact_id"
69        );
70
71        let mut q = sqlx::query(&sql).bind(session_id);
72        for id in &ids {
73            q = q.bind(id);
74        }
75
76        let rows = q
77            .fetch_all(&self.pool)
78            .await
79            .map_err(|e| crate::error::MemoryError::Db(e.into()))?;
80
81        let counts: HashMap<i64, i64> = rows
82            .iter()
83            .map(|row| (row.get::<i64, _>("fact_id"), row.get::<i64, _>("cnt")))
84            .collect();
85
86        let normalized = fact_ids
87            .iter()
88            .map(|id| {
89                #[expect(clippy::cast_precision_loss)]
90                let raw = *counts.get(&id.0).unwrap_or(&0) as f64;
91                let score =
92                    (1.0_f64 + raw.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
93                (*id, score)
94            })
95            .collect();
96
97        Ok(normalized)
98    }
99
100    /// Record a fact access event in `fact_access_log`.
101    ///
102    /// Failures are logged as `WARN` and do not propagate — access logging is non-critical.
103    #[tracing::instrument(
104        name = "memory.five_signal.access_frequency.log",
105        skip(self, fact_type, session_id),
106        fields(fact_id = fact_id.0)
107    )]
108    pub async fn log_access(&self, fact_id: MessageId, fact_type: &str, session_id: &str) {
109        tracing::debug!("five_signal: logging access");
110
111        let accessed_at = std::time::SystemTime::now()
112            .duration_since(std::time::UNIX_EPOCH)
113            .map_or(0, |d| i64::try_from(d.as_secs()).unwrap_or(i64::MAX));
114
115        let res = sqlx::query(
116            "INSERT INTO fact_access_log (fact_id, fact_type, session_id, accessed_at) \
117             VALUES (?1, ?2, ?3, ?4)",
118        )
119        .bind(fact_id.0)
120        .bind(fact_type)
121        .bind(session_id)
122        .bind(accessed_at)
123        .execute(&self.pool)
124        .await;
125
126        if let Err(e) = res {
127            tracing::warn!(
128                fact_id = fact_id.0,
129                error = %e,
130                "five_signal: failed to log fact access (non-fatal)"
131            );
132        }
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    async fn test_pool() -> DbPool {
141        crate::store::SqliteStore::with_pool_size(":memory:", 1)
142            .await
143            .expect("in-memory SQLite failed")
144            .pool()
145            .clone()
146    }
147
148    #[tokio::test]
149    async fn load_for_candidates_empty_returns_empty() {
150        let pool = test_pool().await;
151        let cache = AccessFrequencyCache::new(pool);
152        let result = cache.load_for_candidates("s1", &[]).await.unwrap();
153        assert!(result.is_empty());
154    }
155
156    #[tokio::test]
157    async fn load_for_candidates_no_rows_gives_zero_score() {
158        let pool = test_pool().await;
159        let cache = AccessFrequencyCache::new(pool);
160        let ids = vec![MessageId(1), MessageId(2)];
161        let scores = cache.load_for_candidates("s1", &ids).await.unwrap();
162        assert_eq!(scores.len(), 2);
163        assert!(scores[&MessageId(1)] < f64::EPSILON);
164        assert!(scores[&MessageId(2)] < f64::EPSILON);
165    }
166
167    #[tokio::test]
168    async fn load_for_candidates_higher_count_gives_higher_score() {
169        let pool = test_pool().await;
170        let cache = AccessFrequencyCache::new(pool.clone());
171        let session = "test-session";
172
173        // Insert 1 access for fact 10 and 5 accesses for fact 20.
174        sqlx::query(
175            "INSERT INTO fact_access_log (fact_id, fact_type, session_id, accessed_at) \
176             VALUES (?1, 'episodic', ?2, 0)",
177        )
178        .bind(10_i64)
179        .bind(session)
180        .execute(&pool)
181        .await
182        .unwrap();
183
184        for _ in 0..5_u8 {
185            sqlx::query(
186                "INSERT INTO fact_access_log (fact_id, fact_type, session_id, accessed_at) \
187                 VALUES (?1, 'episodic', ?2, 0)",
188            )
189            .bind(20_i64)
190            .bind(session)
191            .execute(&pool)
192            .await
193            .unwrap();
194        }
195
196        let ids = vec![MessageId(10), MessageId(20)];
197        let scores = cache.load_for_candidates(session, &ids).await.unwrap();
198
199        let s10 = scores[&MessageId(10)];
200        let s20 = scores[&MessageId(20)];
201        assert!(
202            s20 > s10,
203            "higher access count must yield higher score: {s20} vs {s10}"
204        );
205        assert!(s10 > 0.0, "score for fact with 1 access must be > 0");
206        assert!(s20 <= 1.0, "score must be capped at 1.0");
207    }
208
209    #[tokio::test]
210    async fn load_for_candidates_ignores_other_sessions() {
211        let pool = test_pool().await;
212        let cache = AccessFrequencyCache::new(pool.clone());
213
214        sqlx::query(
215            "INSERT INTO fact_access_log (fact_id, fact_type, session_id, accessed_at) \
216             VALUES (?1, 'episodic', ?2, 0)",
217        )
218        .bind(99_i64)
219        .bind("other-session")
220        .execute(&pool)
221        .await
222        .unwrap();
223
224        let ids = vec![MessageId(99)];
225        let scores = cache.load_for_candidates("my-session", &ids).await.unwrap();
226        assert!(
227            scores[&MessageId(99)] < f64::EPSILON,
228            "score must be 0 for different session"
229        );
230    }
231
232    #[test]
233    fn normalization_zero_count() {
234        let raw = 0.0_f64;
235        let score = (1.0 + raw.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
236        assert!((score).abs() < 1e-9, "zero access → score 0.0");
237    }
238
239    #[test]
240    fn normalization_max_count() {
241        let raw = MAX_ACCESS_COUNT;
242        let score = (1.0 + raw.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
243        assert!((score - 1.0).abs() < 1e-9, "max access → score 1.0");
244    }
245
246    #[test]
247    fn normalization_overflow_clamped() {
248        let raw = MAX_ACCESS_COUNT * 2.0;
249        let score = (1.0 + raw.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
250        assert!((score - 1.0).abs() < 1e-9, "overflow is clamped to 1.0");
251    }
252
253    #[test]
254    fn normalization_monotone() {
255        let score_low = (1.0 + 10.0_f64.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
256        let score_high =
257            (1.0 + 100.0_f64.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
258        assert!(score_high > score_low);
259    }
260}