1use std::sync::Arc;
35use std::time::Duration;
36
37use tokio::task::JoinHandle;
38use tokio_util::sync::CancellationToken;
39
40use crate::error::MemoryError;
41use crate::store::SqliteStore;
42
43pub use zeph_common::config::memory::ForgettingConfig;
44
45#[derive(Debug, Default)]
49pub struct ForgettingResult {
50 pub downscaled: u32,
52 pub replayed: u32,
54 pub pruned: u32,
56}
57
58#[must_use]
70pub fn start_forgetting_loop(
71 store: Arc<SqliteStore>,
72 config: ForgettingConfig,
73 cancel: CancellationToken,
74) -> JoinHandle<()> {
75 tokio::spawn(async move {
76 if !config.enabled {
77 tracing::debug!("forgetting sweep disabled (forgetting.enabled = false)");
78 return;
79 }
80
81 let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
82 ticker.tick().await;
84
85 loop {
86 tokio::select! {
87 () = cancel.cancelled() => {
88 tracing::debug!("forgetting loop shutting down");
89 return;
90 }
91 _ = ticker.tick() => {}
92 }
93
94 tracing::debug!("forgetting: starting sweep");
95 let start = std::time::Instant::now();
96
97 match run_forgetting_sweep(&store, &config).await {
98 Ok(r) => {
99 tracing::info!(
100 downscaled = r.downscaled,
101 replayed = r.replayed,
102 pruned = r.pruned,
103 elapsed_ms = start.elapsed().as_millis(),
104 "forgetting: sweep complete"
105 );
106 }
107 Err(e) => {
108 tracing::warn!(
109 error = %e,
110 elapsed_ms = start.elapsed().as_millis(),
111 "forgetting: sweep failed, will retry"
112 );
113 }
114 }
115 }
116 })
117}
118
119#[cfg_attr(
136 feature = "profiling",
137 tracing::instrument(name = "memory.forgetting", skip_all)
138)]
139pub async fn run_forgetting_sweep(
140 store: &SqliteStore,
141 config: &ForgettingConfig,
142) -> Result<ForgettingResult, MemoryError> {
143 if config.decay_rate <= 0.0 || config.decay_rate >= 1.0 {
144 tracing::warn!(
145 decay_rate = config.decay_rate,
146 "forgetting: decay_rate must be in (0.0, 1.0); skipping sweep"
147 );
148 return Ok(ForgettingResult::default());
149 }
150 if config.forgetting_floor < 0.0 || config.forgetting_floor >= 1.0 {
151 tracing::warn!(
152 forgetting_floor = config.forgetting_floor,
153 "forgetting: forgetting_floor must be in [0.0, 1.0); skipping sweep"
154 );
155 return Ok(ForgettingResult::default());
156 }
157 if config.sweep_interval_secs < 60 {
158 tracing::warn!(
159 sweep_interval_secs = config.sweep_interval_secs,
160 "forgetting: sweep_interval_secs must be >= 60; skipping sweep"
161 );
162 return Ok(ForgettingResult::default());
163 }
164 store.run_forgetting_sweep_tx(config).await
165}
166
167#[cfg(test)]
170mod tests {
171 use super::*;
172 use crate::store::SqliteStore;
173 use zeph_common::config::memory::ForgettingConfig;
174
175 async fn make_store() -> SqliteStore {
176 SqliteStore::new(":memory:")
177 .await
178 .expect("SqliteStore::new")
179 }
180
181 fn default_config() -> ForgettingConfig {
182 ForgettingConfig {
183 enabled: true,
184 decay_rate: 0.1,
185 forgetting_floor: 0.05,
186 sweep_interval_secs: 7200,
187 sweep_batch_size: 500,
188 replay_window_hours: 24,
189 replay_min_access_count: 3,
190 protect_recent_hours: 24,
191 protect_min_access_count: 3,
192 }
193 }
194
195 #[tokio::test]
196 async fn sweep_on_empty_db_is_noop() {
197 let store = make_store().await;
198 let result = run_forgetting_sweep(&store, &default_config())
199 .await
200 .expect("sweep");
201 assert_eq!(result.downscaled, 0);
202 assert_eq!(result.replayed, 0);
203 assert_eq!(result.pruned, 0);
204 }
205
206 #[tokio::test]
207 async fn downscaling_reduces_importance_score() {
208 let store = make_store().await;
209 let cid = store.create_conversation().await.expect("conversation");
210
211 let mid = store
213 .save_message(cid, "user", "hello world")
214 .await
215 .expect("save_message");
216 store
217 .set_importance_score(mid, 0.8)
218 .await
219 .expect("set score");
220
221 let config = ForgettingConfig {
222 decay_rate: 0.1,
223 forgetting_floor: 0.01, protect_recent_hours: 0,
225 protect_min_access_count: 999,
226 replay_min_access_count: 999,
227 replay_window_hours: 0,
228 ..default_config()
229 };
230
231 run_forgetting_sweep(&store, &config).await.expect("sweep");
232
233 let importance = store
234 .get_importance_score(mid)
235 .await
236 .expect("get score")
237 .expect("score exists");
238 assert!(
240 (importance - 0.72_f64).abs() < 1e-5,
241 "expected ~0.72, got {importance}"
242 );
243 }
244
245 #[tokio::test]
246 async fn low_score_message_is_pruned() {
247 let store = make_store().await;
248 let cid = store.create_conversation().await.expect("conversation");
249 let mid = store
250 .save_message(cid, "user", "stale memory")
251 .await
252 .expect("save");
253 store
254 .set_importance_score(mid, 0.04)
255 .await
256 .expect("set score");
257
258 let config = ForgettingConfig {
259 decay_rate: 0.1,
260 forgetting_floor: 0.05,
261 protect_recent_hours: 0,
262 protect_min_access_count: 999,
263 replay_min_access_count: 999,
264 replay_window_hours: 0,
265 ..default_config()
266 };
267
268 let result = run_forgetting_sweep(&store, &config).await.expect("sweep");
269 assert_eq!(result.pruned, 1, "low-score message must be pruned");
270 }
271
272 #[tokio::test]
273 async fn high_access_message_is_protected_from_pruning() {
274 let store = make_store().await;
275 let cid = store.create_conversation().await.expect("conversation");
276 let mid = store
277 .save_message(cid, "user", "frequently accessed")
278 .await
279 .expect("save");
280 store
281 .set_importance_score(mid, 0.02)
282 .await
283 .expect("set score");
284 store
286 .batch_increment_access_count(&[mid])
287 .await
288 .expect("increment");
289 store
290 .batch_increment_access_count(&[mid])
291 .await
292 .expect("increment");
293 store
294 .batch_increment_access_count(&[mid])
295 .await
296 .expect("increment");
297
298 let config = ForgettingConfig {
299 decay_rate: 0.1,
300 forgetting_floor: 0.05,
301 protect_recent_hours: 0,
302 protect_min_access_count: 3, replay_min_access_count: 999,
304 replay_window_hours: 0,
305 ..default_config()
306 };
307
308 let result = run_forgetting_sweep(&store, &config).await.expect("sweep");
309 assert_eq!(result.pruned, 0, "high-access message must be protected");
310 }
311
312 #[tokio::test]
313 async fn recently_accessed_message_is_replayed() {
314 let store = make_store().await;
315 let cid = store.create_conversation().await.expect("conversation");
316 let mid = store
317 .save_message(cid, "user", "recently accessed memory")
318 .await
319 .expect("save");
320 store
322 .set_importance_score(mid, 0.5)
323 .await
324 .expect("set score");
325 store
326 .batch_increment_access_count(&[mid])
327 .await
328 .expect("access");
329
330 let config = ForgettingConfig {
331 decay_rate: 0.1,
332 forgetting_floor: 0.01,
333 replay_window_hours: 1,
335 replay_min_access_count: 999, protect_recent_hours: 0,
337 protect_min_access_count: 999,
338 ..default_config()
339 };
340
341 let result = run_forgetting_sweep(&store, &config).await.expect("sweep");
342 assert_eq!(
343 result.replayed, 1,
344 "recently accessed message must be replayed"
345 );
346
347 let importance = store
349 .get_importance_score(mid)
350 .await
351 .expect("get score")
352 .expect("score exists");
353 assert!(
354 (importance - 0.5_f64).abs() < 1e-4,
355 "replayed score must be restored to ~0.5, got {importance}"
356 );
357 }
358
359 #[tokio::test]
360 async fn consolidated_messages_are_not_downscaled() {
361 let store = make_store().await;
362 let cid = store.create_conversation().await.expect("conversation");
363 let mid = store
364 .save_message(cid, "user", "consolidated msg")
365 .await
366 .expect("save");
367 store
368 .set_importance_score(mid, 0.8)
369 .await
370 .expect("set score");
371 store
372 .mark_messages_consolidated(&[mid.0])
373 .await
374 .expect("mark consolidated");
375
376 let config = ForgettingConfig {
377 decay_rate: 0.1,
378 forgetting_floor: 0.01,
379 protect_recent_hours: 0,
380 protect_min_access_count: 999,
381 replay_min_access_count: 999,
382 replay_window_hours: 0,
383 ..default_config()
384 };
385
386 let result = run_forgetting_sweep(&store, &config).await.expect("sweep");
387 assert_eq!(result.downscaled, 0);
389 assert_eq!(result.pruned, 0);
390 }
391}