1use std::time::Duration;
13
14use serde::{Deserialize, Serialize};
15use tokio::time::timeout;
16use zeph_llm::any::AnyProvider;
17use zeph_llm::provider::{LlmProvider as _, Message, Role};
18
19use crate::error::MemoryError;
20use crate::store::DbStore;
21use crate::store::persona::PersonaFactRow;
22
23const EXTRACTION_SYSTEM_PROMPT: &str = "\
24You are a persona fact extractor. Given a list of user messages and any existing persona \
25facts for each category, extract factual claims the user makes about themselves: their \
26preferences, domain knowledge, working style, communication style, and background.
27
28Rules:
291. Only extract facts from first-person user statements (\"I prefer\", \"I work on\", \
30 \"my team\", \"I use\", etc.). Ignore assistant messages.
312. Do NOT extract facts from questions, greetings, or tool outputs.
323. For each extracted fact, decide if it is NEW (no existing fact contradicts it) or \
33 UPDATE (it contradicts or replaces an existing fact). For UPDATE, provide the \
34 `supersedes_id` of the older fact.
354. Confidence: 0.8-1.0 for explicit statements (\"I prefer X\"), 0.4-0.7 for inferences.
365. Categories: preference, domain_knowledge, working_style, communication, background.
376. Keep content concise (one sentence max). Normalize to English.
387. Return empty array if no facts are found.
39
40Output JSON array of objects:
41[
42 {
43 \"category\": \"preference|domain_knowledge|working_style|communication|background\",
44 \"content\": \"concise factual statement\",
45 \"confidence\": 0.0-1.0,
46 \"action\": \"new|update\",
47 \"supersedes_id\": null or integer id of the fact being replaced
48 }
49]";
50
51pub struct PersonaExtractionConfig {
53 pub enabled: bool,
54 pub min_messages: usize,
56 pub max_messages: usize,
58 pub extraction_timeout_secs: u64,
60}
61
62#[derive(Debug, Deserialize, Serialize)]
63struct ExtractedFact {
64 category: String,
65 content: String,
66 confidence: f64,
67 action: String,
68 supersedes_id: Option<i64>,
69}
70
71#[must_use]
74pub fn contains_self_referential_language(text: &str) -> bool {
75 let lower = text.to_lowercase();
78 let tokens = [" i ", " i'", " my ", " me ", " mine ", "i am ", "i'm "];
79 tokens.iter().any(|t| lower.contains(t)) || lower.starts_with("i ") || lower.starts_with("my ")
80}
81
82pub async fn extract_persona_facts(
91 store: &DbStore,
92 provider: &AnyProvider,
93 user_messages: &[&str],
94 config: &PersonaExtractionConfig,
95 conversation_id: Option<i64>,
96) -> Result<usize, MemoryError> {
97 if !config.enabled || user_messages.len() < config.min_messages {
98 return Ok(0);
99 }
100
101 let has_self_ref = user_messages
103 .iter()
104 .any(|m| contains_self_referential_language(m));
105 if !has_self_ref {
106 return Ok(0);
107 }
108
109 let messages_to_send: Vec<&str> = user_messages
110 .iter()
111 .rev()
112 .take(config.max_messages)
113 .copied()
114 .collect::<Vec<_>>()
115 .into_iter()
116 .rev()
117 .collect();
118
119 let existing_facts = store.load_persona_facts(0.0).await?;
121 let user_prompt = build_extraction_prompt(&messages_to_send, &existing_facts);
122
123 let llm_messages = [
124 Message::from_legacy(Role::System, EXTRACTION_SYSTEM_PROMPT),
125 Message::from_legacy(Role::User, user_prompt),
126 ];
127
128 let extraction_timeout = Duration::from_secs(config.extraction_timeout_secs);
129 let response = match timeout(extraction_timeout, provider.chat(&llm_messages)).await {
130 Ok(Ok(text)) => text,
131 Ok(Err(e)) => return Err(MemoryError::Llm(e)),
132 Err(_) => {
133 tracing::warn!(
134 "persona extraction timed out after {}s",
135 config.extraction_timeout_secs
136 );
137 return Ok(0);
138 }
139 };
140
141 let facts = parse_extraction_response(&response);
142 if facts.is_empty() {
143 return Ok(0);
144 }
145
146 let mut upserted = 0usize;
147 for fact in facts {
148 if fact.category.is_empty() || fact.content.is_empty() {
149 continue;
150 }
151 if !is_valid_category(&fact.category) {
152 tracing::debug!(
153 category = %fact.category,
154 "persona extraction: skipping unknown category"
155 );
156 continue;
157 }
158 match store
159 .upsert_persona_fact(
160 &fact.category,
161 &fact.content,
162 fact.confidence.clamp(0.0, 1.0),
163 conversation_id,
164 fact.supersedes_id,
165 )
166 .await
167 {
168 Ok(_) => upserted += 1,
169 Err(e) => {
170 tracing::warn!(error = %e, "persona extraction: failed to upsert fact");
171 }
172 }
173 }
174
175 tracing::debug!(upserted, "persona extraction complete");
176 Ok(upserted)
177}
178
179fn is_valid_category(category: &str) -> bool {
180 matches!(
181 category,
182 "preference" | "domain_knowledge" | "working_style" | "communication" | "background"
183 )
184}
185
186fn build_extraction_prompt(messages: &[&str], existing_facts: &[PersonaFactRow]) -> String {
187 let mut prompt = String::from("User messages to analyze:\n");
188 for (i, msg) in messages.iter().enumerate() {
189 use std::fmt::Write as _;
190 let _ = writeln!(prompt, "[{}] {}", i + 1, msg);
191 }
192
193 if !existing_facts.is_empty() {
194 prompt.push_str("\nExisting persona facts (for contradiction detection):\n");
195 for fact in existing_facts {
196 use std::fmt::Write as _;
197 let _ = writeln!(
198 prompt,
199 " id={} category={} content=\"{}\"",
200 fact.id, fact.category, fact.content
201 );
202 }
203 }
204
205 prompt.push_str("\nExtract persona facts as JSON array.");
206 prompt
207}
208
209fn parse_extraction_response(response: &str) -> Vec<ExtractedFact> {
210 if let Ok(facts) = serde_json::from_str::<Vec<ExtractedFact>>(response) {
212 return facts;
213 }
214
215 if let (Some(start), Some(end)) = (response.find('['), response.rfind(']'))
217 && end > start
218 {
219 let slice = &response[start..=end];
220 if let Ok(facts) = serde_json::from_str::<Vec<ExtractedFact>>(slice) {
221 return facts;
222 }
223 }
224
225 tracing::warn!(
226 "persona extraction: failed to parse LLM response (len={}): {:.200}",
227 response.len(),
228 response
229 );
230 Vec::new()
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use crate::store::DbStore;
237
238 async fn make_store() -> DbStore {
239 DbStore::with_pool_size(":memory:", 1)
240 .await
241 .expect("in-memory store")
242 }
243
244 #[test]
247 fn self_ref_detects_i_prefer() {
248 assert!(contains_self_referential_language("I prefer dark mode"));
249 }
250
251 #[test]
252 fn self_ref_detects_my_team() {
253 assert!(contains_self_referential_language("my team uses Rust"));
254 }
255
256 #[test]
257 fn self_ref_detects_sentence_starting_with_i() {
258 assert!(contains_self_referential_language("I work remotely"));
259 }
260
261 #[test]
262 fn self_ref_detects_inline_i() {
263 assert!(contains_self_referential_language(
264 "Sometimes I prefer async"
265 ));
266 }
267
268 #[test]
269 fn self_ref_detects_me_inline() {
270 assert!(contains_self_referential_language(
271 "That helps me understand"
272 ));
273 }
274
275 #[test]
276 fn self_ref_no_match_for_third_person() {
277 assert!(!contains_self_referential_language("The team uses Python"));
278 }
279
280 #[test]
281 fn self_ref_no_match_for_tool_output() {
282 assert!(!contains_self_referential_language("Error: file not found"));
283 }
284
285 #[test]
286 fn self_ref_no_match_for_empty_string() {
287 assert!(!contains_self_referential_language(""));
288 }
289
290 #[tokio::test]
293 async fn extraction_gate_skips_when_no_self_ref() {
294 let store = make_store().await;
295 let cfg = PersonaExtractionConfig {
301 enabled: true,
302 min_messages: 1,
303 max_messages: 10,
304 extraction_timeout_secs: 5,
305 };
306 let cfg_disabled = PersonaExtractionConfig {
311 enabled: false,
312 min_messages: 1,
313 max_messages: 10,
314 extraction_timeout_secs: 5,
315 };
316 let cfg_min = PersonaExtractionConfig {
320 enabled: true,
321 min_messages: 5,
322 max_messages: 10,
323 extraction_timeout_secs: 5,
324 };
325 let messages: Vec<&str> = vec![];
330 assert!(messages.len() < cfg_min.min_messages);
331 let _ = (store, cfg, cfg_disabled, cfg_min); }
333
334 #[test]
337 fn parse_direct_json_array() {
338 let json = r#"[{"category":"preference","content":"I prefer dark mode","confidence":0.9,"action":"new","supersedes_id":null}]"#;
339 let facts = parse_extraction_response(json);
340 assert_eq!(facts.len(), 1);
341 assert_eq!(facts[0].category, "preference");
342 assert_eq!(facts[0].content, "I prefer dark mode");
343 assert!((facts[0].confidence - 0.9).abs() < 1e-9);
344 assert_eq!(facts[0].action, "new");
345 assert!(facts[0].supersedes_id.is_none());
346 }
347
348 #[test]
349 fn parse_json_embedded_in_prose() {
350 let response = "Sure! Here are the facts:\n[{\"category\":\"domain_knowledge\",\"content\":\"Uses Rust\",\"confidence\":0.8,\"action\":\"new\",\"supersedes_id\":null}]\nThat's all.";
351 let facts = parse_extraction_response(response);
352 assert_eq!(facts.len(), 1);
353 assert_eq!(facts[0].category, "domain_knowledge");
354 }
355
356 #[test]
357 fn parse_empty_array() {
358 let facts = parse_extraction_response("[]");
359 assert!(facts.is_empty());
360 }
361
362 #[test]
363 fn parse_invalid_json_returns_empty() {
364 let facts = parse_extraction_response("not json at all");
365 assert!(facts.is_empty());
366 }
367
368 #[test]
369 fn parse_supersedes_id_populated() {
370 let json = r#"[{"category":"preference","content":"I prefer dark mode","confidence":0.9,"action":"update","supersedes_id":7}]"#;
371 let facts = parse_extraction_response(json);
372 assert_eq!(facts[0].supersedes_id, Some(7));
373 assert_eq!(facts[0].action, "update");
374 }
375
376 #[tokio::test]
379 async fn contradiction_second_fact_supersedes_first() {
380 let store = make_store().await;
381 let old_id = store
382 .upsert_persona_fact("preference", "I prefer light mode", 0.8, None, None)
383 .await
384 .expect("old fact");
385
386 store
387 .upsert_persona_fact("preference", "I prefer dark mode", 0.9, None, Some(old_id))
388 .await
389 .expect("new fact");
390
391 let facts = store.load_persona_facts(0.0).await.expect("load");
392 assert_eq!(facts.len(), 1, "superseded fact should be excluded");
393 assert_eq!(facts[0].content, "I prefer dark mode");
394 }
395}