1use std::time::Duration;
13
14use serde::{Deserialize, Serialize};
15use tokio::time::timeout;
16use zeph_llm::any::AnyProvider;
17use zeph_llm::provider::{LlmProvider as _, Message, Role};
18
19use crate::error::MemoryError;
20use crate::store::DbStore;
21use crate::store::persona::PersonaFactRow;
22
23const EXTRACTION_SYSTEM_PROMPT: &str = "\
24You are a persona fact extractor. Given a list of user messages and any existing persona \
25facts for each category, extract factual claims the user makes about themselves: their \
26preferences, domain knowledge, working style, communication style, and background.
27
28Rules:
291. Only extract facts from first-person user statements (\"I prefer\", \"I work on\", \
30 \"my team\", \"I use\", etc.). Ignore assistant messages.
312. Do NOT extract facts from questions, greetings, or tool outputs.
323. For each extracted fact, decide if it is NEW (no existing fact contradicts it) or \
33 UPDATE (it contradicts or replaces an existing fact). For UPDATE, provide the \
34 `supersedes_id` of the older fact.
354. Confidence: 0.8-1.0 for explicit statements (\"I prefer X\"), 0.4-0.7 for inferences.
365. Categories: preference, domain_knowledge, working_style, communication, background.
376. Keep content concise (one sentence max). Normalize to English.
387. Return empty array if no facts are found.
39
40Output JSON array of objects:
41[
42 {
43 \"category\": \"preference|domain_knowledge|working_style|communication|background\",
44 \"content\": \"concise factual statement\",
45 \"confidence\": 0.0-1.0,
46 \"action\": \"new|update\",
47 \"supersedes_id\": null or integer id of the fact being replaced
48 }
49]";
50
51pub struct PersonaExtractionConfig {
53 pub enabled: bool,
54 pub min_messages: usize,
56 pub max_messages: usize,
58 pub extraction_timeout_secs: u64,
60}
61
62#[derive(Debug, Deserialize, Serialize)]
63struct ExtractedFact {
64 category: String,
65 content: String,
66 confidence: f64,
67 action: String,
68 supersedes_id: Option<i64>,
69}
70
71#[must_use]
74pub fn contains_self_referential_language(text: &str) -> bool {
75 let lower = text.to_lowercase();
78 let tokens = [" i ", " i'", " my ", " me ", " mine ", "i am ", "i'm "];
79 tokens.iter().any(|t| lower.contains(t)) || lower.starts_with("i ") || lower.starts_with("my ")
80}
81
82#[cfg_attr(
91 feature = "profiling",
92 tracing::instrument(name = "memory.persona_extract", skip_all, fields(fact_count = tracing::field::Empty))
93)]
94pub async fn extract_persona_facts(
95 store: &DbStore,
96 provider: &AnyProvider,
97 user_messages: &[&str],
98 config: &PersonaExtractionConfig,
99 conversation_id: Option<i64>,
100) -> Result<usize, MemoryError> {
101 if !config.enabled || user_messages.len() < config.min_messages {
102 return Ok(0);
103 }
104
105 let has_self_ref = user_messages
107 .iter()
108 .any(|m| contains_self_referential_language(m));
109 if !has_self_ref {
110 return Ok(0);
111 }
112
113 let messages_to_send: Vec<&str> = user_messages
114 .iter()
115 .rev()
116 .take(config.max_messages)
117 .copied()
118 .collect::<Vec<_>>()
119 .into_iter()
120 .rev()
121 .collect();
122
123 let existing_facts = store.load_persona_facts(0.0).await?;
125 let user_prompt = build_extraction_prompt(&messages_to_send, &existing_facts);
126
127 let llm_messages = [
128 Message::from_legacy(Role::System, EXTRACTION_SYSTEM_PROMPT),
129 Message::from_legacy(Role::User, user_prompt),
130 ];
131
132 let extraction_timeout = Duration::from_secs(config.extraction_timeout_secs);
133 let response = match timeout(extraction_timeout, provider.chat(&llm_messages)).await {
134 Ok(Ok(text)) => text,
135 Ok(Err(e)) => return Err(MemoryError::Llm(e)),
136 Err(_) => {
137 tracing::warn!(
138 "persona extraction timed out after {}s",
139 config.extraction_timeout_secs
140 );
141 return Ok(0);
142 }
143 };
144
145 let facts = parse_extraction_response(&response);
146 if facts.is_empty() {
147 return Ok(0);
148 }
149
150 let mut upserted = 0usize;
151 for fact in facts {
152 if fact.category.is_empty() || fact.content.is_empty() {
153 continue;
154 }
155 if !is_valid_category(&fact.category) {
156 tracing::debug!(
157 category = %fact.category,
158 "persona extraction: skipping unknown category"
159 );
160 continue;
161 }
162 match store
163 .upsert_persona_fact(
164 &fact.category,
165 &fact.content,
166 fact.confidence.clamp(0.0, 1.0),
167 conversation_id,
168 fact.supersedes_id,
169 )
170 .await
171 {
172 Ok(_) => upserted += 1,
173 Err(e) => {
174 tracing::warn!(error = %e, "persona extraction: failed to upsert fact");
175 }
176 }
177 }
178
179 tracing::debug!(upserted, "persona extraction complete");
180 #[cfg(feature = "profiling")]
181 tracing::Span::current().record("fact_count", upserted);
182 Ok(upserted)
183}
184
185fn is_valid_category(category: &str) -> bool {
186 matches!(
187 category,
188 "preference" | "domain_knowledge" | "working_style" | "communication" | "background"
189 )
190}
191
192fn build_extraction_prompt(messages: &[&str], existing_facts: &[PersonaFactRow]) -> String {
193 let mut prompt = String::from("User messages to analyze:\n");
194 for (i, msg) in messages.iter().enumerate() {
195 use std::fmt::Write as _;
196 let _ = writeln!(prompt, "[{}] {}", i + 1, msg);
197 }
198
199 if !existing_facts.is_empty() {
200 prompt.push_str("\nExisting persona facts (for contradiction detection):\n");
201 for fact in existing_facts {
202 use std::fmt::Write as _;
203 let _ = writeln!(
204 prompt,
205 " id={} category={} content=\"{}\"",
206 fact.id, fact.category, fact.content
207 );
208 }
209 }
210
211 prompt.push_str("\nExtract persona facts as JSON array.");
212 prompt
213}
214
215fn parse_extraction_response(response: &str) -> Vec<ExtractedFact> {
216 if let Ok(facts) = serde_json::from_str::<Vec<ExtractedFact>>(response) {
218 return facts;
219 }
220
221 if let (Some(start), Some(end)) = (response.find('['), response.rfind(']'))
223 && end > start
224 {
225 let slice = &response[start..=end];
226 if let Ok(facts) = serde_json::from_str::<Vec<ExtractedFact>>(slice) {
227 return facts;
228 }
229 }
230
231 tracing::warn!(
232 "persona extraction: failed to parse LLM response (len={}): {:.200}",
233 response.len(),
234 response
235 );
236 Vec::new()
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use crate::store::DbStore;
243
244 async fn make_store() -> DbStore {
245 DbStore::with_pool_size(":memory:", 1)
246 .await
247 .expect("in-memory store")
248 }
249
250 #[test]
253 fn self_ref_detects_i_prefer() {
254 assert!(contains_self_referential_language("I prefer dark mode"));
255 }
256
257 #[test]
258 fn self_ref_detects_my_team() {
259 assert!(contains_self_referential_language("my team uses Rust"));
260 }
261
262 #[test]
263 fn self_ref_detects_sentence_starting_with_i() {
264 assert!(contains_self_referential_language("I work remotely"));
265 }
266
267 #[test]
268 fn self_ref_detects_inline_i() {
269 assert!(contains_self_referential_language(
270 "Sometimes I prefer async"
271 ));
272 }
273
274 #[test]
275 fn self_ref_detects_me_inline() {
276 assert!(contains_self_referential_language(
277 "That helps me understand"
278 ));
279 }
280
281 #[test]
282 fn self_ref_no_match_for_third_person() {
283 assert!(!contains_self_referential_language("The team uses Python"));
284 }
285
286 #[test]
287 fn self_ref_no_match_for_tool_output() {
288 assert!(!contains_self_referential_language("Error: file not found"));
289 }
290
291 #[test]
292 fn self_ref_no_match_for_empty_string() {
293 assert!(!contains_self_referential_language(""));
294 }
295
296 #[tokio::test]
299 async fn extraction_gate_skips_when_no_self_ref() {
300 let store = make_store().await;
301 let cfg = PersonaExtractionConfig {
307 enabled: true,
308 min_messages: 1,
309 max_messages: 10,
310 extraction_timeout_secs: 5,
311 };
312 let cfg_disabled = PersonaExtractionConfig {
317 enabled: false,
318 min_messages: 1,
319 max_messages: 10,
320 extraction_timeout_secs: 5,
321 };
322 let cfg_min = PersonaExtractionConfig {
326 enabled: true,
327 min_messages: 5,
328 max_messages: 10,
329 extraction_timeout_secs: 5,
330 };
331 let messages: Vec<&str> = vec![];
336 assert!(messages.len() < cfg_min.min_messages);
337 let _ = (store, cfg, cfg_disabled, cfg_min); }
339
340 #[test]
343 fn parse_direct_json_array() {
344 let json = r#"[{"category":"preference","content":"I prefer dark mode","confidence":0.9,"action":"new","supersedes_id":null}]"#;
345 let facts = parse_extraction_response(json);
346 assert_eq!(facts.len(), 1);
347 assert_eq!(facts[0].category, "preference");
348 assert_eq!(facts[0].content, "I prefer dark mode");
349 assert!((facts[0].confidence - 0.9).abs() < 1e-9);
350 assert_eq!(facts[0].action, "new");
351 assert!(facts[0].supersedes_id.is_none());
352 }
353
354 #[test]
355 fn parse_json_embedded_in_prose() {
356 let response = "Sure! Here are the facts:\n[{\"category\":\"domain_knowledge\",\"content\":\"Uses Rust\",\"confidence\":0.8,\"action\":\"new\",\"supersedes_id\":null}]\nThat's all.";
357 let facts = parse_extraction_response(response);
358 assert_eq!(facts.len(), 1);
359 assert_eq!(facts[0].category, "domain_knowledge");
360 }
361
362 #[test]
363 fn parse_empty_array() {
364 let facts = parse_extraction_response("[]");
365 assert!(facts.is_empty());
366 }
367
368 #[test]
369 fn parse_invalid_json_returns_empty() {
370 let facts = parse_extraction_response("not json at all");
371 assert!(facts.is_empty());
372 }
373
374 #[test]
375 fn parse_supersedes_id_populated() {
376 let json = r#"[{"category":"preference","content":"I prefer dark mode","confidence":0.9,"action":"update","supersedes_id":7}]"#;
377 let facts = parse_extraction_response(json);
378 assert_eq!(facts[0].supersedes_id, Some(7));
379 assert_eq!(facts[0].action, "update");
380 }
381
382 #[tokio::test]
385 async fn contradiction_second_fact_supersedes_first() {
386 let store = make_store().await;
387 let old_id = store
388 .upsert_persona_fact("preference", "I prefer light mode", 0.8, None, None)
389 .await
390 .expect("old fact");
391
392 store
393 .upsert_persona_fact("preference", "I prefer dark mode", 0.9, None, Some(old_id))
394 .await
395 .expect("new fact");
396
397 let facts = store.load_persona_facts(0.0).await.expect("load");
398 assert_eq!(facts.len(), 1, "superseded fact should be excluded");
399 assert_eq!(facts[0].content, "I prefer dark mode");
400 }
401}