Skip to main content

zeph_memory/
forgetting.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Forgetting sweep — `SleepGate` (#2397).
5//!
6//! Inspired by sleep-dependent memory consolidation: a background sweep periodically
7//! downscales all non-consolidated message importance scores (synaptic downscaling),
8//! restores recently-accessed messages (selective replay), then prunes messages whose
9//! scores fall below `forgetting_floor` (targeted forgetting).
10//!
11//! # Algorithm
12//!
13//! 1. **Synaptic downscaling** — multiply all active, non-consolidated importance scores
14//!    by `(1.0 - decay_rate)` in a single batch UPDATE.
15//! 2. **Selective replay** — undo the current sweep's decay for messages accessed within
16//!    `replay_window_hours` or with `access_count >= replay_min_access_count`.
17//! 3. **Targeted forgetting** — soft-delete messages below `forgetting_floor` that are
18//!    not protected by recent access or high access count.
19//!
20//! All three phases run inside a single `SQLite` transaction to prevent intermediate state
21//! from being visible to concurrent readers (WAL readers see the pre-transaction snapshot
22//! until commit).
23//!
24//! # Interaction with consolidation
25//!
26//! Forgetting only targets non-consolidated messages (`consolidated = 0`). Consolidation
27//! merge transactions re-check `deleted_at IS NULL` before writing, so messages deleted
28//! by forgetting are safely skipped during the next consolidation sweep.
29//!
30//! # No LLM calls
31//!
32//! Pure SQL arithmetic — no `*_provider` field needed.
33
34use std::sync::Arc;
35use std::time::Duration;
36
37use tokio::task::JoinHandle;
38use tokio_util::sync::CancellationToken;
39
40use crate::error::MemoryError;
41use crate::store::SqliteStore;
42
43pub use zeph_common::config::memory::ForgettingConfig;
44
45// ── Result ────────────────────────────────────────────────────────────────────
46
47/// Outcome of a single forgetting sweep.
48#[derive(Debug, Default)]
49pub struct ForgettingResult {
50    /// Number of messages whose importance score was downscaled.
51    pub downscaled: u32,
52    /// Number of messages whose score was restored via selective replay.
53    pub replayed: u32,
54    /// Number of messages soft-deleted by targeted forgetting.
55    pub pruned: u32,
56}
57
58// ── Sweep loop ────────────────────────────────────────────────────────────────
59
60/// Start the background forgetting loop (`SleepGate`).
61///
62/// The loop runs every `config.sweep_interval_secs` seconds, independently of the
63/// consolidation loop. Both share the same `SqliteStore` without a lock because `SQLite`
64/// WAL mode handles concurrent writers safely — each sweep runs inside a single
65/// transaction, so consolidation merges always see either the pre-sweep or post-sweep
66/// state, never an intermediate state.
67///
68/// Database errors are logged but do not stop the loop.
69#[must_use]
70pub fn start_forgetting_loop(
71    store: Arc<SqliteStore>,
72    config: ForgettingConfig,
73    cancel: CancellationToken,
74) -> JoinHandle<()> {
75    tokio::spawn(async move {
76        if !config.enabled {
77            tracing::debug!("forgetting sweep disabled (forgetting.enabled = false)");
78            return;
79        }
80
81        let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
82        // Skip the first immediate tick to avoid running at startup.
83        ticker.tick().await;
84
85        loop {
86            tokio::select! {
87                () = cancel.cancelled() => {
88                    tracing::debug!("forgetting loop shutting down");
89                    return;
90                }
91                _ = ticker.tick() => {}
92            }
93
94            tracing::debug!("forgetting: starting sweep");
95            let start = std::time::Instant::now();
96
97            match run_forgetting_sweep(&store, &config).await {
98                Ok(r) => {
99                    tracing::info!(
100                        downscaled = r.downscaled,
101                        replayed = r.replayed,
102                        pruned = r.pruned,
103                        elapsed_ms = start.elapsed().as_millis(),
104                        "forgetting: sweep complete"
105                    );
106                }
107                Err(e) => {
108                    tracing::warn!(
109                        error = %e,
110                        elapsed_ms = start.elapsed().as_millis(),
111                        "forgetting: sweep failed, will retry"
112                    );
113                }
114            }
115        }
116    })
117}
118
119// ── Sweep implementation ──────────────────────────────────────────────────────
120
121/// Execute one full forgetting sweep (`SleepGate`).
122///
123/// All three phases run inside a single transaction to prevent intermediate state
124/// from being visible to concurrent readers.
125///
126/// Returns early (no-op) if `config` contains out-of-range values, logging a warning.
127/// Valid ranges:
128/// - `decay_rate` in (0.0, 1.0) exclusive
129/// - `forgetting_floor` in [0.0, 1.0) exclusive upper bound
130/// - `sweep_interval_secs >= 60`
131///
132/// # Errors
133///
134/// Returns an error if any database operation fails.
135pub async fn run_forgetting_sweep(
136    store: &SqliteStore,
137    config: &ForgettingConfig,
138) -> Result<ForgettingResult, MemoryError> {
139    if config.decay_rate <= 0.0 || config.decay_rate >= 1.0 {
140        tracing::warn!(
141            decay_rate = config.decay_rate,
142            "forgetting: decay_rate must be in (0.0, 1.0); skipping sweep"
143        );
144        return Ok(ForgettingResult::default());
145    }
146    if config.forgetting_floor < 0.0 || config.forgetting_floor >= 1.0 {
147        tracing::warn!(
148            forgetting_floor = config.forgetting_floor,
149            "forgetting: forgetting_floor must be in [0.0, 1.0); skipping sweep"
150        );
151        return Ok(ForgettingResult::default());
152    }
153    if config.sweep_interval_secs < 60 {
154        tracing::warn!(
155            sweep_interval_secs = config.sweep_interval_secs,
156            "forgetting: sweep_interval_secs must be >= 60; skipping sweep"
157        );
158        return Ok(ForgettingResult::default());
159    }
160    store.run_forgetting_sweep_tx(config).await
161}
162
163// ── Tests ─────────────────────────────────────────────────────────────────────
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::store::SqliteStore;
169    use zeph_common::config::memory::ForgettingConfig;
170
171    async fn make_store() -> SqliteStore {
172        SqliteStore::new(":memory:")
173            .await
174            .expect("SqliteStore::new")
175    }
176
177    fn default_config() -> ForgettingConfig {
178        ForgettingConfig {
179            enabled: true,
180            decay_rate: 0.1,
181            forgetting_floor: 0.05,
182            sweep_interval_secs: 7200,
183            sweep_batch_size: 500,
184            replay_window_hours: 24,
185            replay_min_access_count: 3,
186            protect_recent_hours: 24,
187            protect_min_access_count: 3,
188        }
189    }
190
191    #[tokio::test]
192    async fn sweep_on_empty_db_is_noop() {
193        let store = make_store().await;
194        let result = run_forgetting_sweep(&store, &default_config())
195            .await
196            .expect("sweep");
197        assert_eq!(result.downscaled, 0);
198        assert_eq!(result.replayed, 0);
199        assert_eq!(result.pruned, 0);
200    }
201
202    #[tokio::test]
203    async fn downscaling_reduces_importance_score() {
204        let store = make_store().await;
205        let cid = store.create_conversation().await.expect("conversation");
206
207        // Insert a message and set a high importance score.
208        let mid = store
209            .save_message(cid, "user", "hello world")
210            .await
211            .expect("save_message");
212        store
213            .set_importance_score(mid, 0.8)
214            .await
215            .expect("set score");
216
217        let config = ForgettingConfig {
218            decay_rate: 0.1,
219            forgetting_floor: 0.01, // very low — won't prune
220            protect_recent_hours: 0,
221            protect_min_access_count: 999,
222            replay_min_access_count: 999,
223            replay_window_hours: 0,
224            ..default_config()
225        };
226
227        run_forgetting_sweep(&store, &config).await.expect("sweep");
228
229        let importance = store
230            .get_importance_score(mid)
231            .await
232            .expect("get score")
233            .expect("score exists");
234        // 0.8 * (1 - 0.1) = 0.72, allow small float epsilon
235        assert!(
236            (importance - 0.72_f64).abs() < 1e-5,
237            "expected ~0.72, got {importance}"
238        );
239    }
240
241    #[tokio::test]
242    async fn low_score_message_is_pruned() {
243        let store = make_store().await;
244        let cid = store.create_conversation().await.expect("conversation");
245        let mid = store
246            .save_message(cid, "user", "stale memory")
247            .await
248            .expect("save");
249        store
250            .set_importance_score(mid, 0.04)
251            .await
252            .expect("set score");
253
254        let config = ForgettingConfig {
255            decay_rate: 0.1,
256            forgetting_floor: 0.05,
257            protect_recent_hours: 0,
258            protect_min_access_count: 999,
259            replay_min_access_count: 999,
260            replay_window_hours: 0,
261            ..default_config()
262        };
263
264        let result = run_forgetting_sweep(&store, &config).await.expect("sweep");
265        assert_eq!(result.pruned, 1, "low-score message must be pruned");
266    }
267
268    #[tokio::test]
269    async fn high_access_message_is_protected_from_pruning() {
270        let store = make_store().await;
271        let cid = store.create_conversation().await.expect("conversation");
272        let mid = store
273            .save_message(cid, "user", "frequently accessed")
274            .await
275            .expect("save");
276        store
277            .set_importance_score(mid, 0.02)
278            .await
279            .expect("set score");
280        // Simulate high access count via batch_increment_access_count.
281        store
282            .batch_increment_access_count(&[mid])
283            .await
284            .expect("increment");
285        store
286            .batch_increment_access_count(&[mid])
287            .await
288            .expect("increment");
289        store
290            .batch_increment_access_count(&[mid])
291            .await
292            .expect("increment");
293
294        let config = ForgettingConfig {
295            decay_rate: 0.1,
296            forgetting_floor: 0.05,
297            protect_recent_hours: 0,
298            protect_min_access_count: 3, // protected at 3
299            replay_min_access_count: 999,
300            replay_window_hours: 0,
301            ..default_config()
302        };
303
304        let result = run_forgetting_sweep(&store, &config).await.expect("sweep");
305        assert_eq!(result.pruned, 0, "high-access message must be protected");
306    }
307
308    #[tokio::test]
309    async fn recently_accessed_message_is_replayed() {
310        let store = make_store().await;
311        let cid = store.create_conversation().await.expect("conversation");
312        let mid = store
313            .save_message(cid, "user", "recently accessed memory")
314            .await
315            .expect("save");
316        // Set a moderate importance score, then access it (sets last_accessed = now).
317        store
318            .set_importance_score(mid, 0.5)
319            .await
320            .expect("set score");
321        store
322            .batch_increment_access_count(&[mid])
323            .await
324            .expect("access");
325
326        let config = ForgettingConfig {
327            decay_rate: 0.1,
328            forgetting_floor: 0.01,
329            // Replay window of 1 hour catches last_accessed = now.
330            replay_window_hours: 1,
331            replay_min_access_count: 999, // only trigger via recency, not access count
332            protect_recent_hours: 0,
333            protect_min_access_count: 999,
334            ..default_config()
335        };
336
337        let result = run_forgetting_sweep(&store, &config).await.expect("sweep");
338        assert_eq!(
339            result.replayed, 1,
340            "recently accessed message must be replayed"
341        );
342
343        // Score should be back near 0.5 (decayed then restored): 0.5 * 0.9 / 0.9 = 0.5.
344        let importance = store
345            .get_importance_score(mid)
346            .await
347            .expect("get score")
348            .expect("score exists");
349        assert!(
350            (importance - 0.5_f64).abs() < 1e-4,
351            "replayed score must be restored to ~0.5, got {importance}"
352        );
353    }
354
355    #[tokio::test]
356    async fn consolidated_messages_are_not_downscaled() {
357        let store = make_store().await;
358        let cid = store.create_conversation().await.expect("conversation");
359        let mid = store
360            .save_message(cid, "user", "consolidated msg")
361            .await
362            .expect("save");
363        store
364            .set_importance_score(mid, 0.8)
365            .await
366            .expect("set score");
367        store
368            .mark_messages_consolidated(&[mid.0])
369            .await
370            .expect("mark consolidated");
371
372        let config = ForgettingConfig {
373            decay_rate: 0.1,
374            forgetting_floor: 0.01,
375            protect_recent_hours: 0,
376            protect_min_access_count: 999,
377            replay_min_access_count: 999,
378            replay_window_hours: 0,
379            ..default_config()
380        };
381
382        let result = run_forgetting_sweep(&store, &config).await.expect("sweep");
383        // Consolidated messages must be skipped entirely.
384        assert_eq!(result.downscaled, 0);
385        assert_eq!(result.pruned, 0);
386    }
387}