zeph_memory/five_signal/
access_frequency.rs1use std::collections::HashMap;
5
6use sqlx::Row as _;
7use zeph_db::DbPool;
8
9use crate::types::MessageId;
10
11const MAX_ACCESS_COUNT: f64 = 10_000.0;
14
15pub struct AccessFrequencyCache {
20 pool: DbPool,
21}
22
23impl AccessFrequencyCache {
24 #[must_use]
26 pub fn new(pool: DbPool) -> Self {
27 Self { pool }
28 }
29
30 #[tracing::instrument(
39 name = "memory.five_signal.access_frequency.load",
40 skip(self, fact_ids),
41 fields(fact_count = fact_ids.len())
42 )]
43 pub async fn load_for_candidates(
44 &self,
45 session_id: &str,
46 fact_ids: &[MessageId],
47 ) -> Result<HashMap<MessageId, f64>, crate::error::MemoryError> {
48 tracing::debug!("five_signal: loading access frequencies");
49
50 if fact_ids.is_empty() {
51 return Ok(HashMap::new());
52 }
53
54 let ids: Vec<i64> = fact_ids.iter().map(|id| id.0).collect();
55
56 let placeholders: String = ids
59 .iter()
60 .enumerate()
61 .map(|(i, _)| format!("?{}", i + 2))
62 .collect::<Vec<_>>()
63 .join(", ");
64
65 let sql = format!(
66 "SELECT fact_id, COUNT(*) as cnt FROM fact_access_log \
67 WHERE session_id = ?1 AND fact_id IN ({placeholders}) \
68 GROUP BY fact_id"
69 );
70
71 let mut q = sqlx::query(&sql).bind(session_id);
72 for id in &ids {
73 q = q.bind(id);
74 }
75
76 let rows = q
77 .fetch_all(&self.pool)
78 .await
79 .map_err(|e| crate::error::MemoryError::Db(e.into()))?;
80
81 let counts: HashMap<i64, i64> = rows
82 .iter()
83 .map(|row| (row.get::<i64, _>("fact_id"), row.get::<i64, _>("cnt")))
84 .collect();
85
86 let normalized = fact_ids
87 .iter()
88 .map(|id| {
89 #[expect(clippy::cast_precision_loss)]
90 let raw = *counts.get(&id.0).unwrap_or(&0) as f64;
91 let score =
92 (1.0_f64 + raw.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
93 (*id, score)
94 })
95 .collect();
96
97 Ok(normalized)
98 }
99
100 #[tracing::instrument(
104 name = "memory.five_signal.access_frequency.log",
105 skip(self, fact_type, session_id),
106 fields(fact_id = fact_id.0)
107 )]
108 pub async fn log_access(&self, fact_id: MessageId, fact_type: &str, session_id: &str) {
109 tracing::debug!("five_signal: logging access");
110
111 let accessed_at = std::time::SystemTime::now()
112 .duration_since(std::time::UNIX_EPOCH)
113 .map_or(0, |d| i64::try_from(d.as_secs()).unwrap_or(i64::MAX));
114
115 let res = sqlx::query(
116 "INSERT INTO fact_access_log (fact_id, fact_type, session_id, accessed_at) \
117 VALUES (?1, ?2, ?3, ?4)",
118 )
119 .bind(fact_id.0)
120 .bind(fact_type)
121 .bind(session_id)
122 .bind(accessed_at)
123 .execute(&self.pool)
124 .await;
125
126 if let Err(e) = res {
127 tracing::warn!(
128 fact_id = fact_id.0,
129 error = %e,
130 "five_signal: failed to log fact access (non-fatal)"
131 );
132 }
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 async fn test_pool() -> DbPool {
141 crate::store::SqliteStore::with_pool_size(":memory:", 1)
142 .await
143 .expect("in-memory SQLite failed")
144 .pool()
145 .clone()
146 }
147
148 #[tokio::test]
149 async fn load_for_candidates_empty_returns_empty() {
150 let pool = test_pool().await;
151 let cache = AccessFrequencyCache::new(pool);
152 let result = cache.load_for_candidates("s1", &[]).await.unwrap();
153 assert!(result.is_empty());
154 }
155
156 #[tokio::test]
157 async fn load_for_candidates_no_rows_gives_zero_score() {
158 let pool = test_pool().await;
159 let cache = AccessFrequencyCache::new(pool);
160 let ids = vec![MessageId(1), MessageId(2)];
161 let scores = cache.load_for_candidates("s1", &ids).await.unwrap();
162 assert_eq!(scores.len(), 2);
163 assert!(scores[&MessageId(1)] < f64::EPSILON);
164 assert!(scores[&MessageId(2)] < f64::EPSILON);
165 }
166
167 #[tokio::test]
168 async fn load_for_candidates_higher_count_gives_higher_score() {
169 let pool = test_pool().await;
170 let cache = AccessFrequencyCache::new(pool.clone());
171 let session = "test-session";
172
173 sqlx::query(
175 "INSERT INTO fact_access_log (fact_id, fact_type, session_id, accessed_at) \
176 VALUES (?1, 'episodic', ?2, 0)",
177 )
178 .bind(10_i64)
179 .bind(session)
180 .execute(&pool)
181 .await
182 .unwrap();
183
184 for _ in 0..5_u8 {
185 sqlx::query(
186 "INSERT INTO fact_access_log (fact_id, fact_type, session_id, accessed_at) \
187 VALUES (?1, 'episodic', ?2, 0)",
188 )
189 .bind(20_i64)
190 .bind(session)
191 .execute(&pool)
192 .await
193 .unwrap();
194 }
195
196 let ids = vec![MessageId(10), MessageId(20)];
197 let scores = cache.load_for_candidates(session, &ids).await.unwrap();
198
199 let s10 = scores[&MessageId(10)];
200 let s20 = scores[&MessageId(20)];
201 assert!(
202 s20 > s10,
203 "higher access count must yield higher score: {s20} vs {s10}"
204 );
205 assert!(s10 > 0.0, "score for fact with 1 access must be > 0");
206 assert!(s20 <= 1.0, "score must be capped at 1.0");
207 }
208
209 #[tokio::test]
210 async fn load_for_candidates_ignores_other_sessions() {
211 let pool = test_pool().await;
212 let cache = AccessFrequencyCache::new(pool.clone());
213
214 sqlx::query(
215 "INSERT INTO fact_access_log (fact_id, fact_type, session_id, accessed_at) \
216 VALUES (?1, 'episodic', ?2, 0)",
217 )
218 .bind(99_i64)
219 .bind("other-session")
220 .execute(&pool)
221 .await
222 .unwrap();
223
224 let ids = vec![MessageId(99)];
225 let scores = cache.load_for_candidates("my-session", &ids).await.unwrap();
226 assert!(
227 scores[&MessageId(99)] < f64::EPSILON,
228 "score must be 0 for different session"
229 );
230 }
231
232 #[test]
233 fn normalization_zero_count() {
234 let raw = 0.0_f64;
235 let score = (1.0 + raw.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
236 assert!((score).abs() < 1e-9, "zero access → score 0.0");
237 }
238
239 #[test]
240 fn normalization_max_count() {
241 let raw = MAX_ACCESS_COUNT;
242 let score = (1.0 + raw.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
243 assert!((score - 1.0).abs() < 1e-9, "max access → score 1.0");
244 }
245
246 #[test]
247 fn normalization_overflow_clamped() {
248 let raw = MAX_ACCESS_COUNT * 2.0;
249 let score = (1.0 + raw.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
250 assert!((score - 1.0).abs() < 1e-9, "overflow is clamped to 1.0");
251 }
252
253 #[test]
254 fn normalization_monotone() {
255 let score_low = (1.0 + 10.0_f64.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
256 let score_high =
257 (1.0 + 100.0_f64.min(MAX_ACCESS_COUNT)).ln() / (1.0 + MAX_ACCESS_COUNT).ln();
258 assert!(score_high > score_low);
259 }
260}