1use std::time::Duration;
11
12use tokio::sync::mpsc;
13use tracing::Instrument as _;
14
15use crate::store::SqliteStore;
16use crate::store::retrieval_failures::RetrievalFailureRecord;
17
18const QUERY_TEXT_MAX_CHARS: usize = 512;
19const ERROR_CONTEXT_MAX_CHARS: usize = 256;
20const CLEANUP_FLUSH_INTERVAL: u32 = 500;
22
23pub struct RetrievalFailureLogger {
34 tx: Option<mpsc::Sender<RetrievalFailureRecord>>,
36 handle: Option<tokio::task::JoinHandle<()>>,
37}
38
39impl RetrievalFailureLogger {
40 #[must_use]
46 pub fn new(
47 sqlite: SqliteStore,
48 channel_capacity: usize,
49 batch_size: usize,
50 flush_interval: Duration,
51 retention_days: u32,
52 ) -> Self {
53 let (tx, rx) = mpsc::channel(channel_capacity);
54 let handle = tokio::spawn(writer_task(
55 sqlite,
56 rx,
57 batch_size,
58 flush_interval,
59 retention_days,
60 ));
61 Self {
62 tx: Some(tx),
63 handle: Some(handle),
64 }
65 }
66
67 pub fn log(&self, mut record: RetrievalFailureRecord) {
73 let _span = tracing::debug_span!("memory.retrieval_failure.log").entered();
74 if record.query_text.chars().count() > QUERY_TEXT_MAX_CHARS {
75 record.query_text = record
76 .query_text
77 .chars()
78 .take(QUERY_TEXT_MAX_CHARS)
79 .collect();
80 }
81 if let Some(ref mut ctx) = record.error_context
82 && ctx.chars().count() > ERROR_CONTEXT_MAX_CHARS
83 {
84 *ctx = ctx.chars().take(ERROR_CONTEXT_MAX_CHARS).collect();
85 }
86 if let Some(tx) = &self.tx
87 && tx.try_send(record).is_err()
88 {
89 tracing::debug!("retrieval_failure_logger: channel full, dropping record");
90 }
91 }
92
93 pub async fn shutdown(mut self) {
98 drop(self.tx.take());
99 if let Some(handle) = self.handle.take() {
100 let _ = handle.await;
101 }
102 }
103}
104
105impl Drop for RetrievalFailureLogger {
106 fn drop(&mut self) {
111 if let Some(handle) = &self.handle {
112 handle.abort();
113 }
114 }
115}
116
117async fn writer_task(
118 sqlite: SqliteStore,
119 mut rx: mpsc::Receiver<RetrievalFailureRecord>,
120 batch_size: usize,
121 flush_interval: Duration,
122 retention_days: u32,
123) {
124 let mut batch: Vec<RetrievalFailureRecord> = Vec::with_capacity(batch_size);
125 let mut flush_counter: u32 = 0;
126
127 loop {
128 let deadline = tokio::time::sleep(flush_interval);
130 tokio::pin!(deadline);
131
132 loop {
133 tokio::select! {
134 biased;
135 msg = rx.recv() => {
136 if let Some(record) = msg {
137 batch.push(record);
138 if batch.len() >= batch_size {
139 break;
140 }
141 } else {
142 flush_batch(&sqlite, &mut batch, &mut flush_counter, retention_days).await;
144 return;
145 }
146 }
147 () = &mut deadline => break,
148 }
149 }
150
151 flush_batch(&sqlite, &mut batch, &mut flush_counter, retention_days).await;
152 }
153}
154
155async fn flush_batch(
156 sqlite: &SqliteStore,
157 batch: &mut Vec<RetrievalFailureRecord>,
158 flush_counter: &mut u32,
159 retention_days: u32,
160) {
161 if batch.is_empty() {
162 return;
163 }
164 let count = batch.len();
165 tracing::debug!(count, "retrieval_failure_logger: flushing batch");
166 let span = tracing::info_span!("memory.retrieval_failure.flush", count);
167 let result = sqlite
168 .record_retrieval_failures_batch(batch)
169 .instrument(span)
170 .await;
171 if let Err(e) = result {
172 tracing::warn!("retrieval_failure_logger: batch write failed: {e:#}");
173 }
174 batch.clear();
175
176 *flush_counter = flush_counter.wrapping_add(1);
177 if (*flush_counter).is_multiple_of(CLEANUP_FLUSH_INTERVAL)
178 && let Err(e) = sqlite.purge_old_retrieval_failures(retention_days).await
179 {
180 tracing::debug!("retrieval_failure_logger: cleanup failed: {e:#}");
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use std::time::Duration;
187
188 use super::*;
189 use crate::store::SqliteStore;
190 use crate::store::retrieval_failures::{RetrievalFailureRecord, RetrievalFailureType};
191
192 fn no_hit_record() -> RetrievalFailureRecord {
193 RetrievalFailureRecord {
194 conversation_id: None,
195 turn_index: 0,
196 failure_type: RetrievalFailureType::NoHit,
197 retrieval_strategy: "semantic".into(),
198 query_text: "hello world".into(),
199 query_len: 11,
200 top_score: None,
201 confidence_threshold: None,
202 result_count: 0,
203 latency_ms: 5,
204 edge_types: None,
205 error_context: None,
206 }
207 }
208
209 fn low_confidence_record(score: f32, threshold: f32) -> RetrievalFailureRecord {
210 RetrievalFailureRecord {
211 conversation_id: None,
212 turn_index: 0,
213 failure_type: RetrievalFailureType::LowConfidence,
214 retrieval_strategy: "semantic".into(),
215 query_text: "low confidence query".into(),
216 query_len: 20,
217 top_score: Some(score),
218 confidence_threshold: Some(threshold),
219 result_count: 3,
220 latency_ms: 10,
221 edge_types: None,
222 error_context: None,
223 }
224 }
225
226 #[tokio::test]
227 async fn no_hit_failure_is_persisted() {
228 let sqlite = SqliteStore::new(":memory:").await.unwrap();
229 let logger =
230 RetrievalFailureLogger::new(sqlite.clone(), 256, 16, Duration::from_millis(10), 90);
231 logger.log(no_hit_record());
232 logger.shutdown().await;
233
234 let rows: Vec<(String,)> = sqlx::query_as(
235 "SELECT failure_type FROM memory_retrieval_failures WHERE failure_type = 'no_hit'",
236 )
237 .fetch_all(sqlite.pool())
238 .await
239 .unwrap();
240 assert_eq!(rows.len(), 1, "no_hit record must be persisted");
241 }
242
243 #[tokio::test]
244 async fn low_confidence_failure_is_persisted() {
245 let sqlite = SqliteStore::new(":memory:").await.unwrap();
246 let logger =
247 RetrievalFailureLogger::new(sqlite.clone(), 256, 16, Duration::from_millis(10), 90);
248 logger.log(low_confidence_record(0.3, 0.7));
249 logger.shutdown().await;
250
251 let rows: Vec<(String, f32, f32)> = sqlx::query_as(
252 "SELECT failure_type, top_score, confidence_threshold \
253 FROM memory_retrieval_failures WHERE failure_type = 'low_confidence'",
254 )
255 .fetch_all(sqlite.pool())
256 .await
257 .unwrap();
258 assert_eq!(rows.len(), 1, "low_confidence record must be persisted");
259 let (_, top_score, threshold) = &rows[0];
260 assert!((*top_score - 0.3_f32).abs() < 1e-5, "top_score must match");
261 assert!(
262 (*threshold - 0.7_f32).abs() < 1e-5,
263 "confidence_threshold must match"
264 );
265 }
266
267 #[tokio::test]
268 async fn log_does_not_block_when_channel_is_full() {
269 let sqlite = SqliteStore::new(":memory:").await.unwrap();
270 let logger = RetrievalFailureLogger::new(sqlite.clone(), 1, 16, Duration::from_mins(1), 90);
272 logger.log(no_hit_record());
274 let start = std::time::Instant::now();
276 logger.log(no_hit_record());
277 let elapsed = start.elapsed();
278 assert!(
279 elapsed < Duration::from_millis(100),
280 "log() must be non-blocking even when channel is full, elapsed={elapsed:?}"
281 );
282 logger.shutdown().await;
283 }
284
285 #[tokio::test]
286 async fn query_text_truncated_to_512_chars() {
287 let sqlite = SqliteStore::new(":memory:").await.unwrap();
288 let logger =
289 RetrievalFailureLogger::new(sqlite.clone(), 256, 16, Duration::from_millis(10), 90);
290 let long_query = "x".repeat(1000);
291 let mut record = no_hit_record();
292 record.query_text = long_query;
293 record.query_len = 1000;
294 logger.log(record);
295 logger.shutdown().await;
296
297 let rows: Vec<(String,)> =
298 sqlx::query_as("SELECT query_text FROM memory_retrieval_failures")
299 .fetch_all(sqlite.pool())
300 .await
301 .unwrap();
302 assert_eq!(rows.len(), 1);
303 assert_eq!(
304 rows[0].0.chars().count(),
305 512,
306 "query_text must be truncated to 512 chars"
307 );
308 }
309
310 #[tokio::test]
311 async fn logger_disabled_when_option_is_none() {
312 let sqlite = SqliteStore::new(":memory:").await.unwrap();
313 let logger: Option<RetrievalFailureLogger> = None;
315 if let Some(l) = &logger {
316 l.log(no_hit_record());
317 }
318 let rows: Vec<(i64,)> = sqlx::query_as("SELECT COUNT(*) FROM memory_retrieval_failures")
320 .fetch_all(sqlite.pool())
321 .await
322 .unwrap();
323 assert_eq!(
324 rows[0].0, 0,
325 "no records must be written when logger is None"
326 );
327 }
328
329 #[tokio::test]
330 async fn multiple_records_batch_flushed() {
331 let sqlite = SqliteStore::new(":memory:").await.unwrap();
332 let logger =
333 RetrievalFailureLogger::new(sqlite.clone(), 256, 16, Duration::from_millis(10), 90);
334 for _ in 0..5 {
335 logger.log(no_hit_record());
336 }
337 logger.log(low_confidence_record(0.2, 0.8));
338 logger.shutdown().await;
339
340 let rows: Vec<(i64,)> = sqlx::query_as("SELECT COUNT(*) FROM memory_retrieval_failures")
341 .fetch_all(sqlite.pool())
342 .await
343 .unwrap();
344 assert_eq!(rows[0].0, 6, "all 6 records must be persisted in batch");
345 }
346}