1use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use zeph_llm::LlmError;
7use zeph_llm::any::AnyProvider;
8use zeph_llm::provider::{Message, Role};
9
10use crate::error::MemoryError;
11
12const SYSTEM_PROMPT: &str = "\
13You are an entity and relationship extractor. Given a conversation message and \
14its recent context, extract structured knowledge as JSON.
15
16Rules:
171. Extract only entities that appear in natural conversational text — user statements, \
18preferences, opinions, or factual claims made by a person.
192. Do NOT extract entities from: tool outputs, command results, file contents, \
20configuration files, JSON/TOML/YAML data, code snippets, or error messages. \
21If the message is structured data or raw command output, return empty arrays.
223. Do NOT extract structural data: config keys, file paths, tool names, TOML/JSON keys, \
23programming keywords, or single-letter identifiers.
244. Entity types must be one of: person, project, tool, language, organization, concept. \
25\"tool\" covers frameworks, software tools, and libraries. \
26\"language\" covers programming and natural languages. \
27\"concept\" covers abstract ideas, methodologies, and practices.
285. Only extract entities with clear semantic meaning about people, projects, or domain knowledge.
296. Entity names must be at least 3 characters long. Reject single characters, two-letter \
30tokens (e.g. standalone \"go\", \"cd\"), URLs, and bare file paths.
317. Relations should be short verb phrases: \"prefers\", \"uses\", \"works_on\", \"knows\", \
32\"created\", \"depends_on\", \"replaces\", \"configured_with\".
338. The \"fact\" field is a human-readable sentence summarizing the relationship.
349. If a message contains a temporal change (e.g., \"switched from X to Y\"), include a \
35temporal_hint like \"replaced X\" or \"since January 2026\".
3610. Each edge must include an \"edge_type\" field classifying the relationship:
37 - \"semantic\": conceptual relationships (uses, prefers, knows, works_on, depends_on, created)
38 - \"temporal\": time-ordered events (preceded_by, followed_by, happened_during, started_before)
39 - \"causal\": cause-effect chains (caused, triggered, resulted_in, led_to, prevented)
40 - \"entity\": identity/structural relationships (is_a, part_of, instance_of, alias_of, replaces)
41 Default to \"semantic\" if the relationship type is uncertain.
4211. Each edge must include a \"confidence\" field: a float in [0.0, 1.0] reflecting how \
43certain you are that this relationship was explicitly stated or strongly implied by the text. \
44Use 1.0 only for direct verbatim statements. Use 0.5–0.8 for clear implications. \
45Use 0.3–0.5 for weak inferences. Omit or use null if you cannot assign a meaningful score.
4611. Do not extract entities from greetings, filler, or meta-conversation (\"hi\", \"thanks\", \"ok\").
4712. Do not extract personal identifiable information as entity names: email addresses, \
48phone numbers, physical addresses, SSNs, or API keys. Use generic references instead.
4913. Always output entity names and relation verbs in English. Translate if needed.
5014. Return empty arrays if no entities or relationships are present.
51
52Output JSON schema:
53{
54 \"entities\": [
55 {\"name\": \"string\", \"type\": \"person|project|tool|language|organization|concept\", \"summary\": \"optional string\"}
56 ],
57 \"edges\": [
58 {\"source\": \"entity name\", \"target\": \"entity name\", \"relation\": \"verb phrase\", \"fact\": \"human-readable sentence\", \"temporal_hint\": \"optional string\", \"edge_type\": \"semantic|temporal|causal|entity\", \"confidence\": 0.0}
59 ]
60}";
61
62#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
63pub struct ExtractionResult {
64 pub entities: Vec<ExtractedEntity>,
65 pub edges: Vec<ExtractedEdge>,
66}
67
68#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
69pub struct ExtractedEntity {
70 pub name: String,
71 #[serde(rename = "type")]
72 pub entity_type: String,
73 pub summary: Option<String>,
74}
75
76fn default_semantic() -> String {
77 "semantic".to_owned()
78}
79
80#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
81pub struct ExtractedEdge {
82 pub source: String,
83 pub target: String,
84 pub relation: String,
85 pub fact: String,
86 pub temporal_hint: Option<String>,
87 #[serde(default = "default_semantic")]
89 pub edge_type: String,
90 #[serde(default)]
97 pub confidence: Option<f32>,
98}
99
100pub struct GraphExtractor {
101 provider: AnyProvider,
102 max_entities: usize,
103 max_edges: usize,
104 llm_timeout_secs: u64,
105}
106
107impl GraphExtractor {
108 #[must_use]
109 pub fn new(
110 provider: AnyProvider,
111 max_entities: usize,
112 max_edges: usize,
113 llm_timeout_secs: u64,
114 ) -> Self {
115 Self {
116 provider,
117 max_entities,
118 max_edges,
119 llm_timeout_secs,
120 }
121 }
122
123 #[tracing::instrument(name = "memory.graph.extract", skip_all, level = "debug", err)]
133 pub async fn extract(
134 &self,
135 message: &str,
136 context_messages: &[&str],
137 ) -> Result<Option<ExtractionResult>, MemoryError> {
138 if message.trim().is_empty() {
139 return Ok(None);
140 }
141
142 let user_prompt = build_user_prompt(message, context_messages);
143 let messages = [
144 Message::from_legacy(Role::System, SYSTEM_PROMPT),
145 Message::from_legacy(Role::User, user_prompt),
146 ];
147
148 match tokio::time::timeout(
149 std::time::Duration::from_secs(self.llm_timeout_secs),
150 self.provider
151 .chat_typed_erased::<ExtractionResult>(&messages),
152 )
153 .await
154 {
155 Err(_elapsed) => {
156 let t = self.llm_timeout_secs;
157 tracing::warn!("graph_extractor: extract LLM call timed out after {t}s");
158 return Ok(None);
159 }
160 Ok(Ok(mut result)) => {
161 result.entities.truncate(self.max_entities);
162 result.edges.truncate(self.max_edges);
163 return Ok(Some(result));
164 }
165 Ok(Err(LlmError::StructuredParse(msg))) => {
166 tracing::warn!(
167 "graph extraction: LLM returned unparseable output (len={}): {:.200}",
168 msg.len(),
169 msg
170 );
171 return Ok(None);
172 }
173 Ok(Err(other)) => return Err(MemoryError::Llm(other)),
174 }
175 }
176}
177
178fn build_user_prompt(message: &str, context_messages: &[&str]) -> String {
179 if context_messages.is_empty() {
180 format!("Current message:\n{message}\n\nExtract entities and relationships as JSON.")
181 } else {
182 let n = context_messages.len();
183 let context = context_messages.join("\n");
184 format!(
185 "Context (last {n} messages):\n{context}\n\nCurrent message:\n{message}\n\nExtract entities and relationships as JSON."
186 )
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 fn make_entity(name: &str, entity_type: &str, summary: Option<&str>) -> ExtractedEntity {
195 ExtractedEntity {
196 name: name.into(),
197 entity_type: entity_type.into(),
198 summary: summary.map(Into::into),
199 }
200 }
201
202 fn make_edge(
203 source: &str,
204 target: &str,
205 relation: &str,
206 fact: &str,
207 temporal_hint: Option<&str>,
208 ) -> ExtractedEdge {
209 ExtractedEdge {
210 source: source.into(),
211 target: target.into(),
212 relation: relation.into(),
213 fact: fact.into(),
214 temporal_hint: temporal_hint.map(Into::into),
215 edge_type: "semantic".into(),
216 confidence: None,
217 }
218 }
219
220 #[test]
221 fn extraction_result_deserialize_valid_json() {
222 let json = r#"{"entities":[{"name":"Rust","type":"language","summary":"A systems language"}],"edges":[]}"#;
223 let result: ExtractionResult = serde_json::from_str(json).unwrap();
224 assert_eq!(result.entities.len(), 1);
225 assert_eq!(result.entities[0].name, "Rust");
226 assert_eq!(result.entities[0].entity_type, "language");
227 assert_eq!(
228 result.entities[0].summary.as_deref(),
229 Some("A systems language")
230 );
231 assert!(result.edges.is_empty());
232 }
233
234 #[test]
235 fn extraction_result_deserialize_empty_arrays() {
236 let json = r#"{"entities":[],"edges":[]}"#;
237 let result: ExtractionResult = serde_json::from_str(json).unwrap();
238 assert!(result.entities.is_empty());
239 assert!(result.edges.is_empty());
240 }
241
242 #[test]
243 fn extraction_result_deserialize_missing_optional_fields() {
244 let json = r#"{"entities":[{"name":"Alice","type":"person","summary":null}],"edges":[{"source":"Alice","target":"Rust","relation":"uses","fact":"Alice uses Rust","temporal_hint":null}]}"#;
245 let result: ExtractionResult = serde_json::from_str(json).unwrap();
246 assert_eq!(result.entities[0].summary, None);
247 assert_eq!(result.edges[0].temporal_hint, None);
248 assert_eq!(result.edges[0].edge_type, "semantic");
250 }
251
252 #[test]
253 fn extracted_edge_type_defaults_to_semantic_when_missing() {
254 let json = r#"{"source":"A","target":"B","relation":"uses","fact":"A uses B"}"#;
256 let edge: ExtractedEdge = serde_json::from_str(json).unwrap();
257 assert_eq!(edge.edge_type, "semantic");
258 }
259
260 #[test]
261 fn extracted_edge_type_parses_all_variants() {
262 for et in &["semantic", "temporal", "causal", "entity"] {
263 let json = format!(
264 r#"{{"source":"A","target":"B","relation":"r","fact":"f","edge_type":"{et}"}}"#
265 );
266 let edge: ExtractedEdge = serde_json::from_str(&json).unwrap();
267 assert_eq!(&edge.edge_type, et);
268 }
269 }
270
271 #[test]
272 fn extraction_result_with_edge_types_roundtrip() {
273 let original = ExtractionResult {
274 entities: vec![],
275 edges: vec![
276 ExtractedEdge {
277 source: "A".into(),
278 target: "B".into(),
279 relation: "caused".into(),
280 fact: "A caused B".into(),
281 temporal_hint: None,
282 edge_type: "causal".into(),
283 confidence: Some(0.9),
284 },
285 ExtractedEdge {
286 source: "B".into(),
287 target: "C".into(),
288 relation: "preceded_by".into(),
289 fact: "B preceded_by C".into(),
290 temporal_hint: None,
291 edge_type: "temporal".into(),
292 confidence: None,
293 },
294 ],
295 };
296 let json = serde_json::to_string(&original).unwrap();
297 let restored: ExtractionResult = serde_json::from_str(&json).unwrap();
298 assert_eq!(original, restored);
299 assert_eq!(restored.edges[0].edge_type, "causal");
300 assert_eq!(restored.edges[1].edge_type, "temporal");
301 }
302
303 #[test]
304 fn extracted_entity_type_field_rename() {
305 let json = r#"{"name":"cargo","type":"tool","summary":null}"#;
306 let entity: ExtractedEntity = serde_json::from_str(json).unwrap();
307 assert_eq!(entity.entity_type, "tool");
308
309 let serialized = serde_json::to_string(&entity).unwrap();
310 assert!(serialized.contains("\"type\""));
311 assert!(!serialized.contains("\"entity_type\""));
312 }
313
314 #[test]
315 fn extraction_result_roundtrip() {
316 let original = ExtractionResult {
317 entities: vec![make_entity("Rust", "language", Some("A systems language"))],
318 edges: vec![make_edge("Alice", "Rust", "uses", "Alice uses Rust", None)],
319 };
320 let json = serde_json::to_string(&original).unwrap();
321 let restored: ExtractionResult = serde_json::from_str(&json).unwrap();
322 assert_eq!(original, restored);
323 }
324
325 #[test]
326 fn extraction_result_json_schema() {
327 let schema = schemars::schema_for!(ExtractionResult);
328 let value = serde_json::to_value(&schema).unwrap();
329 let schema_obj = value.as_object().unwrap();
330 assert!(
331 schema_obj.contains_key("title") || schema_obj.contains_key("properties"),
332 "schema should have top-level keys"
333 );
334 let json_str = serde_json::to_string(&schema).unwrap();
335 assert!(
336 json_str.contains("entities"),
337 "schema should contain 'entities'"
338 );
339 assert!(json_str.contains("edges"), "schema should contain 'edges'");
340 }
341
342 #[test]
343 fn build_user_prompt_with_context() {
344 let prompt = build_user_prompt("Hello Rust", &["prev message 1", "prev message 2"]);
345 assert!(prompt.contains("Context (last 2 messages):"));
346 assert!(prompt.contains("prev message 1\nprev message 2"));
347 assert!(prompt.contains("Current message:\nHello Rust"));
348 assert!(prompt.contains("Extract entities and relationships as JSON."));
349 }
350
351 #[test]
352 fn build_user_prompt_without_context() {
353 let prompt = build_user_prompt("Hello Rust", &[]);
354 assert!(!prompt.contains("Context"));
355 assert!(prompt.contains("Current message:\nHello Rust"));
356 assert!(prompt.contains("Extract entities and relationships as JSON."));
357 }
358
359 mod mock_tests {
360 use super::*;
361 use zeph_llm::mock::MockProvider;
362
363 fn make_entities_json(count: usize) -> String {
364 let entities: Vec<String> = (0..count)
365 .map(|i| format!(r#"{{"name":"entity{i}","type":"concept","summary":null}}"#))
366 .collect();
367 format!(r#"{{"entities":[{}],"edges":[]}}"#, entities.join(","))
368 }
369
370 fn make_edges_json(count: usize) -> String {
371 let edges: Vec<String> = (0..count)
372 .map(|i| {
373 format!(
374 r#"{{"source":"A","target":"B{i}","relation":"uses","fact":"A uses B{i}","temporal_hint":null}}"#
375 )
376 })
377 .collect();
378 format!(r#"{{"entities":[],"edges":[{}]}}"#, edges.join(","))
379 }
380
381 #[tokio::test]
382 async fn extract_truncates_to_max_entities() {
383 let json = make_entities_json(20);
384 let mock = MockProvider::with_responses(vec![json]);
385 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 5, 100, 30);
386 let result = extractor.extract("test message", &[]).await.unwrap();
387 let result = result.unwrap();
388 assert_eq!(result.entities.len(), 5);
389 }
390
391 #[tokio::test]
392 async fn extract_truncates_to_max_edges() {
393 let json = make_edges_json(15);
394 let mock = MockProvider::with_responses(vec![json]);
395 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 100, 3, 30);
396 let result = extractor.extract("test message", &[]).await.unwrap();
397 let result = result.unwrap();
398 assert_eq!(result.edges.len(), 3);
399 }
400
401 #[tokio::test]
402 async fn extract_returns_none_on_parse_failure() {
403 let mock = MockProvider::with_responses(vec!["not valid json at all".into()]);
404 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10, 30);
405 let result = extractor.extract("test message", &[]).await.unwrap();
406 assert!(result.is_none());
407 }
408
409 #[tokio::test]
410 async fn extract_returns_err_on_transport_failure() {
411 let mock = MockProvider::default()
412 .with_errors(vec![zeph_llm::LlmError::Other("connection refused".into())]);
413 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10, 30);
414 let result = extractor.extract("test message", &[]).await;
415 assert!(result.is_err());
416 assert!(matches!(result.unwrap_err(), MemoryError::Llm(_)));
417 }
418
419 #[test]
420 fn graph_extractor_stores_custom_llm_timeout() {
421 let extractor = GraphExtractor::new(
422 zeph_llm::any::AnyProvider::Mock(MockProvider::default()),
423 10,
424 5,
425 42,
426 );
427 assert_eq!(extractor.llm_timeout_secs, 42);
428 }
429
430 #[tokio::test]
431 async fn extract_returns_none_on_empty_message() {
432 let mock = MockProvider::with_responses(vec!["should not be called".into()]);
433 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10, 30);
434
435 let result_empty = extractor.extract("", &[]).await.unwrap();
436 assert!(result_empty.is_none());
437
438 let result_whitespace = extractor.extract(" \t\n ", &[]).await.unwrap();
439 assert!(result_whitespace.is_none());
440 }
441 }
442}