Skip to main content

zeph_memory/semantic/
trajectory.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Trajectory-informed memory extraction (#2498).
5//!
6//! After each agent turn containing tool calls, a fast LLM provider analyzes the turn
7//! and produces procedural (reusable how-to patterns) and episodic (one-off event) entries.
8//! Entries are stored per-conversation so concurrent sessions do not interfere (critic S1).
9//!
10//! Extraction is always fire-and-forget (caller uses `tokio::spawn`) — no latency added to
11//! the response path (critic M3).
12
13use std::time::Duration;
14
15use serde::{Deserialize, Serialize};
16use tokio::time::timeout;
17use zeph_llm::any::AnyProvider;
18use zeph_llm::provider::{Message, Role};
19
20use crate::error::MemoryError;
21
22const EXTRACTION_SYSTEM_PROMPT: &str = "\
23You are a trajectory memory extractor. Given messages from an agent turn that included tool \
24calls, extract reusable patterns and notable events.
25
26Rules:
271. Classify each entry as 'procedural' (a reusable how-to pattern — general technique) or \
28   'episodic' (a one-off event — specific occurrence).
292. Focus on the intent (goal), outcome (result), and tools used.
303. Confidence: 0.8-1.0 for clear outcomes, 0.4-0.7 for ambiguous ones.
314. Keep intent and outcome concise (one sentence each).
325. Return empty array if no meaningful entries can be extracted.
33
34Output JSON array:
35[
36  {
37    \"kind\": \"procedural|episodic\",
38    \"intent\": \"what the agent was trying to accomplish\",
39    \"outcome\": \"what actually happened\",
40    \"tools_used\": [\"tool_name\", ...],
41    \"confidence\": 0.0-1.0
42  }
43]";
44
45/// A single extracted trajectory entry (in-memory, before storage).
46#[derive(Debug, Clone)]
47pub struct TrajectoryEntry {
48    pub kind: String,
49    pub intent: String,
50    pub outcome: String,
51    pub tools_used: Vec<String>,
52    pub confidence: f64,
53}
54
55/// Configuration for trajectory extraction.
56pub struct TrajectoryExtractionConfig {
57    pub enabled: bool,
58    /// Maximum messages fed to the LLM per extraction pass.
59    pub max_messages: usize,
60    /// LLM timeout in seconds.
61    pub extraction_timeout_secs: u64,
62}
63
64#[derive(Debug, Deserialize, Serialize)]
65struct RawEntry {
66    kind: String,
67    intent: String,
68    outcome: String,
69    #[serde(default)]
70    tools_used: Vec<String>,
71    confidence: f64,
72}
73
74/// Extract trajectory entries from a turn's messages.
75///
76/// Returns the extracted entries. Parse failures are logged and treated as zero entries.
77///
78/// # Errors
79///
80/// Returns an error only for transport-level LLM failures.
81pub async fn extract_trajectory_entries(
82    provider: &AnyProvider,
83    messages: &[Message],
84    config: &TrajectoryExtractionConfig,
85) -> Result<Vec<TrajectoryEntry>, MemoryError> {
86    if !config.enabled || messages.is_empty() {
87        return Ok(Vec::new());
88    }
89
90    let messages_to_send: Vec<&Message> = messages
91        .iter()
92        .rev()
93        .take(config.max_messages)
94        .collect::<Vec<_>>()
95        .into_iter()
96        .rev()
97        .collect();
98
99    let user_prompt = build_extraction_prompt(&messages_to_send);
100
101    let llm_messages = [
102        Message::from_legacy(Role::System, EXTRACTION_SYSTEM_PROMPT),
103        Message::from_legacy(Role::User, user_prompt),
104    ];
105
106    let extraction_timeout = Duration::from_secs(config.extraction_timeout_secs);
107    let response = match timeout(
108        extraction_timeout,
109        provider.chat_with_named_provider("trajectory", &llm_messages),
110    )
111    .await
112    {
113        Ok(Ok(text)) => text,
114        Ok(Err(e)) => return Err(MemoryError::Llm(e)),
115        Err(_) => {
116            tracing::warn!(
117                "trajectory extraction timed out after {}s",
118                config.extraction_timeout_secs
119            );
120            return Ok(Vec::new());
121        }
122    };
123
124    let entries = parse_extraction_response(&response);
125    Ok(entries)
126}
127
128fn build_extraction_prompt(messages: &[&Message]) -> String {
129    let mut prompt = String::from("Agent turn messages:\n");
130    for (i, msg) in messages.iter().enumerate() {
131        use std::fmt::Write as _;
132        let role = format!("{:?}", msg.role);
133        let _ = writeln!(prompt, "[{}] {}: {}", i + 1, role, msg.content);
134    }
135    prompt.push_str("\nExtract trajectory entries as JSON array.");
136    prompt
137}
138
139fn parse_extraction_response(response: &str) -> Vec<TrajectoryEntry> {
140    let raw: Vec<RawEntry> = if let Ok(v) = serde_json::from_str(response) {
141        v
142    } else if let (Some(start), Some(end)) = (response.find('['), response.rfind(']'))
143        && end > start
144    {
145        serde_json::from_str(&response[start..=end]).unwrap_or_default()
146    } else {
147        tracing::warn!(
148            "trajectory extraction: failed to parse response (len={}): {:.200}",
149            response.len(),
150            response
151        );
152        return Vec::new();
153    };
154
155    raw.into_iter()
156        .filter(|e| !e.intent.is_empty() && !e.outcome.is_empty())
157        .filter(|e| matches!(e.kind.as_str(), "procedural" | "episodic"))
158        .map(|e| TrajectoryEntry {
159            kind: e.kind,
160            intent: e.intent,
161            outcome: e.outcome,
162            tools_used: e.tools_used,
163            confidence: e.confidence.clamp(0.0, 1.0),
164        })
165        .collect()
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn parse_direct_json_array() {
174        let json = r#"[{"kind":"procedural","intent":"read a file","outcome":"file read ok","tools_used":["read_file"],"confidence":0.9}]"#;
175        let entries = parse_extraction_response(json);
176        assert_eq!(entries.len(), 1);
177        assert_eq!(entries[0].kind, "procedural");
178        assert_eq!(entries[0].intent, "read a file");
179        assert_eq!(entries[0].tools_used, vec!["read_file"]);
180        assert!((entries[0].confidence - 0.9).abs() < 1e-9);
181    }
182
183    #[test]
184    fn parse_json_embedded_in_prose() {
185        let response = "Here you go:\n[{\"kind\":\"episodic\",\"intent\":\"fixed a bug\",\"outcome\":\"patch applied\",\"tools_used\":[],\"confidence\":0.8}]\nDone.";
186        let entries = parse_extraction_response(response);
187        assert_eq!(entries.len(), 1);
188        assert_eq!(entries[0].kind, "episodic");
189    }
190
191    #[test]
192    fn parse_empty_array() {
193        let entries = parse_extraction_response("[]");
194        assert!(entries.is_empty());
195    }
196
197    #[test]
198    fn parse_invalid_json_returns_empty() {
199        let entries = parse_extraction_response("not json");
200        assert!(entries.is_empty());
201    }
202
203    #[test]
204    fn parse_filters_unknown_kind() {
205        let json =
206            r#"[{"kind":"unknown","intent":"x","outcome":"y","tools_used":[],"confidence":0.5}]"#;
207        let entries = parse_extraction_response(json);
208        assert!(entries.is_empty(), "unknown kind must be filtered out");
209    }
210
211    #[test]
212    fn parse_clamps_confidence() {
213        let json = r#"[{"kind":"procedural","intent":"x","outcome":"y","tools_used":[],"confidence":1.5}]"#;
214        let entries = parse_extraction_response(json);
215        assert_eq!(entries.len(), 1);
216        assert!(
217            (entries[0].confidence - 1.0).abs() < 1e-9,
218            "confidence must be clamped to 1.0"
219        );
220    }
221
222    #[test]
223    fn parse_filters_empty_intent_or_outcome() {
224        let json =
225            r#"[{"kind":"procedural","intent":"","outcome":"y","tools_used":[],"confidence":0.8}]"#;
226        let entries = parse_extraction_response(json);
227        assert!(entries.is_empty(), "empty intent must be filtered");
228    }
229}