Skip to main content

zeph_memory/
retrieval_failure_logger.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Async fire-and-forget logger for memory retrieval failure events.
5//!
6//! [`RetrievalFailureLogger`] owns a bounded mpsc sender. Callers invoke
7//! [`RetrievalFailureLogger::log`] on the hot path without blocking. A
8//! background task coalesces records into batches and flushes them to `SQLite`.
9
10use std::time::Duration;
11
12use tokio::sync::mpsc;
13use tracing::Instrument as _;
14
15use crate::store::SqliteStore;
16use crate::store::retrieval_failures::RetrievalFailureRecord;
17
18const QUERY_TEXT_MAX_CHARS: usize = 512;
19const ERROR_CONTEXT_MAX_CHARS: usize = 256;
20/// How often to check for a cleanup opportunity (every N flushes).
21const CLEANUP_FLUSH_INTERVAL: u32 = 500;
22
23/// Async background writer that batches retrieval failure records to `SQLite`.
24///
25/// Construct with [`RetrievalFailureLogger::new`] and call [`RetrievalFailureLogger::log`]
26/// from the recall hot path. Records are sent via a bounded channel; if the channel is
27/// full the record is silently dropped (zero hot-path latency, per INV-1).
28///
29/// Fields are `Option` so that [`shutdown`](Self::shutdown) can take them without a
30/// move-out-of-`Drop` conflict, and so the `Drop` impl can abort any task not yet drained.
31/// `tx` is declared before `handle` to ensure the channel is closed before the handle is
32/// dropped, which allows the background task to exit cleanly when `Drop` fires.
33pub struct RetrievalFailureLogger {
34    // tx MUST be declared before handle — drop order closes the channel before the handle.
35    tx: Option<mpsc::Sender<RetrievalFailureRecord>>,
36    handle: Option<tokio::task::JoinHandle<()>>,
37}
38
39impl RetrievalFailureLogger {
40    /// Spawn the background writer task and return a logger handle.
41    ///
42    /// `batch_size` records are flushed at once, or after `flush_interval` elapses,
43    /// whichever comes first. Old records are purged every `CLEANUP_FLUSH_INTERVAL`
44    /// batch flushes according to `retention_days`.
45    #[must_use]
46    pub fn new(
47        sqlite: SqliteStore,
48        channel_capacity: usize,
49        batch_size: usize,
50        flush_interval: Duration,
51        retention_days: u32,
52    ) -> Self {
53        let (tx, rx) = mpsc::channel(channel_capacity);
54        let handle = tokio::spawn(writer_task(
55            sqlite,
56            rx,
57            batch_size,
58            flush_interval,
59            retention_days,
60        ));
61        Self {
62            tx: Some(tx),
63            handle: Some(handle),
64        }
65    }
66
67    /// Queue a retrieval failure record for async persistence.
68    ///
69    /// Both `query_text` (512 chars) and `error_context` (256 chars) are truncated
70    /// before enqueueing to bound in-channel memory usage (INV-3). If the channel is
71    /// full the record is dropped and a debug message is emitted (INV-1).
72    pub fn log(&self, mut record: RetrievalFailureRecord) {
73        let _span = tracing::debug_span!("memory.retrieval_failure.log").entered();
74        if record.query_text.chars().count() > QUERY_TEXT_MAX_CHARS {
75            record.query_text = record
76                .query_text
77                .chars()
78                .take(QUERY_TEXT_MAX_CHARS)
79                .collect();
80        }
81        if let Some(ref mut ctx) = record.error_context
82            && ctx.chars().count() > ERROR_CONTEXT_MAX_CHARS
83        {
84            *ctx = ctx.chars().take(ERROR_CONTEXT_MAX_CHARS).collect();
85        }
86        if let Some(tx) = &self.tx
87            && tx.try_send(record).is_err()
88        {
89            tracing::debug!("retrieval_failure_logger: channel full, dropping record");
90        }
91    }
92
93    /// Shut down the background writer, draining any queued records.
94    ///
95    /// Closes the sender and waits for the background task to complete. Drop is
96    /// best-effort only; call this method for a clean drain on process exit.
97    pub async fn shutdown(mut self) {
98        drop(self.tx.take());
99        if let Some(handle) = self.handle.take() {
100            let _ = handle.await;
101        }
102    }
103}
104
105impl Drop for RetrievalFailureLogger {
106    /// Abort the background writer task on drop.
107    ///
108    /// For a clean drain (flushing queued records) call [`RetrievalFailureLogger::shutdown`]
109    /// explicitly before dropping. This impl ensures the task is not silently detached.
110    fn drop(&mut self) {
111        if let Some(handle) = &self.handle {
112            handle.abort();
113        }
114    }
115}
116
117async fn writer_task(
118    sqlite: SqliteStore,
119    mut rx: mpsc::Receiver<RetrievalFailureRecord>,
120    batch_size: usize,
121    flush_interval: Duration,
122    retention_days: u32,
123) {
124    let mut batch: Vec<RetrievalFailureRecord> = Vec::with_capacity(batch_size);
125    let mut flush_counter: u32 = 0;
126
127    loop {
128        // Collect up to `batch_size` records or until the flush interval elapses.
129        let deadline = tokio::time::sleep(flush_interval);
130        tokio::pin!(deadline);
131
132        loop {
133            tokio::select! {
134                biased;
135                msg = rx.recv() => {
136                    if let Some(record) = msg {
137                        batch.push(record);
138                        if batch.len() >= batch_size {
139                            break;
140                        }
141                    } else {
142                        // Sender dropped — drain remaining and exit.
143                        flush_batch(&sqlite, &mut batch, &mut flush_counter, retention_days).await;
144                        return;
145                    }
146                }
147                () = &mut deadline => break,
148            }
149        }
150
151        flush_batch(&sqlite, &mut batch, &mut flush_counter, retention_days).await;
152    }
153}
154
155async fn flush_batch(
156    sqlite: &SqliteStore,
157    batch: &mut Vec<RetrievalFailureRecord>,
158    flush_counter: &mut u32,
159    retention_days: u32,
160) {
161    if batch.is_empty() {
162        return;
163    }
164    let count = batch.len();
165    tracing::debug!(count, "retrieval_failure_logger: flushing batch");
166    let span = tracing::info_span!("memory.retrieval_failure.flush", count);
167    let result = sqlite
168        .record_retrieval_failures_batch(batch)
169        .instrument(span)
170        .await;
171    if let Err(e) = result {
172        tracing::warn!("retrieval_failure_logger: batch write failed: {e:#}");
173    }
174    batch.clear();
175
176    *flush_counter = flush_counter.wrapping_add(1);
177    if (*flush_counter).is_multiple_of(CLEANUP_FLUSH_INTERVAL)
178        && let Err(e) = sqlite.purge_old_retrieval_failures(retention_days).await
179    {
180        tracing::debug!("retrieval_failure_logger: cleanup failed: {e:#}");
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use std::time::Duration;
187
188    use super::*;
189    use crate::store::SqliteStore;
190    use crate::store::retrieval_failures::{RetrievalFailureRecord, RetrievalFailureType};
191
192    fn no_hit_record() -> RetrievalFailureRecord {
193        RetrievalFailureRecord {
194            conversation_id: None,
195            turn_index: 0,
196            failure_type: RetrievalFailureType::NoHit,
197            retrieval_strategy: "semantic".into(),
198            query_text: "hello world".into(),
199            query_len: 11,
200            top_score: None,
201            confidence_threshold: None,
202            result_count: 0,
203            latency_ms: 5,
204            edge_types: None,
205            error_context: None,
206        }
207    }
208
209    fn low_confidence_record(score: f32, threshold: f32) -> RetrievalFailureRecord {
210        RetrievalFailureRecord {
211            conversation_id: None,
212            turn_index: 0,
213            failure_type: RetrievalFailureType::LowConfidence,
214            retrieval_strategy: "semantic".into(),
215            query_text: "low confidence query".into(),
216            query_len: 20,
217            top_score: Some(score),
218            confidence_threshold: Some(threshold),
219            result_count: 3,
220            latency_ms: 10,
221            edge_types: None,
222            error_context: None,
223        }
224    }
225
226    #[tokio::test]
227    async fn no_hit_failure_is_persisted() {
228        let sqlite = SqliteStore::new(":memory:").await.unwrap();
229        let logger =
230            RetrievalFailureLogger::new(sqlite.clone(), 256, 16, Duration::from_millis(10), 90);
231        logger.log(no_hit_record());
232        logger.shutdown().await;
233
234        let rows: Vec<(String,)> = sqlx::query_as(
235            "SELECT failure_type FROM memory_retrieval_failures WHERE failure_type = 'no_hit'",
236        )
237        .fetch_all(sqlite.pool())
238        .await
239        .unwrap();
240        assert_eq!(rows.len(), 1, "no_hit record must be persisted");
241    }
242
243    #[tokio::test]
244    async fn low_confidence_failure_is_persisted() {
245        let sqlite = SqliteStore::new(":memory:").await.unwrap();
246        let logger =
247            RetrievalFailureLogger::new(sqlite.clone(), 256, 16, Duration::from_millis(10), 90);
248        logger.log(low_confidence_record(0.3, 0.7));
249        logger.shutdown().await;
250
251        let rows: Vec<(String, f32, f32)> = sqlx::query_as(
252            "SELECT failure_type, top_score, confidence_threshold \
253             FROM memory_retrieval_failures WHERE failure_type = 'low_confidence'",
254        )
255        .fetch_all(sqlite.pool())
256        .await
257        .unwrap();
258        assert_eq!(rows.len(), 1, "low_confidence record must be persisted");
259        let (_, top_score, threshold) = &rows[0];
260        assert!((*top_score - 0.3_f32).abs() < 1e-5, "top_score must match");
261        assert!(
262            (*threshold - 0.7_f32).abs() < 1e-5,
263            "confidence_threshold must match"
264        );
265    }
266
267    #[tokio::test]
268    async fn log_does_not_block_when_channel_is_full() {
269        let sqlite = SqliteStore::new(":memory:").await.unwrap();
270        // capacity = 1 so the second send will be dropped
271        let logger = RetrievalFailureLogger::new(sqlite.clone(), 1, 16, Duration::from_mins(1), 90);
272        // First log fills the channel (capacity 1).
273        logger.log(no_hit_record());
274        // Second log must not block — try_send drops the record silently.
275        let start = std::time::Instant::now();
276        logger.log(no_hit_record());
277        let elapsed = start.elapsed();
278        assert!(
279            elapsed < Duration::from_millis(100),
280            "log() must be non-blocking even when channel is full, elapsed={elapsed:?}"
281        );
282        logger.shutdown().await;
283    }
284
285    #[tokio::test]
286    async fn query_text_truncated_to_512_chars() {
287        let sqlite = SqliteStore::new(":memory:").await.unwrap();
288        let logger =
289            RetrievalFailureLogger::new(sqlite.clone(), 256, 16, Duration::from_millis(10), 90);
290        let long_query = "x".repeat(1000);
291        let mut record = no_hit_record();
292        record.query_text = long_query;
293        record.query_len = 1000;
294        logger.log(record);
295        logger.shutdown().await;
296
297        let rows: Vec<(String,)> =
298            sqlx::query_as("SELECT query_text FROM memory_retrieval_failures")
299                .fetch_all(sqlite.pool())
300                .await
301                .unwrap();
302        assert_eq!(rows.len(), 1);
303        assert_eq!(
304            rows[0].0.chars().count(),
305            512,
306            "query_text must be truncated to 512 chars"
307        );
308    }
309
310    #[tokio::test]
311    async fn logger_disabled_when_option_is_none() {
312        let sqlite = SqliteStore::new(":memory:").await.unwrap();
313        // No logger constructed — simulate the disabled path via Option<RetrievalFailureLogger>.
314        let logger: Option<RetrievalFailureLogger> = None;
315        if let Some(l) = &logger {
316            l.log(no_hit_record());
317        }
318        // Nothing written to the store.
319        let rows: Vec<(i64,)> = sqlx::query_as("SELECT COUNT(*) FROM memory_retrieval_failures")
320            .fetch_all(sqlite.pool())
321            .await
322            .unwrap();
323        assert_eq!(
324            rows[0].0, 0,
325            "no records must be written when logger is None"
326        );
327    }
328
329    #[tokio::test]
330    async fn multiple_records_batch_flushed() {
331        let sqlite = SqliteStore::new(":memory:").await.unwrap();
332        let logger =
333            RetrievalFailureLogger::new(sqlite.clone(), 256, 16, Duration::from_millis(10), 90);
334        for _ in 0..5 {
335            logger.log(no_hit_record());
336        }
337        logger.log(low_confidence_record(0.2, 0.8));
338        logger.shutdown().await;
339
340        let rows: Vec<(i64,)> = sqlx::query_as("SELECT COUNT(*) FROM memory_retrieval_failures")
341            .fetch_all(sqlite.pool())
342            .await
343            .unwrap();
344        assert_eq!(rows[0].0, 6, "all 6 records must be persisted in batch");
345    }
346}