1use std::time::Duration;
13
14use serde::{Deserialize, Serialize};
15use tokio::time::timeout;
16use zeph_llm::any::AnyProvider;
17use zeph_llm::provider::{Message, Role};
18
19use crate::error::MemoryError;
20use crate::store::DbStore;
21use crate::store::persona::PersonaFactRow;
22
23const EXTRACTION_SYSTEM_PROMPT: &str = "\
24You are a persona fact extractor. Given a list of user messages and any existing persona \
25facts for each category, extract factual claims the user makes about themselves: their \
26preferences, domain knowledge, working style, communication style, and background.
27
28Rules:
291. Only extract facts from first-person user statements (\"I prefer\", \"I work on\", \
30 \"my team\", \"I use\", etc.). Ignore assistant messages.
312. Do NOT extract facts from questions, greetings, or tool outputs.
323. For each extracted fact, decide if it is NEW (no existing fact contradicts it) or \
33 UPDATE (it contradicts or replaces an existing fact). For UPDATE, provide the \
34 `supersedes_id` of the older fact.
354. Confidence: 0.8-1.0 for explicit statements (\"I prefer X\"), 0.4-0.7 for inferences.
365. Categories: preference, domain_knowledge, working_style, communication, background.
376. Keep content concise (one sentence max). Normalize to English.
387. Return empty array if no facts are found.
39
40Output JSON array of objects:
41[
42 {
43 \"category\": \"preference|domain_knowledge|working_style|communication|background\",
44 \"content\": \"concise factual statement\",
45 \"confidence\": 0.0-1.0,
46 \"action\": \"new|update\",
47 \"supersedes_id\": null or integer id of the fact being replaced
48 }
49]";
50
51pub struct PersonaExtractionConfig {
53 pub enabled: bool,
54 pub persona_provider: String,
56 pub min_messages: usize,
58 pub max_messages: usize,
60 pub extraction_timeout_secs: u64,
62}
63
64#[derive(Debug, Deserialize, Serialize)]
65struct ExtractedFact {
66 category: String,
67 content: String,
68 confidence: f64,
69 action: String,
70 supersedes_id: Option<i64>,
71}
72
73#[must_use]
76pub fn contains_self_referential_language(text: &str) -> bool {
77 let lower = text.to_lowercase();
80 let tokens = [" i ", " i'", " my ", " me ", " mine ", "i am ", "i'm "];
81 tokens.iter().any(|t| lower.contains(t)) || lower.starts_with("i ") || lower.starts_with("my ")
82}
83
84pub async fn extract_persona_facts(
93 store: &DbStore,
94 provider: &AnyProvider,
95 user_messages: &[&str],
96 config: &PersonaExtractionConfig,
97 conversation_id: Option<i64>,
98) -> Result<usize, MemoryError> {
99 if !config.enabled || user_messages.len() < config.min_messages {
100 return Ok(0);
101 }
102
103 let has_self_ref = user_messages
105 .iter()
106 .any(|m| contains_self_referential_language(m));
107 if !has_self_ref {
108 return Ok(0);
109 }
110
111 let messages_to_send: Vec<&str> = user_messages
112 .iter()
113 .rev()
114 .take(config.max_messages)
115 .copied()
116 .collect::<Vec<_>>()
117 .into_iter()
118 .rev()
119 .collect();
120
121 let existing_facts = store.load_persona_facts(0.0).await?;
123 let user_prompt = build_extraction_prompt(&messages_to_send, &existing_facts);
124
125 let llm_messages = [
126 Message::from_legacy(Role::System, EXTRACTION_SYSTEM_PROMPT),
127 Message::from_legacy(Role::User, user_prompt),
128 ];
129
130 let extraction_timeout = Duration::from_secs(config.extraction_timeout_secs);
131 let provider_name = if config.persona_provider.is_empty() {
133 "persona"
134 } else {
135 &config.persona_provider
136 };
137 let response = match timeout(
138 extraction_timeout,
139 provider.chat_with_named_provider(provider_name, &llm_messages),
140 )
141 .await
142 {
143 Ok(Ok(text)) => text,
144 Ok(Err(e)) => return Err(MemoryError::Llm(e)),
145 Err(_) => {
146 tracing::warn!(
147 "persona extraction timed out after {}s",
148 config.extraction_timeout_secs
149 );
150 return Ok(0);
151 }
152 };
153
154 let facts = parse_extraction_response(&response);
155 if facts.is_empty() {
156 return Ok(0);
157 }
158
159 let mut upserted = 0usize;
160 for fact in facts {
161 if fact.category.is_empty() || fact.content.is_empty() {
162 continue;
163 }
164 if !is_valid_category(&fact.category) {
165 tracing::debug!(
166 category = %fact.category,
167 "persona extraction: skipping unknown category"
168 );
169 continue;
170 }
171 match store
172 .upsert_persona_fact(
173 &fact.category,
174 &fact.content,
175 fact.confidence.clamp(0.0, 1.0),
176 conversation_id,
177 fact.supersedes_id,
178 )
179 .await
180 {
181 Ok(_) => upserted += 1,
182 Err(e) => {
183 tracing::warn!(error = %e, "persona extraction: failed to upsert fact");
184 }
185 }
186 }
187
188 tracing::debug!(upserted, "persona extraction complete");
189 Ok(upserted)
190}
191
192fn is_valid_category(category: &str) -> bool {
193 matches!(
194 category,
195 "preference" | "domain_knowledge" | "working_style" | "communication" | "background"
196 )
197}
198
199fn build_extraction_prompt(messages: &[&str], existing_facts: &[PersonaFactRow]) -> String {
200 let mut prompt = String::from("User messages to analyze:\n");
201 for (i, msg) in messages.iter().enumerate() {
202 use std::fmt::Write as _;
203 let _ = writeln!(prompt, "[{}] {}", i + 1, msg);
204 }
205
206 if !existing_facts.is_empty() {
207 prompt.push_str("\nExisting persona facts (for contradiction detection):\n");
208 for fact in existing_facts {
209 use std::fmt::Write as _;
210 let _ = writeln!(
211 prompt,
212 " id={} category={} content=\"{}\"",
213 fact.id, fact.category, fact.content
214 );
215 }
216 }
217
218 prompt.push_str("\nExtract persona facts as JSON array.");
219 prompt
220}
221
222fn parse_extraction_response(response: &str) -> Vec<ExtractedFact> {
223 if let Ok(facts) = serde_json::from_str::<Vec<ExtractedFact>>(response) {
225 return facts;
226 }
227
228 if let (Some(start), Some(end)) = (response.find('['), response.rfind(']'))
230 && end > start
231 {
232 let slice = &response[start..=end];
233 if let Ok(facts) = serde_json::from_str::<Vec<ExtractedFact>>(slice) {
234 return facts;
235 }
236 }
237
238 tracing::warn!(
239 "persona extraction: failed to parse LLM response (len={}): {:.200}",
240 response.len(),
241 response
242 );
243 Vec::new()
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use crate::store::DbStore;
250
251 async fn make_store() -> DbStore {
252 DbStore::with_pool_size(":memory:", 1)
253 .await
254 .expect("in-memory store")
255 }
256
257 #[test]
260 fn self_ref_detects_i_prefer() {
261 assert!(contains_self_referential_language("I prefer dark mode"));
262 }
263
264 #[test]
265 fn self_ref_detects_my_team() {
266 assert!(contains_self_referential_language("my team uses Rust"));
267 }
268
269 #[test]
270 fn self_ref_detects_sentence_starting_with_i() {
271 assert!(contains_self_referential_language("I work remotely"));
272 }
273
274 #[test]
275 fn self_ref_detects_inline_i() {
276 assert!(contains_self_referential_language(
277 "Sometimes I prefer async"
278 ));
279 }
280
281 #[test]
282 fn self_ref_detects_me_inline() {
283 assert!(contains_self_referential_language(
284 "That helps me understand"
285 ));
286 }
287
288 #[test]
289 fn self_ref_no_match_for_third_person() {
290 assert!(!contains_self_referential_language("The team uses Python"));
291 }
292
293 #[test]
294 fn self_ref_no_match_for_tool_output() {
295 assert!(!contains_self_referential_language("Error: file not found"));
296 }
297
298 #[test]
299 fn self_ref_no_match_for_empty_string() {
300 assert!(!contains_self_referential_language(""));
301 }
302
303 #[tokio::test]
306 async fn extraction_gate_skips_when_no_self_ref() {
307 let store = make_store().await;
308 let cfg = PersonaExtractionConfig {
314 enabled: true,
315 persona_provider: String::new(),
316 min_messages: 1,
317 max_messages: 10,
318 extraction_timeout_secs: 5,
319 };
320 let cfg_disabled = PersonaExtractionConfig {
325 enabled: false,
326 persona_provider: String::new(),
327 min_messages: 1,
328 max_messages: 10,
329 extraction_timeout_secs: 5,
330 };
331 let cfg_min = PersonaExtractionConfig {
335 enabled: true,
336 persona_provider: String::new(),
337 min_messages: 5,
338 max_messages: 10,
339 extraction_timeout_secs: 5,
340 };
341 let messages: Vec<&str> = vec![];
346 assert!(messages.len() < cfg_min.min_messages);
347 let _ = (store, cfg, cfg_disabled, cfg_min); }
349
350 #[test]
353 fn parse_direct_json_array() {
354 let json = r#"[{"category":"preference","content":"I prefer dark mode","confidence":0.9,"action":"new","supersedes_id":null}]"#;
355 let facts = parse_extraction_response(json);
356 assert_eq!(facts.len(), 1);
357 assert_eq!(facts[0].category, "preference");
358 assert_eq!(facts[0].content, "I prefer dark mode");
359 assert!((facts[0].confidence - 0.9).abs() < 1e-9);
360 assert_eq!(facts[0].action, "new");
361 assert!(facts[0].supersedes_id.is_none());
362 }
363
364 #[test]
365 fn parse_json_embedded_in_prose() {
366 let response = "Sure! Here are the facts:\n[{\"category\":\"domain_knowledge\",\"content\":\"Uses Rust\",\"confidence\":0.8,\"action\":\"new\",\"supersedes_id\":null}]\nThat's all.";
367 let facts = parse_extraction_response(response);
368 assert_eq!(facts.len(), 1);
369 assert_eq!(facts[0].category, "domain_knowledge");
370 }
371
372 #[test]
373 fn parse_empty_array() {
374 let facts = parse_extraction_response("[]");
375 assert!(facts.is_empty());
376 }
377
378 #[test]
379 fn parse_invalid_json_returns_empty() {
380 let facts = parse_extraction_response("not json at all");
381 assert!(facts.is_empty());
382 }
383
384 #[test]
385 fn parse_supersedes_id_populated() {
386 let json = r#"[{"category":"preference","content":"I prefer dark mode","confidence":0.9,"action":"update","supersedes_id":7}]"#;
387 let facts = parse_extraction_response(json);
388 assert_eq!(facts[0].supersedes_id, Some(7));
389 assert_eq!(facts[0].action, "update");
390 }
391
392 #[tokio::test]
395 async fn contradiction_second_fact_supersedes_first() {
396 let store = make_store().await;
397 let old_id = store
398 .upsert_persona_fact("preference", "I prefer light mode", 0.8, None, None)
399 .await
400 .expect("old fact");
401
402 store
403 .upsert_persona_fact("preference", "I prefer dark mode", 0.9, None, Some(old_id))
404 .await
405 .expect("new fact");
406
407 let facts = store.load_persona_facts(0.0).await.expect("load");
408 assert_eq!(facts.len(), 1, "superseded fact should be excluded");
409 assert_eq!(facts[0].content, "I prefer dark mode");
410 }
411}