1use std::str::FromStr;
34use std::time::Duration;
35
36use serde::Deserialize;
37use tokio::time::timeout;
38use zeph_db::{ActiveDialect, DbPool, placeholder_list};
39use zeph_llm::any::AnyProvider;
40use zeph_llm::provider::{LlmProvider as _, Message, Role};
41
42use crate::error::MemoryError;
43use crate::vector_store::VectorStore;
44
45const HOT_STRATEGY_USE_COUNT: i64 = 10;
50
51const MAX_IDS_PER_QUERY: usize = 490;
53
54const SELF_JUDGE_SYSTEM: &str = "\
59You are a task outcome evaluator. Given an agent turn transcript, analyze the conversation and determine:
601. Did the agent successfully complete the user's request? (true/false)
612. Extract the key reasoning steps the agent took (reasoning chain).
623. Summarize the task in one sentence (task hint).
63
64Respond ONLY with valid JSON, no markdown fences, no prose:
65{\"success\": bool, \"reasoning_chain\": \"string\", \"task_hint\": \"string\"}";
66
67const DISTILL_SYSTEM: &str = "\
71You are a strategy distiller. Given a reasoning chain from an agent turn, distill it into \
72a short generalizable strategy (at most 3 sentences) that could help an agent facing a similar \
73task. Focus on the transferable principle, not the specific instance. \
74Respond with the strategy text only — no headers, no lists, no markdown.";
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub enum Outcome {
81 Success,
83 Failure,
85}
86
87impl Outcome {
88 #[must_use]
90 pub fn as_str(self) -> &'static str {
91 match self {
92 Outcome::Success => "success",
93 Outcome::Failure => "failure",
94 }
95 }
96}
97
98#[derive(Debug, thiserror::Error)]
100#[error("unknown outcome: {0}")]
101pub struct OutcomeParseError(String);
102
103impl FromStr for Outcome {
104 type Err = OutcomeParseError;
105
106 fn from_str(s: &str) -> Result<Self, Self::Err> {
107 match s {
108 "success" => Ok(Outcome::Success),
109 "failure" => Ok(Outcome::Failure),
110 other => {
111 tracing::warn!(
112 value = other,
113 "reasoning: unknown outcome, defaulting to Failure"
114 );
115 Ok(Outcome::Failure)
116 }
117 }
118 }
119}
120
121#[derive(Debug, Clone)]
126pub struct ReasoningStrategy {
127 pub id: String,
129 pub summary: String,
131 pub outcome: Outcome,
133 pub task_hint: String,
135 pub created_at: i64,
137 pub last_used_at: i64,
139 pub use_count: i64,
141 pub embedded_at: Option<i64>,
145}
146
147#[derive(Debug, Deserialize)]
153pub struct SelfJudgeOutcome {
154 pub success: bool,
156 pub reasoning_chain: String,
158 pub task_hint: String,
160}
161
162pub struct ReasoningMemory {
168 pool: DbPool,
169 vector_store: Option<std::sync::Arc<dyn VectorStore>>,
173}
174
175pub const REASONING_COLLECTION: &str = "reasoning_strategies";
177
178impl ReasoningMemory {
179 #[must_use]
194 pub fn new(pool: DbPool, vector_store: Option<std::sync::Arc<dyn VectorStore>>) -> Self {
195 Self { pool, vector_store }
196 }
197
198 #[tracing::instrument(name = "memory.reasoning.insert", skip(self, embedding), fields(id = %strategy.id))]
208 pub async fn insert(
209 &self,
210 strategy: &ReasoningStrategy,
211 embedding: Vec<f32>,
212 ) -> Result<(), MemoryError> {
213 let epoch_now = <ActiveDialect as zeph_db::dialect::Dialect>::EPOCH_NOW;
214 let raw = format!(
215 "INSERT INTO reasoning_strategies \
216 (id, summary, outcome, task_hint, created_at, last_used_at, use_count, embedded_at) \
217 VALUES (?, ?, ?, ?, {epoch_now}, {epoch_now}, 0, NULL) \
218 ON CONFLICT (id) DO UPDATE SET \
219 summary = EXCLUDED.summary, \
220 outcome = EXCLUDED.outcome, \
221 task_hint = EXCLUDED.task_hint, \
222 last_used_at = EXCLUDED.last_used_at, \
223 embedded_at = EXCLUDED.embedded_at"
224 );
225 let sql = zeph_db::rewrite_placeholders(&raw);
226 zeph_db::query(&sql)
227 .bind(&strategy.id)
228 .bind(&strategy.summary)
229 .bind(strategy.outcome.as_str())
230 .bind(&strategy.task_hint)
231 .execute(&self.pool)
232 .await?;
233
234 if let Some(ref vs) = self.vector_store {
236 let point = crate::vector_store::VectorPoint {
237 id: strategy.id.clone(),
238 vector: embedding,
239 payload: std::collections::HashMap::from([
240 (
241 "outcome".to_owned(),
242 serde_json::Value::String(strategy.outcome.as_str().to_owned()),
243 ),
244 (
245 "task_hint".to_owned(),
246 serde_json::Value::String(strategy.task_hint.clone()),
247 ),
248 ]),
249 };
250 if let Err(e) = vs.upsert(REASONING_COLLECTION, vec![point]).await {
251 tracing::warn!(error = %e, id = %strategy.id, "reasoning: Qdrant upsert failed — SQLite-only mode");
252 } else {
253 let update_sql = zeph_db::rewrite_placeholders(&format!(
255 "UPDATE reasoning_strategies SET embedded_at = {epoch_now} WHERE id = ?"
256 ));
257 if let Err(e) = zeph_db::query(&update_sql)
258 .bind(&strategy.id)
259 .execute(&self.pool)
260 .await
261 {
262 tracing::warn!(error = %e, "reasoning: failed to set embedded_at");
263 }
264 }
265 }
266
267 tracing::debug!(id = %strategy.id, outcome = strategy.outcome.as_str(), "reasoning: strategy inserted");
268 Ok(())
269 }
270
271 #[tracing::instrument(
283 name = "memory.reasoning.retrieve_by_embedding",
284 skip(self, embedding),
285 fields(top_k)
286 )]
287 pub async fn retrieve_by_embedding(
288 &self,
289 embedding: &[f32],
290 top_k: u64,
291 ) -> Result<Vec<ReasoningStrategy>, MemoryError> {
292 let Some(ref vs) = self.vector_store else {
293 return Ok(Vec::new());
294 };
295
296 let scored = vs
297 .search(REASONING_COLLECTION, embedding.to_vec(), top_k, None)
298 .await?;
299
300 if scored.is_empty() {
301 return Ok(Vec::new());
302 }
303
304 let ids: Vec<String> = scored.into_iter().map(|p| p.id).collect();
305 self.fetch_by_ids(&ids).await
306 }
307
308 #[tracing::instrument(name = "memory.reasoning.mark_used", skip(self), fields(n = ids.len()))]
318 pub async fn mark_used(&self, ids: &[String]) -> Result<(), MemoryError> {
319 if ids.is_empty() {
320 return Ok(());
321 }
322
323 let epoch_now = <ActiveDialect as zeph_db::dialect::Dialect>::EPOCH_NOW;
324 for chunk in ids.chunks(MAX_IDS_PER_QUERY) {
325 let ph = placeholder_list(1, chunk.len());
326 let sql = format!(
329 "UPDATE reasoning_strategies \
330 SET use_count = use_count + 1, last_used_at = {epoch_now} \
331 WHERE id IN ({ph})"
332 );
333 let mut q = zeph_db::query(&sql);
334 for id in chunk {
335 q = q.bind(id.as_str());
336 }
337 q.execute(&self.pool).await?;
338 }
339
340 Ok(())
341 }
342
343 #[tracing::instrument(name = "memory.reasoning.evict_lru", skip(self), fields(store_limit))]
359 pub async fn evict_lru(&self, store_limit: usize) -> Result<usize, MemoryError> {
360 let count = self.count().await?;
361 if count <= store_limit {
362 return Ok(0);
363 }
364
365 let over_by = count - store_limit;
366 let deleted_cold = self.delete_oldest_cold(over_by).await?;
367 if deleted_cold > 0 {
368 tracing::debug!(
370 deleted = deleted_cold,
371 count,
372 "reasoning: evicted cold strategies"
373 );
374 return Ok(deleted_cold);
375 }
376
377 let hard_ceiling = store_limit.saturating_mul(2);
379 if count <= hard_ceiling {
380 tracing::debug!(
381 count,
382 store_limit,
383 "reasoning: hot saturation — growth allowed under 2x ceiling"
384 );
385 return Ok(0);
386 }
387
388 let forced = count - store_limit;
390 let deleted_forced = self.delete_oldest_unconditional(forced).await?;
391 tracing::warn!(
392 deleted = deleted_forced,
393 count,
394 hard_ceiling,
395 "reasoning: hard-ceiling eviction — evicted hot strategies; consider raising store_limit"
396 );
397
398 Ok(deleted_forced)
399 }
400
401 pub async fn count(&self) -> Result<usize, MemoryError> {
407 let row: (i64,) = zeph_db::query_as("SELECT COUNT(*) FROM reasoning_strategies")
408 .fetch_one(&self.pool)
409 .await?;
410 Ok(usize::try_from(row.0.max(0)).unwrap_or(0))
411 }
412
413 pub(crate) async fn fetch_by_ids(
417 &self,
418 ids: &[String],
419 ) -> Result<Vec<ReasoningStrategy>, MemoryError> {
420 if ids.is_empty() {
421 return Ok(Vec::new());
422 }
423
424 let mut strategies = Vec::with_capacity(ids.len());
425 for chunk in ids.chunks(MAX_IDS_PER_QUERY) {
426 let ph = placeholder_list(1, chunk.len());
427 let sql = format!(
429 "SELECT id, summary, outcome, task_hint, created_at, last_used_at, use_count, embedded_at \
430 FROM reasoning_strategies WHERE id IN ({ph})"
431 );
432 let mut q = zeph_db::query_as::<
433 _,
434 (String, String, String, String, i64, i64, i64, Option<i64>),
435 >(&sql);
436 for id in chunk {
437 q = q.bind(id.as_str());
438 }
439 let rows = q.fetch_all(&self.pool).await?;
440 for (
441 id,
442 summary,
443 outcome_str,
444 task_hint,
445 created_at,
446 last_used_at,
447 use_count,
448 embedded_at,
449 ) in rows
450 {
451 let outcome = Outcome::from_str(&outcome_str).unwrap_or(Outcome::Failure);
452 strategies.push(ReasoningStrategy {
453 id,
454 summary,
455 outcome,
456 task_hint,
457 created_at,
458 last_used_at,
459 use_count,
460 embedded_at,
461 });
462 }
463 }
464
465 Ok(strategies)
466 }
467
468 async fn delete_oldest_cold(&self, n: usize) -> Result<usize, MemoryError> {
472 let limit = i64::try_from(n).unwrap_or(i64::MAX);
473 let raw = format!(
475 "DELETE FROM reasoning_strategies \
476 WHERE id IN ( \
477 SELECT id FROM reasoning_strategies \
478 WHERE use_count <= {HOT_STRATEGY_USE_COUNT} \
479 ORDER BY last_used_at ASC LIMIT ? \
480 )"
481 );
482 let sql = zeph_db::rewrite_placeholders(&raw);
483 let result = zeph_db::query(&sql).bind(limit).execute(&self.pool).await?;
484 Ok(usize::try_from(result.rows_affected()).unwrap_or(0))
485 }
486
487 async fn delete_oldest_unconditional(&self, n: usize) -> Result<usize, MemoryError> {
491 let limit = i64::try_from(n).unwrap_or(i64::MAX);
492 let raw = "DELETE FROM reasoning_strategies \
493 WHERE id IN ( \
494 SELECT id FROM reasoning_strategies \
495 ORDER BY last_used_at ASC LIMIT ? \
496 )";
497 let sql = zeph_db::rewrite_placeholders(raw);
498 let result = zeph_db::query(&sql).bind(limit).execute(&self.pool).await?;
499 Ok(usize::try_from(result.rows_affected()).unwrap_or(0))
500 }
501}
502
503#[tracing::instrument(name = "memory.reasoning.self_judge", skip(provider, messages), fields(n = messages.len()))]
528pub async fn run_self_judge(
529 provider: &AnyProvider,
530 messages: &[Message],
531 extraction_timeout: Duration,
532) -> Option<SelfJudgeOutcome> {
533 if messages.is_empty() {
534 return None;
535 }
536
537 let user_prompt = build_transcript_prompt(messages);
538
539 let llm_messages = [
540 Message::from_legacy(Role::System, SELF_JUDGE_SYSTEM),
541 Message::from_legacy(Role::User, user_prompt),
542 ];
543
544 let response = match timeout(extraction_timeout, provider.chat(&llm_messages)).await {
545 Ok(Ok(text)) => text,
546 Ok(Err(e)) => {
547 tracing::warn!(error = %e, "reasoning: self-judge LLM call failed");
548 return None;
549 }
550 Err(_) => {
551 tracing::warn!("reasoning: self-judge timed out");
552 return None;
553 }
554 };
555
556 parse_self_judge_response(&response)
557}
558
559#[tracing::instrument(name = "memory.reasoning.distill", skip(provider, reasoning_chain))]
579pub async fn distill_strategy(
580 provider: &AnyProvider,
581 outcome: Outcome,
582 reasoning_chain: &str,
583 distill_timeout: Duration,
584) -> Option<String> {
585 if reasoning_chain.is_empty() {
586 return None;
587 }
588
589 let user_prompt = format!(
590 "Outcome: {}\n\nReasoning chain:\n{reasoning_chain}",
591 outcome.as_str()
592 );
593
594 let llm_messages = [
595 Message::from_legacy(Role::System, DISTILL_SYSTEM),
596 Message::from_legacy(Role::User, user_prompt),
597 ];
598
599 let response = match timeout(distill_timeout, provider.chat(&llm_messages)).await {
600 Ok(Ok(text)) => text,
601 Ok(Err(e)) => {
602 tracing::warn!(error = %e, "reasoning: distillation LLM call failed");
603 return None;
604 }
605 Err(_) => {
606 tracing::warn!("reasoning: distillation timed out");
607 return None;
608 }
609 };
610
611 let trimmed = trim_to_three_sentences(&response);
612 if trimmed.is_empty() {
613 None
614 } else {
615 Some(trimmed)
616 }
617}
618
619#[derive(Debug, Clone, Copy)]
623pub struct ProcessTurnConfig {
624 pub store_limit: usize,
626 pub extraction_timeout: Duration,
628 pub distill_timeout: Duration,
630 pub self_judge_window: usize,
634 pub min_assistant_chars: usize,
637}
638
639#[tracing::instrument(name = "memory.reasoning.process_turn", skip_all)]
650pub async fn process_turn(
651 memory: &ReasoningMemory,
652 extract_provider: &AnyProvider,
653 distill_provider: &AnyProvider,
654 embed_provider: &AnyProvider,
655 messages: &[Message],
656 cfg: ProcessTurnConfig,
657) -> Result<(), MemoryError> {
658 let ProcessTurnConfig {
659 store_limit,
660 extraction_timeout,
661 distill_timeout,
662 self_judge_window,
663 min_assistant_chars,
664 } = cfg;
665
666 let judge_messages = if messages.len() > self_judge_window {
669 &messages[messages.len() - self_judge_window..]
670 } else {
671 messages
672 };
673
674 let last_assistant_chars = judge_messages
676 .iter()
677 .rev()
678 .find(|m| m.role == Role::Assistant)
679 .map_or(0, |m| m.content.len());
680 if last_assistant_chars < min_assistant_chars {
681 return Ok(());
682 }
683
684 let Some(outcome) = run_self_judge(extract_provider, judge_messages, extraction_timeout).await
685 else {
686 return Ok(());
687 };
688
689 let outcome_enum = if outcome.success {
690 Outcome::Success
691 } else {
692 Outcome::Failure
693 };
694
695 let Some(summary) = distill_strategy(
696 distill_provider,
697 outcome_enum,
698 &outcome.reasoning_chain,
699 distill_timeout,
700 )
701 .await
702 else {
703 return Ok(());
704 };
705
706 let embed_input = format!("{}\n{}", outcome.task_hint, summary);
708 let embedding = match tokio::time::timeout(
709 std::time::Duration::from_secs(5),
710 embed_provider.embed(&embed_input),
711 )
712 .await
713 {
714 Ok(Ok(v)) => v,
715 Ok(Err(e)) => {
716 tracing::warn!(error = %e, "reasoning: embedding failed — strategy not stored");
717 return Ok(());
718 }
719 Err(_) => {
720 tracing::warn!("reasoning: embed timed out — strategy not stored");
721 return Ok(());
722 }
723 };
724
725 let id = uuid::Uuid::new_v4().to_string();
726 let strategy = ReasoningStrategy {
727 id,
728 summary,
729 outcome: outcome_enum,
730 task_hint: outcome.task_hint,
731 created_at: 0, last_used_at: 0,
733 use_count: 0,
734 embedded_at: None,
735 };
736
737 let count_before = memory.count().await.unwrap_or(0);
742
743 if let Err(e) = memory.insert(&strategy, embedding).await {
744 tracing::warn!(error = %e, "reasoning: insert failed");
745 return Ok(());
746 }
747
748 if count_before >= store_limit
749 && let Err(e) = memory.evict_lru(store_limit).await
750 {
751 tracing::warn!(error = %e, "reasoning: evict_lru failed");
752 }
753
754 Ok(())
755}
756
757const MAX_TRANSCRIPT_MESSAGE_CHARS: usize = 2000;
764
765fn build_transcript_prompt(messages: &[Message]) -> String {
771 let mut prompt = String::from("Agent turn messages:\n");
772 for (i, msg) in messages.iter().enumerate() {
773 use std::fmt::Write as _;
774 let role = format!("{:?}", msg.role);
775 let content: std::borrow::Cow<str> =
777 if msg.content.chars().count() > MAX_TRANSCRIPT_MESSAGE_CHARS {
778 msg.content
779 .char_indices()
780 .nth(MAX_TRANSCRIPT_MESSAGE_CHARS)
781 .map_or(msg.content.as_str().into(), |(byte_idx, _)| {
782 msg.content[..byte_idx].into()
783 })
784 } else {
785 msg.content.as_str().into()
786 };
787 let _ = writeln!(prompt, "[{}] {}: {}", i + 1, role, content);
788 }
789 prompt.push_str("\nEvaluate this turn and return JSON.");
790 prompt
791}
792
793fn parse_self_judge_response(response: &str) -> Option<SelfJudgeOutcome> {
798 let stripped = response
800 .trim()
801 .trim_start_matches("```json")
802 .trim_start_matches("```")
803 .trim_end_matches("```")
804 .trim();
805
806 if let Ok(v) = serde_json::from_str::<SelfJudgeOutcome>(stripped) {
807 return Some(v);
808 }
809
810 if let (Some(start), Some(end)) = (stripped.find('{'), stripped.rfind('}'))
812 && end > start
813 && let Ok(v) = serde_json::from_str::<SelfJudgeOutcome>(&stripped[start..=end])
814 {
815 return Some(v);
816 }
817
818 tracing::warn!(
819 "reasoning: failed to parse self-judge response (len={}): {:.200}",
820 response.len(),
821 response
822 );
823 None
824}
825
826fn trim_to_three_sentences(text: &str) -> String {
831 const MAX_CHARS: usize = 512;
832 const MAX_SENTENCES: usize = 3;
833
834 let text = text.trim();
835 let mut sentence_ends: Vec<usize> = Vec::new();
836 let chars: Vec<char> = text.chars().collect();
837 let len = chars.len();
838
839 for (i, &ch) in chars.iter().enumerate() {
840 if matches!(ch, '.' | '!' | '?') {
841 let next_is_boundary = i + 1 >= len || chars[i + 1].is_whitespace();
842 if next_is_boundary {
843 sentence_ends.push(i + 1); if sentence_ends.len() >= MAX_SENTENCES {
845 break;
846 }
847 }
848 }
849 }
850
851 let char_limit = if let Some(&end) = sentence_ends.last() {
852 end.min(MAX_CHARS)
853 } else {
854 text.chars().count().min(MAX_CHARS)
855 };
856
857 let result: String = text.chars().take(char_limit).collect();
858 match result.char_indices().nth(MAX_CHARS) {
860 Some((byte_idx, _)) => result[..byte_idx].to_owned(),
861 None => result,
862 }
863}
864
865#[cfg(test)]
866mod tests {
867 use super::*;
868
869 #[test]
872 fn outcome_as_str_round_trip() {
873 assert_eq!(Outcome::Success.as_str(), "success");
874 assert_eq!(Outcome::Failure.as_str(), "failure");
875 }
876
877 #[test]
878 fn outcome_from_str_success() {
879 assert_eq!(Outcome::from_str("success").unwrap(), Outcome::Success);
880 }
881
882 #[test]
883 fn outcome_from_str_failure() {
884 assert_eq!(Outcome::from_str("failure").unwrap(), Outcome::Failure);
885 }
886
887 #[test]
888 fn outcome_from_str_unknown_defaults_to_failure() {
889 assert_eq!(Outcome::from_str("partial").unwrap(), Outcome::Failure);
891 }
892
893 #[test]
896 fn parse_direct_json() {
897 let json = r#"{"success":true,"reasoning_chain":"tried X","task_hint":"do Y"}"#;
898 let outcome = parse_self_judge_response(json).unwrap();
899 assert!(outcome.success);
900 assert_eq!(outcome.reasoning_chain, "tried X");
901 assert_eq!(outcome.task_hint, "do Y");
902 }
903
904 #[test]
905 fn parse_json_with_markdown_fences() {
906 let response =
907 "```json\n{\"success\":false,\"reasoning_chain\":\"r\",\"task_hint\":\"t\"}\n```";
908 let outcome = parse_self_judge_response(response).unwrap();
909 assert!(!outcome.success);
910 }
911
912 #[test]
913 fn parse_json_embedded_in_prose() {
914 let response = r#"Here is the evaluation: {"success":true,"reasoning_chain":"chain","task_hint":"hint"} — done."#;
915 let outcome = parse_self_judge_response(response).unwrap();
916 assert!(outcome.success);
917 }
918
919 #[test]
920 fn parse_invalid_returns_none() {
921 let outcome = parse_self_judge_response("not json at all");
922 assert!(outcome.is_none());
923 }
924
925 #[test]
928 fn trim_three_sentences_short_text() {
929 let text = "One. Two. Three.";
930 assert_eq!(trim_to_three_sentences(text), "One. Two. Three.");
931 }
932
933 #[test]
934 fn trim_three_sentences_truncates_at_third() {
935 let text = "One. Two. Three. Four. Five.";
936 let result = trim_to_three_sentences(text);
937 assert!(result.ends_with("Three."), "got: {result}");
938 assert!(!result.contains("Four"));
939 }
940
941 #[test]
942 fn trim_three_sentences_hard_cap() {
943 let long: String = "x".repeat(600);
945 let result = trim_to_three_sentences(&long);
946 assert!(result.chars().count() <= 512);
947 }
948
949 #[test]
950 fn trim_three_sentences_empty() {
951 assert_eq!(trim_to_three_sentences(" "), "");
952 }
953
954 async fn make_test_pool() -> DbPool {
957 let pool = sqlx::SqlitePool::connect(":memory:").await.unwrap();
958 sqlx::query(
959 "CREATE TABLE reasoning_strategies (
960 id TEXT PRIMARY KEY NOT NULL,
961 summary TEXT NOT NULL,
962 outcome TEXT NOT NULL,
963 task_hint TEXT NOT NULL,
964 created_at INTEGER NOT NULL DEFAULT (unixepoch('now')),
965 last_used_at INTEGER NOT NULL DEFAULT (unixepoch('now')),
966 use_count INTEGER NOT NULL DEFAULT 0,
967 embedded_at INTEGER
968 )",
969 )
970 .execute(&pool)
971 .await
972 .unwrap();
973 pool
974 }
975
976 fn make_strategy(id: &str) -> ReasoningStrategy {
977 ReasoningStrategy {
978 id: id.to_owned(),
979 summary: format!("Summary for {id}"),
980 outcome: Outcome::Success,
981 task_hint: format!("Task hint for {id}"),
982 created_at: 0,
983 last_used_at: 0,
984 use_count: 0,
985 embedded_at: None,
986 }
987 }
988
989 #[tokio::test]
990 async fn insert_and_fetch_by_ids() {
991 let pool = make_test_pool().await;
992 let mem = ReasoningMemory::new(pool, None);
993
994 let s = make_strategy("abc-123");
995 mem.insert(&s, vec![]).await.unwrap();
996
997 let rows = mem.fetch_by_ids(&["abc-123".to_owned()]).await.unwrap();
998 assert_eq!(rows.len(), 1);
999 assert_eq!(rows[0].id, "abc-123");
1000 assert_eq!(rows[0].outcome, Outcome::Success);
1001 }
1002
1003 #[tokio::test]
1004 async fn mark_used_increments_count() {
1005 let pool = make_test_pool().await;
1006 let mem = ReasoningMemory::new(pool, None);
1007
1008 let s = make_strategy("mark-1");
1009 mem.insert(&s, vec![]).await.unwrap();
1010 mem.mark_used(&["mark-1".to_owned()]).await.unwrap();
1011 mem.mark_used(&["mark-1".to_owned()]).await.unwrap();
1012
1013 let rows = mem.fetch_by_ids(&["mark-1".to_owned()]).await.unwrap();
1014 assert_eq!(rows[0].use_count, 2);
1015 }
1016
1017 #[tokio::test]
1018 async fn mark_used_empty_is_noop() {
1019 let pool = make_test_pool().await;
1020 let mem = ReasoningMemory::new(pool, None);
1021 mem.mark_used(&[]).await.unwrap();
1023 }
1024
1025 #[tokio::test]
1026 async fn count_returns_correct_total() {
1027 let pool = make_test_pool().await;
1028 let mem = ReasoningMemory::new(pool, None);
1029
1030 for i in 0..5 {
1031 mem.insert(&make_strategy(&format!("s{i}")), vec![])
1032 .await
1033 .unwrap();
1034 }
1035
1036 assert_eq!(mem.count().await.unwrap(), 5);
1037 }
1038
1039 #[tokio::test]
1040 async fn evict_lru_cold_rows() {
1041 let pool = make_test_pool().await;
1042 let mem = ReasoningMemory::new(pool, None);
1043
1044 for i in 0..5 {
1046 mem.insert(&make_strategy(&format!("cold-{i}")), vec![])
1047 .await
1048 .unwrap();
1049 }
1050
1051 let deleted = mem.evict_lru(3).await.unwrap();
1053 assert_eq!(deleted, 2);
1054 assert_eq!(mem.count().await.unwrap(), 3);
1055 }
1056
1057 #[tokio::test]
1058 async fn evict_lru_respects_hot_rows_under_ceiling() {
1059 let pool = make_test_pool().await;
1060 let mem = ReasoningMemory::new(pool.clone(), None);
1061
1062 for i in 0..5 {
1064 let id = format!("hot-{i}");
1065 mem.insert(&make_strategy(&id), vec![]).await.unwrap();
1066 let ids: Vec<String> = (0..11).map(|_| id.clone()).collect();
1068 for chunk_ids in ids.chunks(1) {
1069 mem.mark_used(chunk_ids).await.unwrap();
1070 }
1071 }
1072
1073 let deleted = mem.evict_lru(3).await.unwrap();
1075 assert_eq!(deleted, 0);
1076 assert_eq!(mem.count().await.unwrap(), 5);
1077 }
1078
1079 #[tokio::test]
1080 async fn evict_lru_hard_ceiling_forces_deletion() {
1081 let pool = make_test_pool().await;
1082 let mem = ReasoningMemory::new(pool.clone(), None);
1083
1084 for i in 0..7 {
1086 let id = format!("hot2-{i}");
1087 mem.insert(&make_strategy(&id), vec![]).await.unwrap();
1088 for _ in 0..=HOT_STRATEGY_USE_COUNT {
1090 mem.mark_used(std::slice::from_ref(&id)).await.unwrap();
1091 }
1092 }
1093
1094 let deleted = mem.evict_lru(3).await.unwrap();
1095 assert!(deleted > 0, "expected forced deletion");
1096 let remaining = mem.count().await.unwrap();
1097 assert_eq!(remaining, 3, "should be trimmed to store_limit");
1098 }
1099
1100 #[tokio::test]
1101 async fn evict_lru_no_op_when_under_limit() {
1102 let pool = make_test_pool().await;
1103 let mem = ReasoningMemory::new(pool, None);
1104
1105 for i in 0..3 {
1106 mem.insert(&make_strategy(&format!("s{i}")), vec![])
1107 .await
1108 .unwrap();
1109 }
1110
1111 let deleted = mem.evict_lru(10).await.unwrap();
1113 assert_eq!(deleted, 0);
1114 }
1115
1116 #[tokio::test]
1119 async fn mark_used_chunked_over_490_ids() {
1120 let pool = make_test_pool().await;
1121 let mem = ReasoningMemory::new(pool, None);
1122
1123 for i in 0..500usize {
1125 mem.insert(&make_strategy(&format!("chunked-{i}")), vec![])
1126 .await
1127 .unwrap();
1128 }
1129
1130 let ids: Vec<String> = (0..500usize).map(|i| format!("chunked-{i}")).collect();
1131 mem.mark_used(&ids).await.unwrap();
1132
1133 let first = mem.fetch_by_ids(&[ids[0].clone()]).await.unwrap();
1135 let over_chunk = mem.fetch_by_ids(&[ids[490].clone()]).await.unwrap();
1136 assert_eq!(first[0].use_count, 1, "first id should have use_count = 1");
1137 assert_eq!(
1138 over_chunk[0].use_count, 1,
1139 "id past the chunk boundary should have use_count = 1"
1140 );
1141 }
1142
1143 #[tokio::test]
1146 async fn run_self_judge_malformed_json_returns_none() {
1147 use zeph_llm::any::AnyProvider;
1148 use zeph_llm::mock::MockProvider;
1149
1150 let provider = AnyProvider::Mock(MockProvider::with_responses(vec![
1152 "This is not JSON at all.".to_string(),
1153 ]));
1154 let msgs = vec![Message::from_legacy(Role::User, "hello")];
1155 let result = run_self_judge(&provider, &msgs, std::time::Duration::from_secs(5)).await;
1156 assert!(result.is_none(), "malformed LLM response must return None");
1157 }
1158
1159 #[tokio::test]
1162 async fn distill_strategy_truncates_to_three_sentences() {
1163 use zeph_llm::any::AnyProvider;
1164 use zeph_llm::mock::MockProvider;
1165
1166 let long_response = "One. Two. Three. Four. Five.";
1167 let provider = AnyProvider::Mock(MockProvider::with_responses(vec![
1168 long_response.to_string(),
1169 ]));
1170 let result = distill_strategy(
1171 &provider,
1172 Outcome::Success,
1173 "chain here",
1174 std::time::Duration::from_secs(5),
1175 )
1176 .await
1177 .unwrap();
1178 assert!(result.ends_with("Three."), "got: {result}");
1179 assert!(
1180 !result.contains("Four"),
1181 "should not contain 4th sentence: {result}"
1182 );
1183 }
1184
1185 #[tokio::test]
1188 async fn process_turn_with_empty_messages_is_noop() {
1189 use zeph_llm::any::AnyProvider;
1190 use zeph_llm::mock::MockProvider;
1191
1192 let pool = make_test_pool().await;
1193 let mem = ReasoningMemory::new(pool, None);
1194 let provider = AnyProvider::Mock(MockProvider::default());
1197 let cfg = ProcessTurnConfig {
1198 store_limit: 100,
1199 extraction_timeout: std::time::Duration::from_secs(1),
1200 distill_timeout: std::time::Duration::from_secs(1),
1201 self_judge_window: 2,
1202 min_assistant_chars: 0,
1203 };
1204 let result = process_turn(&mem, &provider, &provider, &provider, &[], cfg).await;
1205 assert!(
1206 result.is_ok(),
1207 "process_turn with empty messages must succeed"
1208 );
1209 assert_eq!(
1210 mem.count().await.unwrap(),
1211 0,
1212 "no strategies should be stored"
1213 );
1214 }
1215}