Skip to main content

zeph_memory/
reasoning.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! `ReasoningBank`: distilled reasoning strategy memory (#3342).
5//!
6//! After each completed agent turn a three-stage async pipeline runs off the hot path:
7//!
8//! 1. **Self-judge** ([`run_self_judge`]) — a fast LLM evaluates success/failure and
9//!    extracts the key reasoning steps.
10//! 2. **Distillation** ([`distill_strategy`]) — a strategy summary (≤ 3 sentences) is
11//!    generated from the reasoning chain, capturing the transferable principle.
12//! 3. **Storage** ([`ReasoningMemory::insert`]) — the summary is written to `SQLite`
13//!    and, when Qdrant is available, embedded and indexed for vector retrieval.
14//!
15//! At context-build time [`ReasoningMemory::retrieve_by_embedding`] fetches top-k
16//! strategies by embedding similarity. The caller (in `zeph-context`) calls
17//! [`ReasoningMemory::mark_used`] only for strategies actually injected into the prompt,
18//! after budget truncation (C4 split from architect plan).
19//!
20//! # LRU eviction
21//!
22//! [`ReasoningMemory::evict_lru`] protects rows with `use_count > HOT_STRATEGY_USE_COUNT`
23//! (default 10) from normal eviction. When all rows are hot and the table exceeds
24//! `2 × store_limit`, a forced eviction pass deletes the oldest rows unconditionally
25//! and emits a `warn!` so operators can tune `store_limit` upward.
26//!
27//! # LRU eviction race note
28//!
29//! Two concurrent turns may race on the count check in `evict_lru`. Either both evict
30//! (over-eviction by at most `top_k` rows) or neither. This is acceptable for MVP —
31//! the table remains bounded.
32
33use std::str::FromStr;
34use std::time::Duration;
35
36use serde::Deserialize;
37use tokio::time::timeout;
38use zeph_db::{ActiveDialect, DbPool, placeholder_list};
39use zeph_llm::any::AnyProvider;
40use zeph_llm::provider::{LlmProvider as _, Message, Role};
41
42use crate::error::MemoryError;
43use crate::vector_store::VectorStore;
44
45/// Minimum retrieval count to protect a strategy from normal LRU eviction.
46///
47/// Strategies with `use_count > HOT_STRATEGY_USE_COUNT` are skipped during normal
48/// cold-eviction and only removed when the table exceeds `2 × store_limit`.
49const HOT_STRATEGY_USE_COUNT: i64 = 10;
50
51/// Maximum ids per `SQLite` `WHERE id IN (...)` bind list (`SQLite` variable limit is 999).
52const MAX_IDS_PER_QUERY: usize = 490;
53
54/// System prompt for the self-judge LLM step.
55///
56/// Instructs the LLM to evaluate success/failure and extract the reasoning chain
57/// as structured JSON matching [`SelfJudgeOutcome`].
58const SELF_JUDGE_SYSTEM: &str = "\
59You are a task outcome evaluator. Given an agent turn transcript, analyze the conversation and determine:
601. Did the agent successfully complete the user's request? (true/false)
612. Extract the key reasoning steps the agent took (reasoning chain).
623. Summarize the task in one sentence (task hint).
63
64Respond ONLY with valid JSON, no markdown fences, no prose:
65{\"success\": bool, \"reasoning_chain\": \"string\", \"task_hint\": \"string\"}";
66
67/// System prompt for the distillation LLM step.
68///
69/// Instructs the LLM to compress a reasoning chain into a short, generalizable strategy.
70const DISTILL_SYSTEM: &str = "\
71You are a strategy distiller. Given a reasoning chain from an agent turn, distill it into \
72a short generalizable strategy (at most 3 sentences) that could help an agent facing a similar \
73task. Focus on the transferable principle, not the specific instance. \
74Respond with the strategy text only — no headers, no lists, no markdown.";
75
76/// Outcome of a reasoning strategy: whether the agent succeeded or failed.
77///
78/// Stored as a `TEXT NOT NULL` column (`"success"` or `"failure"`).
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub enum Outcome {
81    /// The agent successfully completed the task.
82    Success,
83    /// The agent failed to complete the task.
84    Failure,
85}
86
87impl Outcome {
88    /// Returns the canonical string representation stored in the database.
89    #[must_use]
90    pub fn as_str(self) -> &'static str {
91        match self {
92            Outcome::Success => "success",
93            Outcome::Failure => "failure",
94        }
95    }
96}
97
98/// Error returned when parsing an [`Outcome`] from a string fails.
99#[derive(Debug, thiserror::Error)]
100#[error("unknown outcome: {0}")]
101pub struct OutcomeParseError(String);
102
103impl FromStr for Outcome {
104    type Err = OutcomeParseError;
105
106    fn from_str(s: &str) -> Result<Self, Self::Err> {
107        match s {
108            "success" => Ok(Outcome::Success),
109            "failure" => Ok(Outcome::Failure),
110            other => {
111                tracing::warn!(
112                    value = other,
113                    "reasoning: unknown outcome, defaulting to Failure"
114                );
115                Ok(Outcome::Failure)
116            }
117        }
118    }
119}
120
121/// A distilled reasoning strategy row from the `reasoning_strategies` table.
122///
123/// Constructed after a successful self-judge + distillation pipeline run.
124/// Persisted in `SQLite` and (when Qdrant is available) indexed as a vector embedding.
125#[derive(Debug, Clone)]
126pub struct ReasoningStrategy {
127    /// UUID v4 primary key.
128    pub id: String,
129    /// Distilled strategy summary (≤ 3 sentences, ≤ 512 chars).
130    pub summary: String,
131    /// Whether the agent succeeded or failed on the source turn.
132    pub outcome: Outcome,
133    /// One-sentence description of the task that produced this strategy.
134    pub task_hint: String,
135    /// Unix timestamp (seconds) when this strategy was created.
136    pub created_at: i64,
137    /// Unix timestamp (seconds) of the last retrieval.
138    pub last_used_at: i64,
139    /// Number of times this strategy has been injected into context.
140    pub use_count: i64,
141    /// Unix timestamp (seconds) when the Qdrant embedding was created.
142    ///
143    /// `None` means this row has not been embedded yet (Qdrant was unavailable at insert time).
144    pub embedded_at: Option<i64>,
145}
146
147/// Parsed response from the self-judge LLM call.
148///
149/// Deserialized from the LLM JSON response in [`run_self_judge`].
150/// The `success` field drives [`Outcome`] selection; `reasoning_chain` and `task_hint`
151/// are forwarded to the distillation step.
152#[derive(Debug, Deserialize)]
153pub struct SelfJudgeOutcome {
154    /// Whether the agent successfully completed the task.
155    pub success: bool,
156    /// Key reasoning steps the agent took, as free-form text.
157    pub reasoning_chain: String,
158    /// One-sentence summary of the task.
159    pub task_hint: String,
160}
161
162/// SQLite-backed store for distilled reasoning strategies.
163///
164/// Attach to [`crate::semantic::SemanticMemory`] via `with_reasoning`.
165/// All write operations are best-effort: `SQLite` errors are propagated as
166/// [`MemoryError`], Qdrant failures are logged and silently ignored.
167pub struct ReasoningMemory {
168    pool: DbPool,
169    /// Optional vector store for embedding-similarity retrieval.
170    ///
171    /// `None` when Qdrant is unavailable; falls back to returning empty results.
172    vector_store: Option<std::sync::Arc<dyn VectorStore>>,
173}
174
175/// Qdrant collection name used for reasoning-strategy embeddings.
176pub const REASONING_COLLECTION: &str = "reasoning_strategies";
177
178impl ReasoningMemory {
179    /// Create a new `ReasoningMemory` backed by the given `SQLite` pool.
180    ///
181    /// Pass `vector_store = Some(arc)` to enable embedding-similarity retrieval via Qdrant.
182    /// When `None`, [`Self::retrieve_by_embedding`] always returns an empty vec.
183    ///
184    /// # Examples
185    ///
186    /// ```no_run
187    /// use zeph_memory::reasoning::ReasoningMemory;
188    ///
189    /// async fn demo(pool: zeph_db::DbPool) {
190    ///     let memory = ReasoningMemory::new(pool, None);
191    /// }
192    /// ```
193    #[must_use]
194    pub fn new(pool: DbPool, vector_store: Option<std::sync::Arc<dyn VectorStore>>) -> Self {
195        Self { pool, vector_store }
196    }
197
198    /// Insert a new strategy into `SQLite`.
199    ///
200    /// When a `vector_store` is configured, the strategy is also upserted into
201    /// the Qdrant `reasoning_strategies` collection using the provided `embedding`.
202    /// Qdrant failures are logged at `warn` level and do not fail the insert.
203    ///
204    /// # Errors
205    ///
206    /// Returns an error if the `SQLite` insert fails.
207    #[tracing::instrument(name = "memory.reasoning.insert", skip(self, embedding), fields(id = %strategy.id))]
208    pub async fn insert(
209        &self,
210        strategy: &ReasoningStrategy,
211        embedding: Vec<f32>,
212    ) -> Result<(), MemoryError> {
213        let epoch_now = <ActiveDialect as zeph_db::dialect::Dialect>::EPOCH_NOW;
214        let raw = format!(
215            "INSERT INTO reasoning_strategies \
216             (id, summary, outcome, task_hint, created_at, last_used_at, use_count, embedded_at) \
217             VALUES (?, ?, ?, ?, {epoch_now}, {epoch_now}, 0, NULL) \
218             ON CONFLICT (id) DO UPDATE SET \
219               summary = EXCLUDED.summary, \
220               outcome = EXCLUDED.outcome, \
221               task_hint = EXCLUDED.task_hint, \
222               last_used_at = EXCLUDED.last_used_at, \
223               embedded_at = EXCLUDED.embedded_at"
224        );
225        let sql = zeph_db::rewrite_placeholders(&raw);
226        zeph_db::query(&sql)
227            .bind(&strategy.id)
228            .bind(&strategy.summary)
229            .bind(strategy.outcome.as_str())
230            .bind(&strategy.task_hint)
231            .execute(&self.pool)
232            .await?;
233
234        // Qdrant upsert — best effort: SQLite row already written.
235        if let Some(ref vs) = self.vector_store {
236            let point = crate::vector_store::VectorPoint {
237                id: strategy.id.clone(),
238                vector: embedding,
239                payload: std::collections::HashMap::from([
240                    (
241                        "outcome".to_owned(),
242                        serde_json::Value::String(strategy.outcome.as_str().to_owned()),
243                    ),
244                    (
245                        "task_hint".to_owned(),
246                        serde_json::Value::String(strategy.task_hint.clone()),
247                    ),
248                ]),
249            };
250            if let Err(e) = vs.upsert(REASONING_COLLECTION, vec![point]).await {
251                tracing::warn!(error = %e, id = %strategy.id, "reasoning: Qdrant upsert failed — SQLite-only mode");
252            } else {
253                // Mark embedded_at on success.
254                let update_sql = zeph_db::rewrite_placeholders(&format!(
255                    "UPDATE reasoning_strategies SET embedded_at = {epoch_now} WHERE id = ?"
256                ));
257                if let Err(e) = zeph_db::query(&update_sql)
258                    .bind(&strategy.id)
259                    .execute(&self.pool)
260                    .await
261                {
262                    tracing::warn!(error = %e, "reasoning: failed to set embedded_at");
263                }
264            }
265        }
266
267        tracing::debug!(id = %strategy.id, outcome = strategy.outcome.as_str(), "reasoning: strategy inserted");
268        Ok(())
269    }
270
271    /// Retrieve up to `top_k` strategies by embedding similarity.
272    ///
273    /// This method is **pure** — it does not update `use_count` or `last_used_at`.
274    /// Call [`Self::mark_used`] with the ids of strategies actually injected into the
275    /// prompt (after budget truncation) to maintain accurate retrieval bookkeeping.
276    ///
277    /// Returns an empty vec when no vector store is configured.
278    ///
279    /// # Errors
280    ///
281    /// Returns an error if the Qdrant search or `SQLite` fetch fails.
282    #[tracing::instrument(
283        name = "memory.reasoning.retrieve_by_embedding",
284        skip(self, embedding),
285        fields(top_k)
286    )]
287    pub async fn retrieve_by_embedding(
288        &self,
289        embedding: &[f32],
290        top_k: u64,
291    ) -> Result<Vec<ReasoningStrategy>, MemoryError> {
292        let Some(ref vs) = self.vector_store else {
293            return Ok(Vec::new());
294        };
295
296        let scored = vs
297            .search(REASONING_COLLECTION, embedding.to_vec(), top_k, None)
298            .await?;
299
300        if scored.is_empty() {
301            return Ok(Vec::new());
302        }
303
304        let ids: Vec<String> = scored.into_iter().map(|p| p.id).collect();
305        self.fetch_by_ids(&ids).await
306    }
307
308    /// Increment `use_count` and update `last_used_at` for each id in the list.
309    ///
310    /// Safe to call with an empty slice — no SQL is issued.
311    /// The list is chunked into batches of [`MAX_IDS_PER_QUERY`] to respect `SQLite`'s
312    /// variable limit.
313    ///
314    /// # Errors
315    ///
316    /// Returns an error if the database update fails.
317    #[tracing::instrument(name = "memory.reasoning.mark_used", skip(self), fields(n = ids.len()))]
318    pub async fn mark_used(&self, ids: &[String]) -> Result<(), MemoryError> {
319        if ids.is_empty() {
320            return Ok(());
321        }
322
323        let epoch_now = <ActiveDialect as zeph_db::dialect::Dialect>::EPOCH_NOW;
324        for chunk in ids.chunks(MAX_IDS_PER_QUERY) {
325            let ph = placeholder_list(1, chunk.len());
326            // Note: placeholder_list already generates ?1,?2,... (SQLite) or $1,$2,... (postgres).
327            // Do NOT call rewrite_placeholders here — that would corrupt ?1 into $11.
328            let sql = format!(
329                "UPDATE reasoning_strategies \
330                 SET use_count = use_count + 1, last_used_at = {epoch_now} \
331                 WHERE id IN ({ph})"
332            );
333            let mut q = zeph_db::query(&sql);
334            for id in chunk {
335                q = q.bind(id.as_str());
336            }
337            q.execute(&self.pool).await?;
338        }
339
340        Ok(())
341    }
342
343    /// Evict strategies when the table exceeds `store_limit`.
344    ///
345    /// **Normal path**: delete rows with `use_count <= HOT_STRATEGY_USE_COUNT`, oldest
346    /// first, until the table returns to `store_limit`.
347    ///
348    /// **Saturation path**: when the normal path deletes nothing AND the table exceeds
349    /// `2 × store_limit`, bypass hot-row protection and delete oldest rows regardless of
350    /// `use_count`. Emits a `warn!` with the eviction count so operators can tune
351    /// `store_limit` upward or lower the hot threshold.
352    ///
353    /// Returns the number of rows deleted.
354    ///
355    /// # Errors
356    ///
357    /// Returns an error if any database operation fails.
358    #[tracing::instrument(name = "memory.reasoning.evict_lru", skip(self), fields(store_limit))]
359    pub async fn evict_lru(&self, store_limit: usize) -> Result<usize, MemoryError> {
360        let count = self.count().await?;
361        if count <= store_limit {
362            return Ok(0);
363        }
364
365        let over_by = count - store_limit;
366        let deleted_cold = self.delete_oldest_cold(over_by).await?;
367        if deleted_cold > 0 {
368            // Also delete from Qdrant best-effort (ids not tracked here — full resync on recovery).
369            tracing::debug!(
370                deleted = deleted_cold,
371                count,
372                "reasoning: evicted cold strategies"
373            );
374            return Ok(deleted_cold);
375        }
376
377        // All rows over limit are hot. Check hard ceiling.
378        let hard_ceiling = store_limit.saturating_mul(2);
379        if count <= hard_ceiling {
380            tracing::debug!(
381                count,
382                store_limit,
383                "reasoning: hot saturation — growth allowed under 2x ceiling"
384            );
385            return Ok(0);
386        }
387
388        // Hard ceiling breached: force-evict oldest rows unconditionally.
389        let forced = count - store_limit;
390        let deleted_forced = self.delete_oldest_unconditional(forced).await?;
391        tracing::warn!(
392            deleted = deleted_forced,
393            count,
394            hard_ceiling,
395            "reasoning: hard-ceiling eviction — evicted hot strategies; consider raising store_limit"
396        );
397
398        Ok(deleted_forced)
399    }
400
401    /// Return the total number of rows in `reasoning_strategies`.
402    ///
403    /// # Errors
404    ///
405    /// Returns an error if the database query fails.
406    pub async fn count(&self) -> Result<usize, MemoryError> {
407        let row: (i64,) = zeph_db::query_as("SELECT COUNT(*) FROM reasoning_strategies")
408            .fetch_one(&self.pool)
409            .await?;
410        Ok(usize::try_from(row.0.max(0)).unwrap_or(0))
411    }
412
413    // ── private helpers ───────────────────────────────────────────────────────
414
415    /// Fetch strategy rows by their ids in a single `WHERE id IN (...)` query.
416    pub(crate) async fn fetch_by_ids(
417        &self,
418        ids: &[String],
419    ) -> Result<Vec<ReasoningStrategy>, MemoryError> {
420        if ids.is_empty() {
421            return Ok(Vec::new());
422        }
423
424        let mut strategies = Vec::with_capacity(ids.len());
425        for chunk in ids.chunks(MAX_IDS_PER_QUERY) {
426            let ph = placeholder_list(1, chunk.len());
427            // Note: placeholder_list generates DB-specific ?N/$N syntax — do NOT rewite.
428            let sql = format!(
429                "SELECT id, summary, outcome, task_hint, created_at, last_used_at, use_count, embedded_at \
430                 FROM reasoning_strategies WHERE id IN ({ph})"
431            );
432            let mut q = zeph_db::query_as::<
433                _,
434                (String, String, String, String, i64, i64, i64, Option<i64>),
435            >(&sql);
436            for id in chunk {
437                q = q.bind(id.as_str());
438            }
439            let rows = q.fetch_all(&self.pool).await?;
440            for (
441                id,
442                summary,
443                outcome_str,
444                task_hint,
445                created_at,
446                last_used_at,
447                use_count,
448                embedded_at,
449            ) in rows
450            {
451                let outcome = Outcome::from_str(&outcome_str).unwrap_or(Outcome::Failure);
452                strategies.push(ReasoningStrategy {
453                    id,
454                    summary,
455                    outcome,
456                    task_hint,
457                    created_at,
458                    last_used_at,
459                    use_count,
460                    embedded_at,
461                });
462            }
463        }
464
465        Ok(strategies)
466    }
467
468    /// Delete up to `n` cold rows (`use_count <= HOT_STRATEGY_USE_COUNT`), oldest first.
469    ///
470    /// Returns the number of deleted rows.
471    async fn delete_oldest_cold(&self, n: usize) -> Result<usize, MemoryError> {
472        let limit = i64::try_from(n).unwrap_or(i64::MAX);
473        // Use plain `?` + rewrite_placeholders so postgres gets `$1`.
474        let raw = format!(
475            "DELETE FROM reasoning_strategies \
476             WHERE id IN ( \
477               SELECT id FROM reasoning_strategies \
478               WHERE use_count <= {HOT_STRATEGY_USE_COUNT} \
479               ORDER BY last_used_at ASC LIMIT ? \
480             )"
481        );
482        let sql = zeph_db::rewrite_placeholders(&raw);
483        let result = zeph_db::query(&sql).bind(limit).execute(&self.pool).await?;
484        Ok(usize::try_from(result.rows_affected()).unwrap_or(0))
485    }
486
487    /// Delete up to `n` rows unconditionally (oldest by `last_used_at`).
488    ///
489    /// Used only for the hard-ceiling saturation path.
490    async fn delete_oldest_unconditional(&self, n: usize) -> Result<usize, MemoryError> {
491        let limit = i64::try_from(n).unwrap_or(i64::MAX);
492        let raw = "DELETE FROM reasoning_strategies \
493                   WHERE id IN ( \
494                     SELECT id FROM reasoning_strategies \
495                     ORDER BY last_used_at ASC LIMIT ? \
496                   )";
497        let sql = zeph_db::rewrite_placeholders(raw);
498        let result = zeph_db::query(&sql).bind(limit).execute(&self.pool).await?;
499        Ok(usize::try_from(result.rows_affected()).unwrap_or(0))
500    }
501}
502
503// ── Free functions ────────────────────────────────────────────────────────────
504
505/// Run the self-judge step against a turn's message tail.
506///
507/// Sends the last `messages` slice to the LLM with the self-judge system prompt and
508/// attempts to parse the JSON response into a [`SelfJudgeOutcome`].
509///
510/// Returns `None` on parse failure, timeout, or LLM error — never propagates errors.
511/// Callers should log the `None` case at most at `debug` level.
512///
513/// # Examples
514///
515/// ```no_run
516/// use std::time::Duration;
517/// use zeph_llm::any::AnyProvider;
518/// use zeph_memory::reasoning::run_self_judge;
519///
520/// async fn demo(provider: AnyProvider, messages: &[zeph_llm::provider::Message]) {
521///     let outcome = run_self_judge(&provider, messages, Duration::from_secs(10)).await;
522///     if let Some(o) = outcome {
523///         println!("success={}, hint={}", o.success, o.task_hint);
524///     }
525/// }
526/// ```
527#[tracing::instrument(name = "memory.reasoning.self_judge", skip(provider, messages), fields(n = messages.len()))]
528pub async fn run_self_judge(
529    provider: &AnyProvider,
530    messages: &[Message],
531    extraction_timeout: Duration,
532) -> Option<SelfJudgeOutcome> {
533    if messages.is_empty() {
534        return None;
535    }
536
537    let user_prompt = build_transcript_prompt(messages);
538
539    let llm_messages = [
540        Message::from_legacy(Role::System, SELF_JUDGE_SYSTEM),
541        Message::from_legacy(Role::User, user_prompt),
542    ];
543
544    let response = match timeout(extraction_timeout, provider.chat(&llm_messages)).await {
545        Ok(Ok(text)) => text,
546        Ok(Err(e)) => {
547            tracing::warn!(error = %e, "reasoning: self-judge LLM call failed");
548            return None;
549        }
550        Err(_) => {
551            tracing::warn!("reasoning: self-judge timed out");
552            return None;
553        }
554    };
555
556    parse_self_judge_response(&response)
557}
558
559/// Run the distillation step.
560///
561/// Sends the reasoning chain and outcome label to the LLM and trims the response to
562/// at most 3 sentences and 512 characters.
563///
564/// Returns `None` on LLM error, timeout, or empty response.
565///
566/// # Examples
567///
568/// ```no_run
569/// use std::time::Duration;
570/// use zeph_llm::any::AnyProvider;
571/// use zeph_memory::reasoning::{Outcome, distill_strategy};
572///
573/// async fn demo(provider: AnyProvider) {
574///     let summary = distill_strategy(&provider, Outcome::Success, "tried X, worked", Duration::from_secs(10)).await;
575///     println!("{:?}", summary);
576/// }
577/// ```
578#[tracing::instrument(name = "memory.reasoning.distill", skip(provider, reasoning_chain))]
579pub async fn distill_strategy(
580    provider: &AnyProvider,
581    outcome: Outcome,
582    reasoning_chain: &str,
583    distill_timeout: Duration,
584) -> Option<String> {
585    if reasoning_chain.is_empty() {
586        return None;
587    }
588
589    let user_prompt = format!(
590        "Outcome: {}\n\nReasoning chain:\n{reasoning_chain}",
591        outcome.as_str()
592    );
593
594    let llm_messages = [
595        Message::from_legacy(Role::System, DISTILL_SYSTEM),
596        Message::from_legacy(Role::User, user_prompt),
597    ];
598
599    let response = match timeout(distill_timeout, provider.chat(&llm_messages)).await {
600        Ok(Ok(text)) => text,
601        Ok(Err(e)) => {
602            tracing::warn!(error = %e, "reasoning: distillation LLM call failed");
603            return None;
604        }
605        Err(_) => {
606            tracing::warn!("reasoning: distillation timed out");
607            return None;
608        }
609    };
610
611    let trimmed = trim_to_three_sentences(&response);
612    if trimmed.is_empty() {
613        None
614    } else {
615        Some(trimmed)
616    }
617}
618
619/// Configuration for the [`process_turn`] extraction pipeline.
620///
621/// Groups timeout and limit parameters that rarely change between turns.
622#[derive(Debug, Clone, Copy)]
623pub struct ProcessTurnConfig {
624    /// Maximum rows to retain in the `reasoning_strategies` table.
625    pub store_limit: usize,
626    /// Timeout for the self-judge LLM call.
627    pub extraction_timeout: Duration,
628    /// Timeout for the distillation LLM call.
629    pub distill_timeout: Duration,
630    /// Maximum number of recent messages sliced from the turn history before passing
631    /// to the self-judge evaluator. Narrowing the window prevents digest/recap messages
632    /// from prior sessions from confusing the classifier. Default: `2`.
633    pub self_judge_window: usize,
634    /// Minimum character count in the last assistant message to trigger self-judge.
635    /// Short or trivial responses (greetings, one-word answers) are skipped. Default: `50`.
636    pub min_assistant_chars: usize,
637}
638
639/// Run the full extraction pipeline for a single turn.
640///
641/// Calls [`run_self_judge`], then [`distill_strategy`], then inserts the result.
642/// `evict_lru` is called when the table exceeds `store_limit`. All errors are
643/// logged at `warn` level and the function returns `Ok(())` so callers never
644/// propagate pipeline failures.
645///
646/// # Errors
647///
648/// Returns an error if the embedding call fails, but not if self-judge or distillation fails.
649#[tracing::instrument(name = "memory.reasoning.process_turn", skip_all)]
650pub async fn process_turn(
651    memory: &ReasoningMemory,
652    extract_provider: &AnyProvider,
653    distill_provider: &AnyProvider,
654    embed_provider: &AnyProvider,
655    messages: &[Message],
656    cfg: ProcessTurnConfig,
657) -> Result<(), MemoryError> {
658    let ProcessTurnConfig {
659        store_limit,
660        extraction_timeout,
661        distill_timeout,
662        self_judge_window,
663        min_assistant_chars,
664    } = cfg;
665
666    // Narrow the message window to reduce noise from session digests and welcome-back
667    // messages that span prior sessions, which can confuse the self-judge classifier.
668    let judge_messages = if messages.len() > self_judge_window {
669        &messages[messages.len() - self_judge_window..]
670    } else {
671        messages
672    };
673
674    // Skip self-judge when the last assistant response is too short to be meaningful.
675    let last_assistant_chars = judge_messages
676        .iter()
677        .rev()
678        .find(|m| m.role == Role::Assistant)
679        .map_or(0, |m| m.content.len());
680    if last_assistant_chars < min_assistant_chars {
681        return Ok(());
682    }
683
684    let Some(outcome) = run_self_judge(extract_provider, judge_messages, extraction_timeout).await
685    else {
686        return Ok(());
687    };
688
689    let outcome_enum = if outcome.success {
690        Outcome::Success
691    } else {
692        Outcome::Failure
693    };
694
695    let Some(summary) = distill_strategy(
696        distill_provider,
697        outcome_enum,
698        &outcome.reasoning_chain,
699        distill_timeout,
700    )
701    .await
702    else {
703        return Ok(());
704    };
705
706    // Embed task_hint + summary for Qdrant retrieval (S2 from architect plan).
707    let embed_input = format!("{}\n{}", outcome.task_hint, summary);
708    let embedding = match tokio::time::timeout(
709        std::time::Duration::from_secs(5),
710        embed_provider.embed(&embed_input),
711    )
712    .await
713    {
714        Ok(Ok(v)) => v,
715        Ok(Err(e)) => {
716            tracing::warn!(error = %e, "reasoning: embedding failed — strategy not stored");
717            return Ok(());
718        }
719        Err(_) => {
720            tracing::warn!("reasoning: embed timed out — strategy not stored");
721            return Ok(());
722        }
723    };
724
725    let id = uuid::Uuid::new_v4().to_string();
726    let strategy = ReasoningStrategy {
727        id,
728        summary,
729        outcome: outcome_enum,
730        task_hint: outcome.task_hint,
731        created_at: 0, // filled by SQL EPOCH_NOW
732        last_used_at: 0,
733        use_count: 0,
734        embedded_at: None,
735    };
736
737    // P2-2: check count before insert to skip the evict_lru SELECT+DELETE when not needed.
738    // If count is already at or above store_limit, evict after insert. Approximate: two
739    // concurrent inserts can both read the same count and both decide to evict — the
740    // evict_lru implementation is idempotent so over-eviction by ≤1 row is acceptable.
741    let count_before = memory.count().await.unwrap_or(0);
742
743    if let Err(e) = memory.insert(&strategy, embedding).await {
744        tracing::warn!(error = %e, "reasoning: insert failed");
745        return Ok(());
746    }
747
748    if count_before >= store_limit
749        && let Err(e) = memory.evict_lru(store_limit).await
750    {
751        tracing::warn!(error = %e, "reasoning: evict_lru failed");
752    }
753
754    Ok(())
755}
756
757// ── private helpers ───────────────────────────────────────────────────────────
758
759/// Maximum characters taken from a single message's content in the transcript prompt.
760///
761/// Prevents unbounded prompt growth when long tool outputs or code blocks are present
762/// in the turn history (S-Med2 fix).
763const MAX_TRANSCRIPT_MESSAGE_CHARS: usize = 2000;
764
765/// Build a turn transcript prompt from the message slice.
766///
767/// Each message's content is truncated to [`MAX_TRANSCRIPT_MESSAGE_CHARS`] to bound
768/// the prompt length regardless of tool-output size. Mirrors the
769/// `build_extraction_prompt` format in `trajectory.rs` for consistency.
770fn build_transcript_prompt(messages: &[Message]) -> String {
771    let mut prompt = String::from("Agent turn messages:\n");
772    for (i, msg) in messages.iter().enumerate() {
773        use std::fmt::Write as _;
774        let role = format!("{:?}", msg.role);
775        // Truncate at a char boundary to avoid invalid UTF-8 slices.
776        let content: std::borrow::Cow<str> =
777            if msg.content.chars().count() > MAX_TRANSCRIPT_MESSAGE_CHARS {
778                msg.content
779                    .char_indices()
780                    .nth(MAX_TRANSCRIPT_MESSAGE_CHARS)
781                    .map_or(msg.content.as_str().into(), |(byte_idx, _)| {
782                        msg.content[..byte_idx].into()
783                    })
784            } else {
785                msg.content.as_str().into()
786            };
787        let _ = writeln!(prompt, "[{}] {}: {}", i + 1, role, content);
788    }
789    prompt.push_str("\nEvaluate this turn and return JSON.");
790    prompt
791}
792
793/// Parse the LLM response from the self-judge step into a [`SelfJudgeOutcome`].
794///
795/// Strips markdown code fences, then tries direct parse; on failure, locates the
796/// outermost `{…}` brackets and tries again. Returns `None` on persistent parse failure.
797fn parse_self_judge_response(response: &str) -> Option<SelfJudgeOutcome> {
798    // Strip markdown fences (```json … ```)
799    let stripped = response
800        .trim()
801        .trim_start_matches("```json")
802        .trim_start_matches("```")
803        .trim_end_matches("```")
804        .trim();
805
806    if let Ok(v) = serde_json::from_str::<SelfJudgeOutcome>(stripped) {
807        return Some(v);
808    }
809
810    // Try to extract the first `{…}` span.
811    if let (Some(start), Some(end)) = (stripped.find('{'), stripped.rfind('}'))
812        && end > start
813        && let Ok(v) = serde_json::from_str::<SelfJudgeOutcome>(&stripped[start..=end])
814    {
815        return Some(v);
816    }
817
818    tracing::warn!(
819        "reasoning: failed to parse self-judge response (len={}): {:.200}",
820        response.len(),
821        response
822    );
823    None
824}
825
826/// Trim text to at most 3 sentences and 512 characters.
827///
828/// Sentence boundaries are detected by `.`, `!`, `?` followed by whitespace or end-of-string.
829/// The hard 512-char cap truncates at the nearest char boundary below the limit.
830fn trim_to_three_sentences(text: &str) -> String {
831    const MAX_CHARS: usize = 512;
832    const MAX_SENTENCES: usize = 3;
833
834    let text = text.trim();
835    let mut sentence_ends: Vec<usize> = Vec::new();
836    let chars: Vec<char> = text.chars().collect();
837    let len = chars.len();
838
839    for (i, &ch) in chars.iter().enumerate() {
840        if matches!(ch, '.' | '!' | '?') {
841            let next_is_boundary = i + 1 >= len || chars[i + 1].is_whitespace();
842            if next_is_boundary {
843                sentence_ends.push(i + 1); // exclusive byte position (chars)
844                if sentence_ends.len() >= MAX_SENTENCES {
845                    break;
846                }
847            }
848        }
849    }
850
851    let char_limit = if let Some(&end) = sentence_ends.last() {
852        end.min(MAX_CHARS)
853    } else {
854        text.chars().count().min(MAX_CHARS)
855    };
856
857    let result: String = text.chars().take(char_limit).collect();
858    // Hard cap on byte length (chars already limited, but enforce once more).
859    match result.char_indices().nth(MAX_CHARS) {
860        Some((byte_idx, _)) => result[..byte_idx].to_owned(),
861        None => result,
862    }
863}
864
865#[cfg(test)]
866mod tests {
867    use super::*;
868
869    // ── Outcome ────────────────────────────────────────────────────────────────
870
871    #[test]
872    fn outcome_as_str_round_trip() {
873        assert_eq!(Outcome::Success.as_str(), "success");
874        assert_eq!(Outcome::Failure.as_str(), "failure");
875    }
876
877    #[test]
878    fn outcome_from_str_success() {
879        assert_eq!(Outcome::from_str("success").unwrap(), Outcome::Success);
880    }
881
882    #[test]
883    fn outcome_from_str_failure() {
884        assert_eq!(Outcome::from_str("failure").unwrap(), Outcome::Failure);
885    }
886
887    #[test]
888    fn outcome_from_str_unknown_defaults_to_failure() {
889        // Unknown values silently map to Failure (forward-compatible).
890        assert_eq!(Outcome::from_str("partial").unwrap(), Outcome::Failure);
891    }
892
893    // ── parse_self_judge_response ─────────────────────────────────────────────
894
895    #[test]
896    fn parse_direct_json() {
897        let json = r#"{"success":true,"reasoning_chain":"tried X","task_hint":"do Y"}"#;
898        let outcome = parse_self_judge_response(json).unwrap();
899        assert!(outcome.success);
900        assert_eq!(outcome.reasoning_chain, "tried X");
901        assert_eq!(outcome.task_hint, "do Y");
902    }
903
904    #[test]
905    fn parse_json_with_markdown_fences() {
906        let response =
907            "```json\n{\"success\":false,\"reasoning_chain\":\"r\",\"task_hint\":\"t\"}\n```";
908        let outcome = parse_self_judge_response(response).unwrap();
909        assert!(!outcome.success);
910    }
911
912    #[test]
913    fn parse_json_embedded_in_prose() {
914        let response = r#"Here is the evaluation: {"success":true,"reasoning_chain":"chain","task_hint":"hint"} — done."#;
915        let outcome = parse_self_judge_response(response).unwrap();
916        assert!(outcome.success);
917    }
918
919    #[test]
920    fn parse_invalid_returns_none() {
921        let outcome = parse_self_judge_response("not json at all");
922        assert!(outcome.is_none());
923    }
924
925    // ── trim_to_three_sentences ───────────────────────────────────────────────
926
927    #[test]
928    fn trim_three_sentences_short_text() {
929        let text = "One. Two. Three.";
930        assert_eq!(trim_to_three_sentences(text), "One. Two. Three.");
931    }
932
933    #[test]
934    fn trim_three_sentences_truncates_at_third() {
935        let text = "One. Two. Three. Four. Five.";
936        let result = trim_to_three_sentences(text);
937        assert!(result.ends_with("Three."), "got: {result}");
938        assert!(!result.contains("Four"));
939    }
940
941    #[test]
942    fn trim_three_sentences_hard_cap() {
943        // 600 chars, no sentence boundaries → should be capped at 512 chars
944        let long: String = "x".repeat(600);
945        let result = trim_to_three_sentences(&long);
946        assert!(result.chars().count() <= 512);
947    }
948
949    #[test]
950    fn trim_three_sentences_empty() {
951        assert_eq!(trim_to_three_sentences("   "), "");
952    }
953
954    // ── ReasoningMemory (in-memory SQLite) ────────────────────────────────────
955
956    async fn make_test_pool() -> DbPool {
957        let pool = sqlx::SqlitePool::connect(":memory:").await.unwrap();
958        sqlx::query(
959            "CREATE TABLE reasoning_strategies (
960                id           TEXT    PRIMARY KEY NOT NULL,
961                summary      TEXT    NOT NULL,
962                outcome      TEXT    NOT NULL,
963                task_hint    TEXT    NOT NULL,
964                created_at   INTEGER NOT NULL DEFAULT (unixepoch('now')),
965                last_used_at INTEGER NOT NULL DEFAULT (unixepoch('now')),
966                use_count    INTEGER NOT NULL DEFAULT 0,
967                embedded_at  INTEGER
968            )",
969        )
970        .execute(&pool)
971        .await
972        .unwrap();
973        pool
974    }
975
976    fn make_strategy(id: &str) -> ReasoningStrategy {
977        ReasoningStrategy {
978            id: id.to_owned(),
979            summary: format!("Summary for {id}"),
980            outcome: Outcome::Success,
981            task_hint: format!("Task hint for {id}"),
982            created_at: 0,
983            last_used_at: 0,
984            use_count: 0,
985            embedded_at: None,
986        }
987    }
988
989    #[tokio::test]
990    async fn insert_and_fetch_by_ids() {
991        let pool = make_test_pool().await;
992        let mem = ReasoningMemory::new(pool, None);
993
994        let s = make_strategy("abc-123");
995        mem.insert(&s, vec![]).await.unwrap();
996
997        let rows = mem.fetch_by_ids(&["abc-123".to_owned()]).await.unwrap();
998        assert_eq!(rows.len(), 1);
999        assert_eq!(rows[0].id, "abc-123");
1000        assert_eq!(rows[0].outcome, Outcome::Success);
1001    }
1002
1003    #[tokio::test]
1004    async fn mark_used_increments_count() {
1005        let pool = make_test_pool().await;
1006        let mem = ReasoningMemory::new(pool, None);
1007
1008        let s = make_strategy("mark-1");
1009        mem.insert(&s, vec![]).await.unwrap();
1010        mem.mark_used(&["mark-1".to_owned()]).await.unwrap();
1011        mem.mark_used(&["mark-1".to_owned()]).await.unwrap();
1012
1013        let rows = mem.fetch_by_ids(&["mark-1".to_owned()]).await.unwrap();
1014        assert_eq!(rows[0].use_count, 2);
1015    }
1016
1017    #[tokio::test]
1018    async fn mark_used_empty_is_noop() {
1019        let pool = make_test_pool().await;
1020        let mem = ReasoningMemory::new(pool, None);
1021        // Should not panic or error on empty slice.
1022        mem.mark_used(&[]).await.unwrap();
1023    }
1024
1025    #[tokio::test]
1026    async fn count_returns_correct_total() {
1027        let pool = make_test_pool().await;
1028        let mem = ReasoningMemory::new(pool, None);
1029
1030        for i in 0..5 {
1031            mem.insert(&make_strategy(&format!("s{i}")), vec![])
1032                .await
1033                .unwrap();
1034        }
1035
1036        assert_eq!(mem.count().await.unwrap(), 5);
1037    }
1038
1039    #[tokio::test]
1040    async fn evict_lru_cold_rows() {
1041        let pool = make_test_pool().await;
1042        let mem = ReasoningMemory::new(pool, None);
1043
1044        // Insert 5 cold rows (use_count = 0 by default).
1045        for i in 0..5 {
1046            mem.insert(&make_strategy(&format!("cold-{i}")), vec![])
1047                .await
1048                .unwrap();
1049        }
1050
1051        // Store limit is 3 → should delete 2 oldest.
1052        let deleted = mem.evict_lru(3).await.unwrap();
1053        assert_eq!(deleted, 2);
1054        assert_eq!(mem.count().await.unwrap(), 3);
1055    }
1056
1057    #[tokio::test]
1058    async fn evict_lru_respects_hot_rows_under_ceiling() {
1059        let pool = make_test_pool().await;
1060        let mem = ReasoningMemory::new(pool.clone(), None);
1061
1062        // Insert 5 hot rows by manually setting use_count > HOT_STRATEGY_USE_COUNT.
1063        for i in 0..5 {
1064            let id = format!("hot-{i}");
1065            mem.insert(&make_strategy(&id), vec![]).await.unwrap();
1066            // Mark used 11 times to make them hot.
1067            let ids: Vec<String> = (0..11).map(|_| id.clone()).collect();
1068            for chunk_ids in ids.chunks(1) {
1069                mem.mark_used(chunk_ids).await.unwrap();
1070            }
1071        }
1072
1073        // store_limit=3, count=5, all hot, 5 < 2*3=6 → under ceiling → no deletion.
1074        let deleted = mem.evict_lru(3).await.unwrap();
1075        assert_eq!(deleted, 0);
1076        assert_eq!(mem.count().await.unwrap(), 5);
1077    }
1078
1079    #[tokio::test]
1080    async fn evict_lru_hard_ceiling_forces_deletion() {
1081        let pool = make_test_pool().await;
1082        let mem = ReasoningMemory::new(pool.clone(), None);
1083
1084        // Insert 7 hot rows. store_limit=3, ceiling=6. 7 > 6 → forced eviction.
1085        for i in 0..7 {
1086            let id = format!("hot2-{i}");
1087            mem.insert(&make_strategy(&id), vec![]).await.unwrap();
1088            // Make hot.
1089            for _ in 0..=HOT_STRATEGY_USE_COUNT {
1090                mem.mark_used(std::slice::from_ref(&id)).await.unwrap();
1091            }
1092        }
1093
1094        let deleted = mem.evict_lru(3).await.unwrap();
1095        assert!(deleted > 0, "expected forced deletion");
1096        let remaining = mem.count().await.unwrap();
1097        assert_eq!(remaining, 3, "should be trimmed to store_limit");
1098    }
1099
1100    #[tokio::test]
1101    async fn evict_lru_no_op_when_under_limit() {
1102        let pool = make_test_pool().await;
1103        let mem = ReasoningMemory::new(pool, None);
1104
1105        for i in 0..3 {
1106            mem.insert(&make_strategy(&format!("s{i}")), vec![])
1107                .await
1108                .unwrap();
1109        }
1110
1111        // store_limit=10 → count(3) ≤ 10 → no deletion.
1112        let deleted = mem.evict_lru(10).await.unwrap();
1113        assert_eq!(deleted, 0);
1114    }
1115
1116    // ── mark_used chunked path ────────────────────────────────────────────────
1117
1118    #[tokio::test]
1119    async fn mark_used_chunked_over_490_ids() {
1120        let pool = make_test_pool().await;
1121        let mem = ReasoningMemory::new(pool, None);
1122
1123        // Insert 500 strategies — exceeds MAX_IDS_PER_QUERY (490) forcing two SQL batches.
1124        for i in 0..500usize {
1125            mem.insert(&make_strategy(&format!("chunked-{i}")), vec![])
1126                .await
1127                .unwrap();
1128        }
1129
1130        let ids: Vec<String> = (0..500usize).map(|i| format!("chunked-{i}")).collect();
1131        mem.mark_used(&ids).await.unwrap();
1132
1133        // Spot-check: first and 491st should both have use_count == 1.
1134        let first = mem.fetch_by_ids(&[ids[0].clone()]).await.unwrap();
1135        let over_chunk = mem.fetch_by_ids(&[ids[490].clone()]).await.unwrap();
1136        assert_eq!(first[0].use_count, 1, "first id should have use_count = 1");
1137        assert_eq!(
1138            over_chunk[0].use_count, 1,
1139            "id past the chunk boundary should have use_count = 1"
1140        );
1141    }
1142
1143    // ── run_self_judge malformed response ─────────────────────────────────────
1144
1145    #[tokio::test]
1146    async fn run_self_judge_malformed_json_returns_none() {
1147        use zeph_llm::any::AnyProvider;
1148        use zeph_llm::mock::MockProvider;
1149
1150        // with_responses populates the one-shot queue; chat() returns this prose string.
1151        let provider = AnyProvider::Mock(MockProvider::with_responses(vec![
1152            "This is not JSON at all.".to_string(),
1153        ]));
1154        let msgs = vec![Message::from_legacy(Role::User, "hello")];
1155        let result = run_self_judge(&provider, &msgs, std::time::Duration::from_secs(5)).await;
1156        assert!(result.is_none(), "malformed LLM response must return None");
1157    }
1158
1159    // ── distill_strategy truncation ───────────────────────────────────────────
1160
1161    #[tokio::test]
1162    async fn distill_strategy_truncates_to_three_sentences() {
1163        use zeph_llm::any::AnyProvider;
1164        use zeph_llm::mock::MockProvider;
1165
1166        let long_response = "One. Two. Three. Four. Five.";
1167        let provider = AnyProvider::Mock(MockProvider::with_responses(vec![
1168            long_response.to_string(),
1169        ]));
1170        let result = distill_strategy(
1171            &provider,
1172            Outcome::Success,
1173            "chain here",
1174            std::time::Duration::from_secs(5),
1175        )
1176        .await
1177        .unwrap();
1178        assert!(result.ends_with("Three."), "got: {result}");
1179        assert!(
1180            !result.contains("Four"),
1181            "should not contain 4th sentence: {result}"
1182        );
1183    }
1184
1185    // ── process_turn smoke test ───────────────────────────────────────────────
1186
1187    #[tokio::test]
1188    async fn process_turn_with_empty_messages_is_noop() {
1189        use zeph_llm::any::AnyProvider;
1190        use zeph_llm::mock::MockProvider;
1191
1192        let pool = make_test_pool().await;
1193        let mem = ReasoningMemory::new(pool, None);
1194        // MockProvider returns "{}" which parse_self_judge_response will return None for
1195        // (missing required fields) → Ok(()) with zero inserts.
1196        let provider = AnyProvider::Mock(MockProvider::default());
1197        let cfg = ProcessTurnConfig {
1198            store_limit: 100,
1199            extraction_timeout: std::time::Duration::from_secs(1),
1200            distill_timeout: std::time::Duration::from_secs(1),
1201            self_judge_window: 2,
1202            min_assistant_chars: 0,
1203        };
1204        let result = process_turn(&mem, &provider, &provider, &provider, &[], cfg).await;
1205        assert!(
1206            result.is_ok(),
1207            "process_turn with empty messages must succeed"
1208        );
1209        assert_eq!(
1210            mem.count().await.unwrap(),
1211            0,
1212            "no strategies should be stored"
1213        );
1214    }
1215}