1use std::time::Duration;
14
15use serde::{Deserialize, Serialize};
16use tokio::time::timeout;
17use zeph_llm::any::AnyProvider;
18use zeph_llm::provider::{LlmProvider as _, Message, Role};
19
20use crate::error::MemoryError;
21
22const EXTRACTION_SYSTEM_PROMPT: &str = "\
23You are a trajectory memory extractor. Given messages from an agent turn that included tool \
24calls, extract reusable patterns and notable events.
25
26Rules:
271. Classify each entry as 'procedural' (a reusable how-to pattern — general technique) or \
28 'episodic' (a one-off event — specific occurrence).
292. Focus on the intent (goal), outcome (result), and tools used.
303. Confidence: 0.8-1.0 for clear outcomes, 0.4-0.7 for ambiguous ones.
314. Keep intent and outcome concise (one sentence each).
325. Return empty array if no meaningful entries can be extracted.
33
34Output JSON array:
35[
36 {
37 \"kind\": \"procedural|episodic\",
38 \"intent\": \"what the agent was trying to accomplish\",
39 \"outcome\": \"what actually happened\",
40 \"tools_used\": [\"tool_name\", ...],
41 \"confidence\": 0.0-1.0
42 }
43]";
44
45#[derive(Debug, Clone)]
47pub struct TrajectoryEntry {
48 pub kind: String,
49 pub intent: String,
50 pub outcome: String,
51 pub tools_used: Vec<String>,
52 pub confidence: f64,
53}
54
55pub struct TrajectoryExtractionConfig {
57 pub enabled: bool,
58 pub max_messages: usize,
60 pub extraction_timeout_secs: u64,
62}
63
64#[derive(Debug, Deserialize, Serialize)]
65struct RawEntry {
66 kind: String,
67 intent: String,
68 outcome: String,
69 #[serde(default)]
70 tools_used: Vec<String>,
71 confidence: f64,
72}
73
74pub async fn extract_trajectory_entries(
82 provider: &AnyProvider,
83 messages: &[Message],
84 config: &TrajectoryExtractionConfig,
85) -> Result<Vec<TrajectoryEntry>, MemoryError> {
86 if !config.enabled || messages.is_empty() {
87 return Ok(Vec::new());
88 }
89
90 let messages_to_send: Vec<&Message> = messages
91 .iter()
92 .rev()
93 .take(config.max_messages)
94 .collect::<Vec<_>>()
95 .into_iter()
96 .rev()
97 .collect();
98
99 let user_prompt = build_extraction_prompt(&messages_to_send);
100
101 let llm_messages = [
102 Message::from_legacy(Role::System, EXTRACTION_SYSTEM_PROMPT),
103 Message::from_legacy(Role::User, user_prompt),
104 ];
105
106 let extraction_timeout = Duration::from_secs(config.extraction_timeout_secs);
107 let response = match timeout(extraction_timeout, provider.chat(&llm_messages)).await {
108 Ok(Ok(text)) => text,
109 Ok(Err(e)) => return Err(MemoryError::Llm(e)),
110 Err(_) => {
111 tracing::warn!(
112 "trajectory extraction timed out after {}s",
113 config.extraction_timeout_secs
114 );
115 return Ok(Vec::new());
116 }
117 };
118
119 let entries = parse_extraction_response(&response);
120 Ok(entries)
121}
122
123fn build_extraction_prompt(messages: &[&Message]) -> String {
124 let mut prompt = String::from("Agent turn messages:\n");
125 for (i, msg) in messages.iter().enumerate() {
126 use std::fmt::Write as _;
127 let role = format!("{:?}", msg.role);
128 let _ = writeln!(prompt, "[{}] {}: {}", i + 1, role, msg.content);
129 }
130 prompt.push_str("\nExtract trajectory entries as JSON array.");
131 prompt
132}
133
134fn parse_extraction_response(response: &str) -> Vec<TrajectoryEntry> {
135 let raw: Vec<RawEntry> = if let Ok(v) = serde_json::from_str(response) {
136 v
137 } else if let (Some(start), Some(end)) = (response.find('['), response.rfind(']'))
138 && end > start
139 {
140 serde_json::from_str(&response[start..=end]).unwrap_or_default()
141 } else {
142 tracing::warn!(
143 "trajectory extraction: failed to parse response (len={}): {:.200}",
144 response.len(),
145 response
146 );
147 return Vec::new();
148 };
149
150 raw.into_iter()
151 .filter(|e| !e.intent.is_empty() && !e.outcome.is_empty())
152 .filter(|e| matches!(e.kind.as_str(), "procedural" | "episodic"))
153 .map(|e| TrajectoryEntry {
154 kind: e.kind,
155 intent: e.intent,
156 outcome: e.outcome,
157 tools_used: e.tools_used,
158 confidence: e.confidence.clamp(0.0, 1.0),
159 })
160 .collect()
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn parse_direct_json_array() {
169 let json = r#"[{"kind":"procedural","intent":"read a file","outcome":"file read ok","tools_used":["read_file"],"confidence":0.9}]"#;
170 let entries = parse_extraction_response(json);
171 assert_eq!(entries.len(), 1);
172 assert_eq!(entries[0].kind, "procedural");
173 assert_eq!(entries[0].intent, "read a file");
174 assert_eq!(entries[0].tools_used, vec!["read_file"]);
175 assert!((entries[0].confidence - 0.9).abs() < 1e-9);
176 }
177
178 #[test]
179 fn parse_json_embedded_in_prose() {
180 let response = "Here you go:\n[{\"kind\":\"episodic\",\"intent\":\"fixed a bug\",\"outcome\":\"patch applied\",\"tools_used\":[],\"confidence\":0.8}]\nDone.";
181 let entries = parse_extraction_response(response);
182 assert_eq!(entries.len(), 1);
183 assert_eq!(entries[0].kind, "episodic");
184 }
185
186 #[test]
187 fn parse_empty_array() {
188 let entries = parse_extraction_response("[]");
189 assert!(entries.is_empty());
190 }
191
192 #[test]
193 fn parse_invalid_json_returns_empty() {
194 let entries = parse_extraction_response("not json");
195 assert!(entries.is_empty());
196 }
197
198 #[test]
199 fn parse_filters_unknown_kind() {
200 let json =
201 r#"[{"kind":"unknown","intent":"x","outcome":"y","tools_used":[],"confidence":0.5}]"#;
202 let entries = parse_extraction_response(json);
203 assert!(entries.is_empty(), "unknown kind must be filtered out");
204 }
205
206 #[test]
207 fn parse_clamps_confidence() {
208 let json = r#"[{"kind":"procedural","intent":"x","outcome":"y","tools_used":[],"confidence":1.5}]"#;
209 let entries = parse_extraction_response(json);
210 assert_eq!(entries.len(), 1);
211 assert!(
212 (entries[0].confidence - 1.0).abs() < 1e-9,
213 "confidence must be clamped to 1.0"
214 );
215 }
216
217 #[test]
218 fn parse_filters_empty_intent_or_outcome() {
219 let json =
220 r#"[{"kind":"procedural","intent":"","outcome":"y","tools_used":[],"confidence":0.8}]"#;
221 let entries = parse_extraction_response(json);
222 assert!(entries.is_empty(), "empty intent must be filtered");
223 }
224}