1use std::str::FromStr;
34use std::time::Duration;
35
36use serde::Deserialize;
37use tokio::time::timeout;
38use zeph_db::{ActiveDialect, DbPool, placeholder_list};
39use zeph_llm::any::AnyProvider;
40use zeph_llm::provider::{LlmProvider as _, Message, Role};
41
42use crate::error::MemoryError;
43use crate::vector_store::VectorStore;
44
45const HOT_STRATEGY_USE_COUNT: i64 = 10;
50
51const MAX_IDS_PER_QUERY: usize = 490;
53
54const SELF_JUDGE_SYSTEM: &str = "\
59You are a task outcome evaluator. Given an agent turn transcript, analyze the conversation and determine:
601. Did the agent successfully complete the user's request? (true/false)
612. Extract the key reasoning steps the agent took (reasoning chain).
623. Summarize the task in one sentence (task hint).
63
64Respond ONLY with valid JSON, no markdown fences, no prose:
65{\"success\": bool, \"reasoning_chain\": \"string\", \"task_hint\": \"string\"}";
66
67const DISTILL_SYSTEM: &str = "\
71You are a strategy distiller. Given a reasoning chain from an agent turn, distill it into \
72a short generalizable strategy (at most 3 sentences) that could help an agent facing a similar \
73task. Focus on the transferable principle, not the specific instance. \
74Respond with the strategy text only — no headers, no lists, no markdown.";
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub enum Outcome {
81 Success,
83 Failure,
85}
86
87impl Outcome {
88 #[must_use]
90 pub fn as_str(self) -> &'static str {
91 match self {
92 Outcome::Success => "success",
93 Outcome::Failure => "failure",
94 }
95 }
96}
97
98#[derive(Debug, thiserror::Error)]
100#[error("unknown outcome: {0}")]
101pub struct OutcomeParseError(String);
102
103impl FromStr for Outcome {
104 type Err = OutcomeParseError;
105
106 fn from_str(s: &str) -> Result<Self, Self::Err> {
107 match s {
108 "success" => Ok(Outcome::Success),
109 "failure" => Ok(Outcome::Failure),
110 other => {
111 tracing::warn!(
112 value = other,
113 "reasoning: unknown outcome, defaulting to Failure"
114 );
115 Ok(Outcome::Failure)
116 }
117 }
118 }
119}
120
121#[derive(Debug, Clone)]
126pub struct ReasoningStrategy {
127 pub id: String,
129 pub summary: String,
131 pub outcome: Outcome,
133 pub task_hint: String,
135 pub created_at: i64,
137 pub last_used_at: i64,
139 pub use_count: i64,
141 pub embedded_at: Option<i64>,
145}
146
147#[derive(Debug, Deserialize)]
153pub struct SelfJudgeOutcome {
154 pub success: bool,
156 pub reasoning_chain: String,
158 pub task_hint: String,
160}
161
162pub struct ReasoningMemory {
168 pool: DbPool,
169 vector_store: Option<std::sync::Arc<dyn VectorStore>>,
173}
174
175pub const REASONING_COLLECTION: &str = "reasoning_strategies";
177
178impl ReasoningMemory {
179 #[must_use]
194 pub fn new(pool: DbPool, vector_store: Option<std::sync::Arc<dyn VectorStore>>) -> Self {
195 Self { pool, vector_store }
196 }
197
198 #[tracing::instrument(name = "memory.reasoning.insert", skip(self, embedding), fields(id = %strategy.id))]
208 pub async fn insert(
209 &self,
210 strategy: &ReasoningStrategy,
211 embedding: Vec<f32>,
212 ) -> Result<(), MemoryError> {
213 let epoch_now = <ActiveDialect as zeph_db::dialect::Dialect>::EPOCH_NOW;
214 let raw = format!(
215 "INSERT OR REPLACE INTO reasoning_strategies \
216 (id, summary, outcome, task_hint, created_at, last_used_at, use_count, embedded_at) \
217 VALUES (?, ?, ?, ?, {epoch_now}, {epoch_now}, 0, NULL)"
218 );
219 let sql = zeph_db::rewrite_placeholders(&raw);
220 zeph_db::query(&sql)
221 .bind(&strategy.id)
222 .bind(&strategy.summary)
223 .bind(strategy.outcome.as_str())
224 .bind(&strategy.task_hint)
225 .execute(&self.pool)
226 .await?;
227
228 if let Some(ref vs) = self.vector_store {
230 let point = crate::vector_store::VectorPoint {
231 id: strategy.id.clone(),
232 vector: embedding,
233 payload: std::collections::HashMap::from([
234 (
235 "outcome".to_owned(),
236 serde_json::Value::String(strategy.outcome.as_str().to_owned()),
237 ),
238 (
239 "task_hint".to_owned(),
240 serde_json::Value::String(strategy.task_hint.clone()),
241 ),
242 ]),
243 };
244 if let Err(e) = vs.upsert(REASONING_COLLECTION, vec![point]).await {
245 tracing::warn!(error = %e, id = %strategy.id, "reasoning: Qdrant upsert failed — SQLite-only mode");
246 } else {
247 let update_sql = zeph_db::rewrite_placeholders(&format!(
249 "UPDATE reasoning_strategies SET embedded_at = {epoch_now} WHERE id = ?"
250 ));
251 if let Err(e) = zeph_db::query(&update_sql)
252 .bind(&strategy.id)
253 .execute(&self.pool)
254 .await
255 {
256 tracing::warn!(error = %e, "reasoning: failed to set embedded_at");
257 }
258 }
259 }
260
261 tracing::debug!(id = %strategy.id, outcome = strategy.outcome.as_str(), "reasoning: strategy inserted");
262 Ok(())
263 }
264
265 #[tracing::instrument(
277 name = "memory.reasoning.retrieve_by_embedding",
278 skip(self, embedding),
279 fields(top_k)
280 )]
281 pub async fn retrieve_by_embedding(
282 &self,
283 embedding: &[f32],
284 top_k: u64,
285 ) -> Result<Vec<ReasoningStrategy>, MemoryError> {
286 let Some(ref vs) = self.vector_store else {
287 return Ok(Vec::new());
288 };
289
290 let scored = vs
291 .search(REASONING_COLLECTION, embedding.to_vec(), top_k, None)
292 .await?;
293
294 if scored.is_empty() {
295 return Ok(Vec::new());
296 }
297
298 let ids: Vec<String> = scored.into_iter().map(|p| p.id).collect();
299 self.fetch_by_ids(&ids).await
300 }
301
302 #[tracing::instrument(name = "memory.reasoning.mark_used", skip(self), fields(n = ids.len()))]
312 pub async fn mark_used(&self, ids: &[String]) -> Result<(), MemoryError> {
313 if ids.is_empty() {
314 return Ok(());
315 }
316
317 let epoch_now = <ActiveDialect as zeph_db::dialect::Dialect>::EPOCH_NOW;
318 for chunk in ids.chunks(MAX_IDS_PER_QUERY) {
319 let ph = placeholder_list(1, chunk.len());
320 let sql = format!(
323 "UPDATE reasoning_strategies \
324 SET use_count = use_count + 1, last_used_at = {epoch_now} \
325 WHERE id IN ({ph})"
326 );
327 let mut q = zeph_db::query(&sql);
328 for id in chunk {
329 q = q.bind(id.as_str());
330 }
331 q.execute(&self.pool).await?;
332 }
333
334 Ok(())
335 }
336
337 #[tracing::instrument(name = "memory.reasoning.evict_lru", skip(self), fields(store_limit))]
353 pub async fn evict_lru(&self, store_limit: usize) -> Result<usize, MemoryError> {
354 let count = self.count().await?;
355 if count <= store_limit {
356 return Ok(0);
357 }
358
359 let over_by = count - store_limit;
360 let deleted_cold = self.delete_oldest_cold(over_by).await?;
361 if deleted_cold > 0 {
362 tracing::debug!(
364 deleted = deleted_cold,
365 count,
366 "reasoning: evicted cold strategies"
367 );
368 return Ok(deleted_cold);
369 }
370
371 let hard_ceiling = store_limit.saturating_mul(2);
373 if count <= hard_ceiling {
374 tracing::debug!(
375 count,
376 store_limit,
377 "reasoning: hot saturation — growth allowed under 2x ceiling"
378 );
379 return Ok(0);
380 }
381
382 let forced = count - store_limit;
384 let deleted_forced = self.delete_oldest_unconditional(forced).await?;
385 tracing::warn!(
386 deleted = deleted_forced,
387 count,
388 hard_ceiling,
389 "reasoning: hard-ceiling eviction — evicted hot strategies; consider raising store_limit"
390 );
391
392 Ok(deleted_forced)
393 }
394
395 pub async fn count(&self) -> Result<usize, MemoryError> {
401 let row: (i64,) = zeph_db::query_as("SELECT COUNT(*) FROM reasoning_strategies")
402 .fetch_one(&self.pool)
403 .await?;
404 Ok(usize::try_from(row.0.max(0)).unwrap_or(0))
405 }
406
407 pub(crate) async fn fetch_by_ids(
411 &self,
412 ids: &[String],
413 ) -> Result<Vec<ReasoningStrategy>, MemoryError> {
414 if ids.is_empty() {
415 return Ok(Vec::new());
416 }
417
418 let mut strategies = Vec::with_capacity(ids.len());
419 for chunk in ids.chunks(MAX_IDS_PER_QUERY) {
420 let ph = placeholder_list(1, chunk.len());
421 let sql = format!(
423 "SELECT id, summary, outcome, task_hint, created_at, last_used_at, use_count, embedded_at \
424 FROM reasoning_strategies WHERE id IN ({ph})"
425 );
426 let mut q = zeph_db::query_as::<
427 _,
428 (String, String, String, String, i64, i64, i64, Option<i64>),
429 >(&sql);
430 for id in chunk {
431 q = q.bind(id.as_str());
432 }
433 let rows = q.fetch_all(&self.pool).await?;
434 for (
435 id,
436 summary,
437 outcome_str,
438 task_hint,
439 created_at,
440 last_used_at,
441 use_count,
442 embedded_at,
443 ) in rows
444 {
445 let outcome = Outcome::from_str(&outcome_str).unwrap_or(Outcome::Failure);
446 strategies.push(ReasoningStrategy {
447 id,
448 summary,
449 outcome,
450 task_hint,
451 created_at,
452 last_used_at,
453 use_count,
454 embedded_at,
455 });
456 }
457 }
458
459 Ok(strategies)
460 }
461
462 async fn delete_oldest_cold(&self, n: usize) -> Result<usize, MemoryError> {
466 let limit = i64::try_from(n).unwrap_or(i64::MAX);
467 let raw = format!(
469 "DELETE FROM reasoning_strategies \
470 WHERE id IN ( \
471 SELECT id FROM reasoning_strategies \
472 WHERE use_count <= {HOT_STRATEGY_USE_COUNT} \
473 ORDER BY last_used_at ASC LIMIT ? \
474 )"
475 );
476 let sql = zeph_db::rewrite_placeholders(&raw);
477 let result = zeph_db::query(&sql).bind(limit).execute(&self.pool).await?;
478 Ok(usize::try_from(result.rows_affected()).unwrap_or(0))
479 }
480
481 async fn delete_oldest_unconditional(&self, n: usize) -> Result<usize, MemoryError> {
485 let limit = i64::try_from(n).unwrap_or(i64::MAX);
486 let raw = "DELETE FROM reasoning_strategies \
487 WHERE id IN ( \
488 SELECT id FROM reasoning_strategies \
489 ORDER BY last_used_at ASC LIMIT ? \
490 )";
491 let sql = zeph_db::rewrite_placeholders(raw);
492 let result = zeph_db::query(&sql).bind(limit).execute(&self.pool).await?;
493 Ok(usize::try_from(result.rows_affected()).unwrap_or(0))
494 }
495}
496
497#[tracing::instrument(name = "memory.reasoning.self_judge", skip(provider, messages), fields(n = messages.len()))]
522pub async fn run_self_judge(
523 provider: &AnyProvider,
524 messages: &[Message],
525 extraction_timeout: Duration,
526) -> Option<SelfJudgeOutcome> {
527 if messages.is_empty() {
528 return None;
529 }
530
531 let user_prompt = build_transcript_prompt(messages);
532
533 let llm_messages = [
534 Message::from_legacy(Role::System, SELF_JUDGE_SYSTEM),
535 Message::from_legacy(Role::User, user_prompt),
536 ];
537
538 let response = match timeout(extraction_timeout, provider.chat(&llm_messages)).await {
539 Ok(Ok(text)) => text,
540 Ok(Err(e)) => {
541 tracing::warn!(error = %e, "reasoning: self-judge LLM call failed");
542 return None;
543 }
544 Err(_) => {
545 tracing::warn!("reasoning: self-judge timed out");
546 return None;
547 }
548 };
549
550 parse_self_judge_response(&response)
551}
552
553#[tracing::instrument(name = "memory.reasoning.distill", skip(provider, reasoning_chain))]
573pub async fn distill_strategy(
574 provider: &AnyProvider,
575 outcome: Outcome,
576 reasoning_chain: &str,
577 distill_timeout: Duration,
578) -> Option<String> {
579 if reasoning_chain.is_empty() {
580 return None;
581 }
582
583 let user_prompt = format!(
584 "Outcome: {}\n\nReasoning chain:\n{reasoning_chain}",
585 outcome.as_str()
586 );
587
588 let llm_messages = [
589 Message::from_legacy(Role::System, DISTILL_SYSTEM),
590 Message::from_legacy(Role::User, user_prompt),
591 ];
592
593 let response = match timeout(distill_timeout, provider.chat(&llm_messages)).await {
594 Ok(Ok(text)) => text,
595 Ok(Err(e)) => {
596 tracing::warn!(error = %e, "reasoning: distillation LLM call failed");
597 return None;
598 }
599 Err(_) => {
600 tracing::warn!("reasoning: distillation timed out");
601 return None;
602 }
603 };
604
605 let trimmed = trim_to_three_sentences(&response);
606 if trimmed.is_empty() {
607 None
608 } else {
609 Some(trimmed)
610 }
611}
612
613#[derive(Debug, Clone, Copy)]
617pub struct ProcessTurnConfig {
618 pub store_limit: usize,
620 pub extraction_timeout: Duration,
622 pub distill_timeout: Duration,
624 pub self_judge_window: usize,
628 pub min_assistant_chars: usize,
631}
632
633#[tracing::instrument(name = "memory.reasoning.process_turn", skip_all)]
644pub async fn process_turn(
645 memory: &ReasoningMemory,
646 extract_provider: &AnyProvider,
647 distill_provider: &AnyProvider,
648 embed_provider: &AnyProvider,
649 messages: &[Message],
650 cfg: ProcessTurnConfig,
651) -> Result<(), MemoryError> {
652 let ProcessTurnConfig {
653 store_limit,
654 extraction_timeout,
655 distill_timeout,
656 self_judge_window,
657 min_assistant_chars,
658 } = cfg;
659
660 let judge_messages = if messages.len() > self_judge_window {
663 &messages[messages.len() - self_judge_window..]
664 } else {
665 messages
666 };
667
668 let last_assistant_chars = judge_messages
670 .iter()
671 .rev()
672 .find(|m| m.role == Role::Assistant)
673 .map_or(0, |m| m.content.len());
674 if last_assistant_chars < min_assistant_chars {
675 return Ok(());
676 }
677
678 let Some(outcome) = run_self_judge(extract_provider, judge_messages, extraction_timeout).await
679 else {
680 return Ok(());
681 };
682
683 let outcome_enum = if outcome.success {
684 Outcome::Success
685 } else {
686 Outcome::Failure
687 };
688
689 let Some(summary) = distill_strategy(
690 distill_provider,
691 outcome_enum,
692 &outcome.reasoning_chain,
693 distill_timeout,
694 )
695 .await
696 else {
697 return Ok(());
698 };
699
700 let embed_input = format!("{}\n{}", outcome.task_hint, summary);
702 let embedding = match embed_provider.embed(&embed_input).await {
703 Ok(v) => v,
704 Err(e) => {
705 tracing::warn!(error = %e, "reasoning: embedding failed — strategy not stored");
706 return Ok(());
707 }
708 };
709
710 let id = uuid::Uuid::new_v4().to_string();
711 let strategy = ReasoningStrategy {
712 id,
713 summary,
714 outcome: outcome_enum,
715 task_hint: outcome.task_hint,
716 created_at: 0, last_used_at: 0,
718 use_count: 0,
719 embedded_at: None,
720 };
721
722 let count_before = memory.count().await.unwrap_or(0);
727
728 if let Err(e) = memory.insert(&strategy, embedding).await {
729 tracing::warn!(error = %e, "reasoning: insert failed");
730 return Ok(());
731 }
732
733 if count_before >= store_limit
734 && let Err(e) = memory.evict_lru(store_limit).await
735 {
736 tracing::warn!(error = %e, "reasoning: evict_lru failed");
737 }
738
739 Ok(())
740}
741
742const MAX_TRANSCRIPT_MESSAGE_CHARS: usize = 2000;
749
750fn build_transcript_prompt(messages: &[Message]) -> String {
756 let mut prompt = String::from("Agent turn messages:\n");
757 for (i, msg) in messages.iter().enumerate() {
758 use std::fmt::Write as _;
759 let role = format!("{:?}", msg.role);
760 let content: std::borrow::Cow<str> =
762 if msg.content.chars().count() > MAX_TRANSCRIPT_MESSAGE_CHARS {
763 msg.content
764 .char_indices()
765 .nth(MAX_TRANSCRIPT_MESSAGE_CHARS)
766 .map_or(msg.content.as_str().into(), |(byte_idx, _)| {
767 msg.content[..byte_idx].into()
768 })
769 } else {
770 msg.content.as_str().into()
771 };
772 let _ = writeln!(prompt, "[{}] {}: {}", i + 1, role, content);
773 }
774 prompt.push_str("\nEvaluate this turn and return JSON.");
775 prompt
776}
777
778fn parse_self_judge_response(response: &str) -> Option<SelfJudgeOutcome> {
783 let stripped = response
785 .trim()
786 .trim_start_matches("```json")
787 .trim_start_matches("```")
788 .trim_end_matches("```")
789 .trim();
790
791 if let Ok(v) = serde_json::from_str::<SelfJudgeOutcome>(stripped) {
792 return Some(v);
793 }
794
795 if let (Some(start), Some(end)) = (stripped.find('{'), stripped.rfind('}'))
797 && end > start
798 && let Ok(v) = serde_json::from_str::<SelfJudgeOutcome>(&stripped[start..=end])
799 {
800 return Some(v);
801 }
802
803 tracing::warn!(
804 "reasoning: failed to parse self-judge response (len={}): {:.200}",
805 response.len(),
806 response
807 );
808 None
809}
810
811fn trim_to_three_sentences(text: &str) -> String {
816 const MAX_CHARS: usize = 512;
817 const MAX_SENTENCES: usize = 3;
818
819 let text = text.trim();
820 let mut sentence_ends: Vec<usize> = Vec::new();
821 let chars: Vec<char> = text.chars().collect();
822 let len = chars.len();
823
824 for (i, &ch) in chars.iter().enumerate() {
825 if matches!(ch, '.' | '!' | '?') {
826 let next_is_boundary = i + 1 >= len || chars[i + 1].is_whitespace();
827 if next_is_boundary {
828 sentence_ends.push(i + 1); if sentence_ends.len() >= MAX_SENTENCES {
830 break;
831 }
832 }
833 }
834 }
835
836 let char_limit = if let Some(&end) = sentence_ends.last() {
837 end.min(MAX_CHARS)
838 } else {
839 text.chars().count().min(MAX_CHARS)
840 };
841
842 let result: String = text.chars().take(char_limit).collect();
843 match result.char_indices().nth(MAX_CHARS) {
845 Some((byte_idx, _)) => result[..byte_idx].to_owned(),
846 None => result,
847 }
848}
849
850#[cfg(test)]
851mod tests {
852 use super::*;
853
854 #[test]
857 fn outcome_as_str_round_trip() {
858 assert_eq!(Outcome::Success.as_str(), "success");
859 assert_eq!(Outcome::Failure.as_str(), "failure");
860 }
861
862 #[test]
863 fn outcome_from_str_success() {
864 assert_eq!(Outcome::from_str("success").unwrap(), Outcome::Success);
865 }
866
867 #[test]
868 fn outcome_from_str_failure() {
869 assert_eq!(Outcome::from_str("failure").unwrap(), Outcome::Failure);
870 }
871
872 #[test]
873 fn outcome_from_str_unknown_defaults_to_failure() {
874 assert_eq!(Outcome::from_str("partial").unwrap(), Outcome::Failure);
876 }
877
878 #[test]
881 fn parse_direct_json() {
882 let json = r#"{"success":true,"reasoning_chain":"tried X","task_hint":"do Y"}"#;
883 let outcome = parse_self_judge_response(json).unwrap();
884 assert!(outcome.success);
885 assert_eq!(outcome.reasoning_chain, "tried X");
886 assert_eq!(outcome.task_hint, "do Y");
887 }
888
889 #[test]
890 fn parse_json_with_markdown_fences() {
891 let response =
892 "```json\n{\"success\":false,\"reasoning_chain\":\"r\",\"task_hint\":\"t\"}\n```";
893 let outcome = parse_self_judge_response(response).unwrap();
894 assert!(!outcome.success);
895 }
896
897 #[test]
898 fn parse_json_embedded_in_prose() {
899 let response = r#"Here is the evaluation: {"success":true,"reasoning_chain":"chain","task_hint":"hint"} — done."#;
900 let outcome = parse_self_judge_response(response).unwrap();
901 assert!(outcome.success);
902 }
903
904 #[test]
905 fn parse_invalid_returns_none() {
906 let outcome = parse_self_judge_response("not json at all");
907 assert!(outcome.is_none());
908 }
909
910 #[test]
913 fn trim_three_sentences_short_text() {
914 let text = "One. Two. Three.";
915 assert_eq!(trim_to_three_sentences(text), "One. Two. Three.");
916 }
917
918 #[test]
919 fn trim_three_sentences_truncates_at_third() {
920 let text = "One. Two. Three. Four. Five.";
921 let result = trim_to_three_sentences(text);
922 assert!(result.ends_with("Three."), "got: {result}");
923 assert!(!result.contains("Four"));
924 }
925
926 #[test]
927 fn trim_three_sentences_hard_cap() {
928 let long: String = "x".repeat(600);
930 let result = trim_to_three_sentences(&long);
931 assert!(result.chars().count() <= 512);
932 }
933
934 #[test]
935 fn trim_three_sentences_empty() {
936 assert_eq!(trim_to_three_sentences(" "), "");
937 }
938
939 async fn make_test_pool() -> DbPool {
942 let pool = sqlx::SqlitePool::connect(":memory:").await.unwrap();
943 sqlx::query(
944 "CREATE TABLE reasoning_strategies (
945 id TEXT PRIMARY KEY NOT NULL,
946 summary TEXT NOT NULL,
947 outcome TEXT NOT NULL,
948 task_hint TEXT NOT NULL,
949 created_at INTEGER NOT NULL DEFAULT (unixepoch('now')),
950 last_used_at INTEGER NOT NULL DEFAULT (unixepoch('now')),
951 use_count INTEGER NOT NULL DEFAULT 0,
952 embedded_at INTEGER
953 )",
954 )
955 .execute(&pool)
956 .await
957 .unwrap();
958 pool
959 }
960
961 fn make_strategy(id: &str) -> ReasoningStrategy {
962 ReasoningStrategy {
963 id: id.to_owned(),
964 summary: format!("Summary for {id}"),
965 outcome: Outcome::Success,
966 task_hint: format!("Task hint for {id}"),
967 created_at: 0,
968 last_used_at: 0,
969 use_count: 0,
970 embedded_at: None,
971 }
972 }
973
974 #[tokio::test]
975 async fn insert_and_fetch_by_ids() {
976 let pool = make_test_pool().await;
977 let mem = ReasoningMemory::new(pool, None);
978
979 let s = make_strategy("abc-123");
980 mem.insert(&s, vec![]).await.unwrap();
981
982 let rows = mem.fetch_by_ids(&["abc-123".to_owned()]).await.unwrap();
983 assert_eq!(rows.len(), 1);
984 assert_eq!(rows[0].id, "abc-123");
985 assert_eq!(rows[0].outcome, Outcome::Success);
986 }
987
988 #[tokio::test]
989 async fn mark_used_increments_count() {
990 let pool = make_test_pool().await;
991 let mem = ReasoningMemory::new(pool, None);
992
993 let s = make_strategy("mark-1");
994 mem.insert(&s, vec![]).await.unwrap();
995 mem.mark_used(&["mark-1".to_owned()]).await.unwrap();
996 mem.mark_used(&["mark-1".to_owned()]).await.unwrap();
997
998 let rows = mem.fetch_by_ids(&["mark-1".to_owned()]).await.unwrap();
999 assert_eq!(rows[0].use_count, 2);
1000 }
1001
1002 #[tokio::test]
1003 async fn mark_used_empty_is_noop() {
1004 let pool = make_test_pool().await;
1005 let mem = ReasoningMemory::new(pool, None);
1006 mem.mark_used(&[]).await.unwrap();
1008 }
1009
1010 #[tokio::test]
1011 async fn count_returns_correct_total() {
1012 let pool = make_test_pool().await;
1013 let mem = ReasoningMemory::new(pool, None);
1014
1015 for i in 0..5 {
1016 mem.insert(&make_strategy(&format!("s{i}")), vec![])
1017 .await
1018 .unwrap();
1019 }
1020
1021 assert_eq!(mem.count().await.unwrap(), 5);
1022 }
1023
1024 #[tokio::test]
1025 async fn evict_lru_cold_rows() {
1026 let pool = make_test_pool().await;
1027 let mem = ReasoningMemory::new(pool, None);
1028
1029 for i in 0..5 {
1031 mem.insert(&make_strategy(&format!("cold-{i}")), vec![])
1032 .await
1033 .unwrap();
1034 }
1035
1036 let deleted = mem.evict_lru(3).await.unwrap();
1038 assert_eq!(deleted, 2);
1039 assert_eq!(mem.count().await.unwrap(), 3);
1040 }
1041
1042 #[tokio::test]
1043 async fn evict_lru_respects_hot_rows_under_ceiling() {
1044 let pool = make_test_pool().await;
1045 let mem = ReasoningMemory::new(pool.clone(), None);
1046
1047 for i in 0..5 {
1049 let id = format!("hot-{i}");
1050 mem.insert(&make_strategy(&id), vec![]).await.unwrap();
1051 let ids: Vec<String> = (0..11).map(|_| id.clone()).collect();
1053 for chunk_ids in ids.chunks(1) {
1054 mem.mark_used(chunk_ids).await.unwrap();
1055 }
1056 }
1057
1058 let deleted = mem.evict_lru(3).await.unwrap();
1060 assert_eq!(deleted, 0);
1061 assert_eq!(mem.count().await.unwrap(), 5);
1062 }
1063
1064 #[tokio::test]
1065 async fn evict_lru_hard_ceiling_forces_deletion() {
1066 let pool = make_test_pool().await;
1067 let mem = ReasoningMemory::new(pool.clone(), None);
1068
1069 for i in 0..7 {
1071 let id = format!("hot2-{i}");
1072 mem.insert(&make_strategy(&id), vec![]).await.unwrap();
1073 for _ in 0..=HOT_STRATEGY_USE_COUNT {
1075 mem.mark_used(std::slice::from_ref(&id)).await.unwrap();
1076 }
1077 }
1078
1079 let deleted = mem.evict_lru(3).await.unwrap();
1080 assert!(deleted > 0, "expected forced deletion");
1081 let remaining = mem.count().await.unwrap();
1082 assert_eq!(remaining, 3, "should be trimmed to store_limit");
1083 }
1084
1085 #[tokio::test]
1086 async fn evict_lru_no_op_when_under_limit() {
1087 let pool = make_test_pool().await;
1088 let mem = ReasoningMemory::new(pool, None);
1089
1090 for i in 0..3 {
1091 mem.insert(&make_strategy(&format!("s{i}")), vec![])
1092 .await
1093 .unwrap();
1094 }
1095
1096 let deleted = mem.evict_lru(10).await.unwrap();
1098 assert_eq!(deleted, 0);
1099 }
1100
1101 #[tokio::test]
1104 async fn mark_used_chunked_over_490_ids() {
1105 let pool = make_test_pool().await;
1106 let mem = ReasoningMemory::new(pool, None);
1107
1108 for i in 0..500usize {
1110 mem.insert(&make_strategy(&format!("chunked-{i}")), vec![])
1111 .await
1112 .unwrap();
1113 }
1114
1115 let ids: Vec<String> = (0..500usize).map(|i| format!("chunked-{i}")).collect();
1116 mem.mark_used(&ids).await.unwrap();
1117
1118 let first = mem.fetch_by_ids(&[ids[0].clone()]).await.unwrap();
1120 let over_chunk = mem.fetch_by_ids(&[ids[490].clone()]).await.unwrap();
1121 assert_eq!(first[0].use_count, 1, "first id should have use_count = 1");
1122 assert_eq!(
1123 over_chunk[0].use_count, 1,
1124 "id past the chunk boundary should have use_count = 1"
1125 );
1126 }
1127
1128 #[tokio::test]
1131 async fn run_self_judge_malformed_json_returns_none() {
1132 use zeph_llm::any::AnyProvider;
1133 use zeph_llm::mock::MockProvider;
1134
1135 let provider = AnyProvider::Mock(MockProvider::with_responses(vec![
1137 "This is not JSON at all.".to_string(),
1138 ]));
1139 let msgs = vec![Message::from_legacy(Role::User, "hello")];
1140 let result = run_self_judge(&provider, &msgs, std::time::Duration::from_secs(5)).await;
1141 assert!(result.is_none(), "malformed LLM response must return None");
1142 }
1143
1144 #[tokio::test]
1147 async fn distill_strategy_truncates_to_three_sentences() {
1148 use zeph_llm::any::AnyProvider;
1149 use zeph_llm::mock::MockProvider;
1150
1151 let long_response = "One. Two. Three. Four. Five.";
1152 let provider = AnyProvider::Mock(MockProvider::with_responses(vec![
1153 long_response.to_string(),
1154 ]));
1155 let result = distill_strategy(
1156 &provider,
1157 Outcome::Success,
1158 "chain here",
1159 std::time::Duration::from_secs(5),
1160 )
1161 .await
1162 .unwrap();
1163 assert!(result.ends_with("Three."), "got: {result}");
1164 assert!(
1165 !result.contains("Four"),
1166 "should not contain 4th sentence: {result}"
1167 );
1168 }
1169
1170 #[tokio::test]
1173 async fn process_turn_with_empty_messages_is_noop() {
1174 use zeph_llm::any::AnyProvider;
1175 use zeph_llm::mock::MockProvider;
1176
1177 let pool = make_test_pool().await;
1178 let mem = ReasoningMemory::new(pool, None);
1179 let provider = AnyProvider::Mock(MockProvider::default());
1182 let cfg = ProcessTurnConfig {
1183 store_limit: 100,
1184 extraction_timeout: std::time::Duration::from_secs(1),
1185 distill_timeout: std::time::Duration::from_secs(1),
1186 self_judge_window: 2,
1187 min_assistant_chars: 0,
1188 };
1189 let result = process_turn(&mem, &provider, &provider, &provider, &[], cfg).await;
1190 assert!(
1191 result.is_ok(),
1192 "process_turn with empty messages must succeed"
1193 );
1194 assert_eq!(
1195 mem.count().await.unwrap(),
1196 0,
1197 "no strategies should be stored"
1198 );
1199 }
1200}