1use std::sync::Arc;
29use std::time::Duration;
30
31use tokio_util::sync::CancellationToken;
32use zeph_llm::any::AnyProvider;
33use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, Role};
34
35pub use zeph_config::memory::OpticalForgettingConfig;
36
37use crate::error::MemoryError;
38use crate::store::SqliteStore;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
49pub enum ContentFidelity {
50 Full,
52 Compressed,
54 SummaryOnly,
56}
57
58impl ContentFidelity {
59 #[must_use]
61 pub fn as_str(self) -> &'static str {
62 match self {
63 Self::Full => "Full",
64 Self::Compressed => "Compressed",
65 Self::SummaryOnly => "SummaryOnly",
66 }
67 }
68}
69
70impl std::fmt::Display for ContentFidelity {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 f.write_str(self.as_str())
73 }
74}
75
76impl std::str::FromStr for ContentFidelity {
77 type Err = String;
78 fn from_str(s: &str) -> Result<Self, Self::Err> {
79 match s {
80 "Full" => Ok(Self::Full),
81 "Compressed" => Ok(Self::Compressed),
82 "SummaryOnly" => Ok(Self::SummaryOnly),
83 other => Err(format!("unknown content_fidelity: {other}")),
84 }
85 }
86}
87
88#[derive(Debug, Default)]
92pub struct OpticalForgettingResult {
93 pub compressed: u32,
95 pub summarized: u32,
97 pub skipped: u32,
99}
100
101pub async fn start_optical_forgetting_loop(
110 store: Arc<SqliteStore>,
111 provider: AnyProvider,
112 config: OpticalForgettingConfig,
113 forgetting_floor: f32,
114 cancel: CancellationToken,
115) {
116 if !config.enabled {
117 tracing::debug!("optical forgetting disabled (optical_forgetting.enabled = false)");
118 return;
119 }
120
121 let provider = Arc::new(provider);
122 let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
123 ticker.tick().await; loop {
126 tokio::select! {
127 () = cancel.cancelled() => {
128 tracing::debug!("optical forgetting loop shutting down");
129 return;
130 }
131 _ = ticker.tick() => {}
132 }
133
134 tracing::debug!("optical_forgetting: starting sweep");
135 let start = std::time::Instant::now();
136
137 match run_optical_forgetting_sweep(&store, &provider, &config, forgetting_floor).await {
138 Ok(r) => {
139 tracing::info!(
140 compressed = r.compressed,
141 summarized = r.summarized,
142 skipped = r.skipped,
143 elapsed_ms = start.elapsed().as_millis(),
144 "optical_forgetting: sweep complete"
145 );
146 }
147 Err(e) => {
148 tracing::warn!(
149 error = %e,
150 elapsed_ms = start.elapsed().as_millis(),
151 "optical_forgetting: sweep failed, will retry"
152 );
153 }
154 }
155 }
156}
157
158#[tracing::instrument(name = "memory.optical_forgetting", skip_all)]
172pub async fn run_optical_forgetting_sweep(
173 store: &SqliteStore,
174 provider: &Arc<AnyProvider>,
175 config: &OpticalForgettingConfig,
176 forgetting_floor: f32,
177) -> Result<OpticalForgettingResult, MemoryError> {
178 let mut result = OpticalForgettingResult::default();
179
180 let full_candidates = fetch_full_candidates(store, config, forgetting_floor).await?;
182 for (msg_id, content) in full_candidates {
183 match compress_content(provider, &content).await {
184 Ok(compressed) => {
185 store_compressed(store, msg_id, &compressed).await?;
186 result.compressed += 1;
187 tracing::debug!(msg_id, "optical_forgetting: Full → Compressed");
188 }
189 Err(e) => {
190 tracing::warn!(error = %e, msg_id, "optical_forgetting: compression failed, skipping");
191 result.skipped += 1;
192 }
193 }
194 }
195
196 let compressed_candidates =
198 fetch_compressed_candidates(store, config, forgetting_floor).await?;
199 for (msg_id, compressed_content) in compressed_candidates {
200 match summarize_content(provider, &compressed_content).await {
201 Ok(summary) => {
202 store_summary_only(store, msg_id, &summary).await?;
203 result.summarized += 1;
204 tracing::debug!(msg_id, "optical_forgetting: Compressed → SummaryOnly");
205 }
206 Err(e) => {
207 tracing::warn!(error = %e, msg_id, "optical_forgetting: summarization failed, skipping");
208 result.skipped += 1;
209 }
210 }
211 }
212
213 Ok(result)
214}
215
216async fn fetch_full_candidates(
222 store: &SqliteStore,
223 config: &OpticalForgettingConfig,
224 forgetting_floor: f32,
225) -> Result<Vec<(i64, String)>, MemoryError> {
226 let rows = sqlx::query_as::<_, (i64, String)>(
234 "SELECT id, content FROM messages
235 WHERE content_fidelity = 'Full'
236 AND deleted_at IS NULL
237 AND (importance_score IS NULL OR importance_score >= ?)
238 AND id <= (SELECT COALESCE(MAX(id), 0) - ? FROM messages)
239 ORDER BY id ASC
240 LIMIT ?",
241 )
242 .bind(forgetting_floor)
243 .bind(i64::from(config.compress_after_turns))
244 .bind(i64::try_from(config.sweep_batch_size).unwrap_or(i64::MAX))
245 .fetch_all(store.pool())
246 .await?;
247
248 Ok(rows)
249}
250
251async fn fetch_compressed_candidates(
253 store: &SqliteStore,
254 config: &OpticalForgettingConfig,
255 forgetting_floor: f32,
256) -> Result<Vec<(i64, String)>, MemoryError> {
257 let rows = sqlx::query_as::<_, (i64, Option<String>)>(
258 "SELECT id, compressed_content FROM messages
259 WHERE content_fidelity = 'Compressed'
260 AND deleted_at IS NULL
261 AND (importance_score IS NULL OR importance_score >= ?)
262 AND id <= (SELECT COALESCE(MAX(id), 0) - ? FROM messages)
263 ORDER BY id ASC
264 LIMIT ?",
265 )
266 .bind(forgetting_floor)
267 .bind(i64::from(config.summarize_after_turns))
268 .bind(i64::try_from(config.sweep_batch_size).unwrap_or(i64::MAX))
269 .fetch_all(store.pool())
270 .await?;
271
272 Ok(rows
273 .into_iter()
274 .filter_map(|(id, content)| content.map(|c| (id, c)))
275 .collect())
276}
277
278async fn store_compressed(
280 store: &SqliteStore,
281 msg_id: i64,
282 compressed: &str,
283) -> Result<(), MemoryError> {
284 sqlx::query(
285 "UPDATE messages
286 SET content_fidelity = 'Compressed', compressed_content = ?
287 WHERE id = ?",
288 )
289 .bind(compressed)
290 .bind(msg_id)
291 .execute(store.pool())
292 .await?;
293 Ok(())
294}
295
296async fn store_summary_only(
298 store: &SqliteStore,
299 msg_id: i64,
300 summary: &str,
301) -> Result<(), MemoryError> {
302 sqlx::query(
303 "UPDATE messages
304 SET content_fidelity = 'SummaryOnly', content = ?, compressed_content = NULL
305 WHERE id = ?",
306 )
307 .bind(summary)
308 .bind(msg_id)
309 .execute(store.pool())
310 .await?;
311 Ok(())
312}
313
314#[tracing::instrument(name = "memory.optical_forgetting.compress", skip_all, err)]
318async fn compress_content(
319 provider: &Arc<AnyProvider>,
320 content: &str,
321) -> Result<String, MemoryError> {
322 let cleaned = zeph_common::sanitize::strip_control_chars_preserve_whitespace(content);
323 let snippet = cleaned.chars().take(2000).collect::<String>();
324 let messages = vec![
325 Message {
326 role: Role::System,
327 content: "You compress conversation messages into concise summaries that preserve \
328 all key facts, decisions, and action items. Output only the summary text, \
329 no preamble."
330 .to_owned(),
331 parts: vec![],
332 metadata: MessageMetadata::default(),
333 },
334 Message {
335 role: Role::User,
336 content: format!("Compress this message:\n\n{snippet}"),
337 parts: vec![],
338 metadata: MessageMetadata::default(),
339 },
340 ];
341
342 let raw = tokio::time::timeout(Duration::from_secs(15), provider.chat(&messages))
343 .await
344 .map_err(|_| MemoryError::Timeout("optical_forgetting: compress timed out".into()))?
345 .map_err(MemoryError::Llm)?;
346
347 Ok(raw.trim().to_owned())
348}
349
350#[tracing::instrument(name = "memory.optical_forgetting.summarize", skip_all, err)]
352async fn summarize_content(
353 provider: &Arc<AnyProvider>,
354 content: &str,
355) -> Result<String, MemoryError> {
356 let cleaned = zeph_common::sanitize::strip_control_chars_preserve_whitespace(content);
357 let snippet = cleaned.chars().take(1000).collect::<String>();
358 let messages = vec![
359 Message {
360 role: Role::System,
361 content: "You distill summaries into single sentences that capture the essential \
362 fact or outcome. Output only the one-sentence summary, no preamble."
363 .to_owned(),
364 parts: vec![],
365 metadata: MessageMetadata::default(),
366 },
367 Message {
368 role: Role::User,
369 content: format!("Distill into one sentence:\n\n{snippet}"),
370 parts: vec![],
371 metadata: MessageMetadata::default(),
372 },
373 ];
374
375 let raw = tokio::time::timeout(Duration::from_secs(10), provider.chat(&messages))
376 .await
377 .map_err(|_| MemoryError::Timeout("optical_forgetting: summarize timed out".into()))?
378 .map_err(MemoryError::Llm)?;
379
380 Ok(raw.trim().to_owned())
381}
382
383#[cfg(test)]
386mod tests {
387 use super::*;
388 use zeph_config::providers::ProviderName;
389
390 #[test]
391 fn content_fidelity_round_trip() {
392 for fidelity in [
393 ContentFidelity::Full,
394 ContentFidelity::Compressed,
395 ContentFidelity::SummaryOnly,
396 ] {
397 let s = fidelity.as_str();
398 let parsed: ContentFidelity = s.parse().expect("should parse");
399 assert_eq!(parsed, fidelity);
400 assert_eq!(format!("{fidelity}"), s);
401 }
402 }
403
404 #[test]
405 fn content_fidelity_unknown_string_errors() {
406 assert!("unknown".parse::<ContentFidelity>().is_err());
407 }
408
409 #[test]
410 fn optical_forgetting_config_defaults() {
411 let cfg = OpticalForgettingConfig::default();
412 assert!(!cfg.enabled);
413 assert_eq!(cfg.compress_after_turns, 100);
414 assert_eq!(cfg.summarize_after_turns, 500);
415 assert_eq!(cfg.sweep_interval_secs, 3600);
416 assert_eq!(cfg.sweep_batch_size, 50);
417 }
418
419 #[test]
420 fn optical_forgetting_result_default() {
421 let r = OpticalForgettingResult::default();
422 assert_eq!(r.compressed, 0);
423 assert_eq!(r.summarized, 0);
424 assert_eq!(r.skipped, 0);
425 }
426
427 #[tokio::test]
430 async fn sweep_skips_when_no_candidates_old_enough() {
431 use std::sync::Arc;
432
433 use zeph_llm::any::AnyProvider;
434 use zeph_llm::mock::MockProvider;
435
436 use crate::store::SqliteStore;
437
438 let store = Arc::new(
439 SqliteStore::new(":memory:")
440 .await
441 .expect("SqliteStore::new"),
442 );
443 let provider = Arc::new(AnyProvider::Mock(MockProvider::default()));
444
445 let cid = store.create_conversation().await.expect("conversation");
446 store
447 .save_message(cid, "user", "hello")
448 .await
449 .expect("save_message");
450
451 let config = OpticalForgettingConfig {
452 enabled: true,
453 compress_after_turns: 100, summarize_after_turns: 500,
455 sweep_interval_secs: 3600,
456 sweep_batch_size: 50,
457 compress_provider: ProviderName::default(),
458 };
459 let result = run_optical_forgetting_sweep(&store, &provider, &config, 0.0)
460 .await
461 .expect("sweep");
462
463 assert_eq!(
464 result.compressed, 0,
465 "no message should be compressed when not old enough"
466 );
467 assert_eq!(result.summarized, 0);
468 }
469
470 #[tokio::test]
473 async fn sweep_compresses_eligible_full_message() {
474 use std::sync::Arc;
475
476 use zeph_llm::any::AnyProvider;
477 use zeph_llm::mock::MockProvider;
478
479 use crate::store::SqliteStore;
480
481 let store = Arc::new(
482 SqliteStore::new(":memory:")
483 .await
484 .expect("SqliteStore::new"),
485 );
486 let mock = MockProvider::with_responses(vec!["compressed summary".to_owned()]);
487 let provider = Arc::new(AnyProvider::Mock(mock));
488
489 let cid = store.create_conversation().await.expect("conversation");
490 store
492 .save_message(cid, "user", "first message")
493 .await
494 .expect("save_message 1");
495 store
496 .save_message(cid, "user", "second message")
497 .await
498 .expect("save_message 2");
499
500 let config = OpticalForgettingConfig {
501 enabled: true,
502 compress_after_turns: 0, summarize_after_turns: 500,
504 sweep_interval_secs: 3600,
505 sweep_batch_size: 50,
506 compress_provider: ProviderName::default(),
507 };
508 let result = run_optical_forgetting_sweep(&store, &provider, &config, 0.0)
509 .await
510 .expect("sweep");
511
512 assert!(
514 result.compressed >= 1,
515 "at least one message must be compressed"
516 );
517 }
518
519 #[tokio::test]
521 async fn sweep_disabled_returns_empty_result() {
522 use std::sync::Arc;
523
524 use zeph_llm::any::AnyProvider;
525 use zeph_llm::mock::MockProvider;
526
527 use crate::store::SqliteStore;
528
529 let store = Arc::new(
530 SqliteStore::new(":memory:")
531 .await
532 .expect("SqliteStore::new"),
533 );
534 let provider = Arc::new(AnyProvider::Mock(MockProvider::default()));
535 let config = OpticalForgettingConfig {
536 enabled: false,
537 ..Default::default()
538 };
539 let result = run_optical_forgetting_sweep(&store, &provider, &config, 0.0)
542 .await
543 .expect("sweep with disabled config");
544 assert_eq!(result.compressed, 0);
545 assert_eq!(result.summarized, 0);
546 }
547}