1use std::sync::Arc;
29use std::time::Duration;
30
31use tokio_util::sync::CancellationToken;
32use zeph_llm::any::AnyProvider;
33use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, Role};
34
35pub use zeph_config::memory::OpticalForgettingConfig;
36
37use crate::error::MemoryError;
38use crate::store::SqliteStore;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
54#[non_exhaustive]
55pub enum ContentFidelity {
56 Full,
58 Compressed,
60 SummaryOnly,
62}
63
64impl ContentFidelity {
65 #[must_use]
67 pub fn as_str(self) -> &'static str {
68 match self {
69 Self::Full => "Full",
70 Self::Compressed => "Compressed",
71 Self::SummaryOnly => "SummaryOnly",
72 }
73 }
74}
75
76impl std::fmt::Display for ContentFidelity {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 f.write_str(self.as_str())
79 }
80}
81
82impl std::str::FromStr for ContentFidelity {
83 type Err = String;
84 fn from_str(s: &str) -> Result<Self, Self::Err> {
85 match s {
86 "Full" => Ok(Self::Full),
87 "Compressed" => Ok(Self::Compressed),
88 "SummaryOnly" => Ok(Self::SummaryOnly),
89 other => Err(format!("unknown content_fidelity: {other}")),
90 }
91 }
92}
93
94#[derive(Debug, Default)]
98pub struct OpticalForgettingResult {
99 pub compressed: u32,
101 pub summarized: u32,
103 pub skipped: u32,
105}
106
107pub async fn start_optical_forgetting_loop(
116 store: Arc<SqliteStore>,
117 provider: AnyProvider,
118 config: OpticalForgettingConfig,
119 forgetting_floor: f32,
120 cancel: CancellationToken,
121) {
122 if !config.enabled {
123 tracing::debug!("optical forgetting disabled (optical_forgetting.enabled = false)");
124 return;
125 }
126
127 let provider = Arc::new(provider);
128 let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
129 ticker.tick().await; loop {
132 tokio::select! {
133 () = cancel.cancelled() => {
134 tracing::debug!("optical forgetting loop shutting down");
135 return;
136 }
137 _ = ticker.tick() => {}
138 }
139
140 tracing::debug!("optical_forgetting: starting sweep");
141 let start = std::time::Instant::now();
142
143 match run_optical_forgetting_sweep(&store, &provider, &config, forgetting_floor).await {
144 Ok(r) => {
145 tracing::info!(
146 compressed = r.compressed,
147 summarized = r.summarized,
148 skipped = r.skipped,
149 elapsed_ms = start.elapsed().as_millis(),
150 "optical_forgetting: sweep complete"
151 );
152 }
153 Err(e) => {
154 tracing::warn!(
155 error = %e,
156 elapsed_ms = start.elapsed().as_millis(),
157 "optical_forgetting: sweep failed, will retry"
158 );
159 }
160 }
161 }
162}
163
164#[tracing::instrument(name = "memory.optical_forgetting", skip_all)]
178pub async fn run_optical_forgetting_sweep(
179 store: &SqliteStore,
180 provider: &Arc<AnyProvider>,
181 config: &OpticalForgettingConfig,
182 forgetting_floor: f32,
183) -> Result<OpticalForgettingResult, MemoryError> {
184 let mut result = OpticalForgettingResult::default();
185
186 let full_candidates = fetch_full_candidates(store, config, forgetting_floor).await?;
188 for (msg_id, content) in full_candidates {
189 match compress_content(provider, &content).await {
190 Ok(compressed) => {
191 store_compressed(store, msg_id, &compressed).await?;
192 result.compressed += 1;
193 tracing::debug!(msg_id, "optical_forgetting: Full → Compressed");
194 }
195 Err(e) => {
196 tracing::warn!(error = %e, msg_id, "optical_forgetting: compression failed, skipping");
197 result.skipped += 1;
198 }
199 }
200 }
201
202 let compressed_candidates =
204 fetch_compressed_candidates(store, config, forgetting_floor).await?;
205 for (msg_id, compressed_content) in compressed_candidates {
206 match summarize_content(provider, &compressed_content).await {
207 Ok(summary) => {
208 store_summary_only(store, msg_id, &summary).await?;
209 result.summarized += 1;
210 tracing::debug!(msg_id, "optical_forgetting: Compressed → SummaryOnly");
211 }
212 Err(e) => {
213 tracing::warn!(error = %e, msg_id, "optical_forgetting: summarization failed, skipping");
214 result.skipped += 1;
215 }
216 }
217 }
218
219 Ok(result)
220}
221
222async fn fetch_full_candidates(
228 store: &SqliteStore,
229 config: &OpticalForgettingConfig,
230 forgetting_floor: f32,
231) -> Result<Vec<(i64, String)>, MemoryError> {
232 let rows = sqlx::query_as::<_, (i64, String)>(
240 "SELECT id, content FROM messages
241 WHERE content_fidelity = 'Full'
242 AND deleted_at IS NULL
243 AND (importance_score IS NULL OR importance_score >= ?)
244 AND id <= (SELECT COALESCE(MAX(id), 0) - ? FROM messages)
245 ORDER BY id ASC
246 LIMIT ?",
247 )
248 .bind(forgetting_floor)
249 .bind(i64::from(config.compress_after_turns))
250 .bind(i64::try_from(config.sweep_batch_size).unwrap_or(i64::MAX))
251 .fetch_all(store.pool())
252 .await?;
253
254 Ok(rows)
255}
256
257async fn fetch_compressed_candidates(
259 store: &SqliteStore,
260 config: &OpticalForgettingConfig,
261 forgetting_floor: f32,
262) -> Result<Vec<(i64, String)>, MemoryError> {
263 let rows = sqlx::query_as::<_, (i64, Option<String>)>(
264 "SELECT id, compressed_content FROM messages
265 WHERE content_fidelity = 'Compressed'
266 AND deleted_at IS NULL
267 AND (importance_score IS NULL OR importance_score >= ?)
268 AND id <= (SELECT COALESCE(MAX(id), 0) - ? FROM messages)
269 ORDER BY id ASC
270 LIMIT ?",
271 )
272 .bind(forgetting_floor)
273 .bind(i64::from(config.summarize_after_turns))
274 .bind(i64::try_from(config.sweep_batch_size).unwrap_or(i64::MAX))
275 .fetch_all(store.pool())
276 .await?;
277
278 Ok(rows
279 .into_iter()
280 .filter_map(|(id, content)| content.map(|c| (id, c)))
281 .collect())
282}
283
284async fn store_compressed(
286 store: &SqliteStore,
287 msg_id: i64,
288 compressed: &str,
289) -> Result<(), MemoryError> {
290 sqlx::query(
291 "UPDATE messages
292 SET content_fidelity = 'Compressed', compressed_content = ?
293 WHERE id = ?",
294 )
295 .bind(compressed)
296 .bind(msg_id)
297 .execute(store.pool())
298 .await?;
299 Ok(())
300}
301
302async fn store_summary_only(
304 store: &SqliteStore,
305 msg_id: i64,
306 summary: &str,
307) -> Result<(), MemoryError> {
308 sqlx::query(
309 "UPDATE messages
310 SET content_fidelity = 'SummaryOnly', content = ?, compressed_content = NULL
311 WHERE id = ?",
312 )
313 .bind(summary)
314 .bind(msg_id)
315 .execute(store.pool())
316 .await?;
317 Ok(())
318}
319
320#[tracing::instrument(name = "memory.optical_forgetting.compress", skip_all, err)]
324async fn compress_content(
325 provider: &Arc<AnyProvider>,
326 content: &str,
327) -> Result<String, MemoryError> {
328 let cleaned = zeph_common::sanitize::strip_control_chars_preserve_whitespace(content);
329 let snippet = cleaned.chars().take(2000).collect::<String>();
330 let messages = vec![
331 Message {
332 role: Role::System,
333 content: "You compress conversation messages into concise summaries that preserve \
334 all key facts, decisions, and action items. Output only the summary text, \
335 no preamble."
336 .to_owned(),
337 parts: vec![],
338 metadata: MessageMetadata::default(),
339 },
340 Message {
341 role: Role::User,
342 content: format!("Compress this message:\n\n{snippet}"),
343 parts: vec![],
344 metadata: MessageMetadata::default(),
345 },
346 ];
347
348 let raw = tokio::time::timeout(Duration::from_secs(15), provider.chat(&messages))
349 .await
350 .map_err(|_| MemoryError::Timeout("optical_forgetting: compress timed out".into()))?
351 .map_err(MemoryError::Llm)?;
352
353 Ok(raw.trim().to_owned())
354}
355
356#[tracing::instrument(name = "memory.optical_forgetting.summarize", skip_all, err)]
358async fn summarize_content(
359 provider: &Arc<AnyProvider>,
360 content: &str,
361) -> Result<String, MemoryError> {
362 let cleaned = zeph_common::sanitize::strip_control_chars_preserve_whitespace(content);
363 let snippet = cleaned.chars().take(1000).collect::<String>();
364 let messages = vec![
365 Message {
366 role: Role::System,
367 content: "You distill summaries into single sentences that capture the essential \
368 fact or outcome. Output only the one-sentence summary, no preamble."
369 .to_owned(),
370 parts: vec![],
371 metadata: MessageMetadata::default(),
372 },
373 Message {
374 role: Role::User,
375 content: format!("Distill into one sentence:\n\n{snippet}"),
376 parts: vec![],
377 metadata: MessageMetadata::default(),
378 },
379 ];
380
381 let raw = tokio::time::timeout(Duration::from_secs(10), provider.chat(&messages))
382 .await
383 .map_err(|_| MemoryError::Timeout("optical_forgetting: summarize timed out".into()))?
384 .map_err(MemoryError::Llm)?;
385
386 Ok(raw.trim().to_owned())
387}
388
389#[cfg(test)]
392mod tests {
393 use super::*;
394 use zeph_config::providers::ProviderName;
395
396 #[test]
397 fn content_fidelity_round_trip() {
398 for fidelity in [
399 ContentFidelity::Full,
400 ContentFidelity::Compressed,
401 ContentFidelity::SummaryOnly,
402 ] {
403 let s = fidelity.as_str();
404 let parsed: ContentFidelity = s.parse().expect("should parse");
405 assert_eq!(parsed, fidelity);
406 assert_eq!(format!("{fidelity}"), s);
407 }
408 }
409
410 #[test]
411 fn content_fidelity_unknown_string_errors() {
412 assert!("unknown".parse::<ContentFidelity>().is_err());
413 }
414
415 #[test]
416 fn optical_forgetting_config_defaults() {
417 let cfg = OpticalForgettingConfig::default();
418 assert!(!cfg.enabled);
419 assert_eq!(cfg.compress_after_turns, 100);
420 assert_eq!(cfg.summarize_after_turns, 500);
421 assert_eq!(cfg.sweep_interval_secs, 3600);
422 assert_eq!(cfg.sweep_batch_size, 50);
423 }
424
425 #[test]
426 fn optical_forgetting_result_default() {
427 let r = OpticalForgettingResult::default();
428 assert_eq!(r.compressed, 0);
429 assert_eq!(r.summarized, 0);
430 assert_eq!(r.skipped, 0);
431 }
432
433 #[tokio::test]
436 async fn sweep_skips_when_no_candidates_old_enough() {
437 use std::sync::Arc;
438
439 use zeph_llm::any::AnyProvider;
440 use zeph_llm::mock::MockProvider;
441
442 use crate::store::SqliteStore;
443
444 let store = Arc::new(
445 SqliteStore::new(":memory:")
446 .await
447 .expect("SqliteStore::new"),
448 );
449 let provider = Arc::new(AnyProvider::Mock(MockProvider::default()));
450
451 let cid = store.create_conversation().await.expect("conversation");
452 store
453 .save_message(cid, "user", "hello")
454 .await
455 .expect("save_message");
456
457 let config = OpticalForgettingConfig {
458 enabled: true,
459 compress_after_turns: 100, summarize_after_turns: 500,
461 sweep_interval_secs: 3600,
462 sweep_batch_size: 50,
463 compress_provider: ProviderName::default(),
464 };
465 let result = run_optical_forgetting_sweep(&store, &provider, &config, 0.0)
466 .await
467 .expect("sweep");
468
469 assert_eq!(
470 result.compressed, 0,
471 "no message should be compressed when not old enough"
472 );
473 assert_eq!(result.summarized, 0);
474 }
475
476 #[tokio::test]
479 async fn sweep_compresses_eligible_full_message() {
480 use std::sync::Arc;
481
482 use zeph_llm::any::AnyProvider;
483 use zeph_llm::mock::MockProvider;
484
485 use crate::store::SqliteStore;
486
487 let store = Arc::new(
488 SqliteStore::new(":memory:")
489 .await
490 .expect("SqliteStore::new"),
491 );
492 let mock = MockProvider::with_responses(vec!["compressed summary".to_owned()]);
493 let provider = Arc::new(AnyProvider::Mock(mock));
494
495 let cid = store.create_conversation().await.expect("conversation");
496 store
498 .save_message(cid, "user", "first message")
499 .await
500 .expect("save_message 1");
501 store
502 .save_message(cid, "user", "second message")
503 .await
504 .expect("save_message 2");
505
506 let config = OpticalForgettingConfig {
507 enabled: true,
508 compress_after_turns: 0, summarize_after_turns: 500,
510 sweep_interval_secs: 3600,
511 sweep_batch_size: 50,
512 compress_provider: ProviderName::default(),
513 };
514 let result = run_optical_forgetting_sweep(&store, &provider, &config, 0.0)
515 .await
516 .expect("sweep");
517
518 assert!(
520 result.compressed >= 1,
521 "at least one message must be compressed"
522 );
523 }
524
525 #[tokio::test]
527 async fn sweep_disabled_returns_empty_result() {
528 use std::sync::Arc;
529
530 use zeph_llm::any::AnyProvider;
531 use zeph_llm::mock::MockProvider;
532
533 use crate::store::SqliteStore;
534
535 let store = Arc::new(
536 SqliteStore::new(":memory:")
537 .await
538 .expect("SqliteStore::new"),
539 );
540 let provider = Arc::new(AnyProvider::Mock(MockProvider::default()));
541 let config = OpticalForgettingConfig {
542 enabled: false,
543 ..Default::default()
544 };
545 let result = run_optical_forgetting_sweep(&store, &provider, &config, 0.0)
548 .await
549 .expect("sweep with disabled config");
550 assert_eq!(result.compressed, 0);
551 assert_eq!(result.summarized, 0);
552 }
553}