Skip to main content

zeph_memory/graph/
extractor.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use zeph_llm::LlmError;
7use zeph_llm::any::AnyProvider;
8use zeph_llm::provider::{Message, Role};
9
10use crate::error::MemoryError;
11
12const SYSTEM_PROMPT: &str = "\
13You are an entity and relationship extractor. Given a conversation message and \
14its recent context, extract structured knowledge as JSON.
15
16Rules:
171. Extract only entities that appear in natural conversational text — user statements, \
18preferences, opinions, or factual claims made by a person.
192. Do NOT extract entities from: tool outputs, command results, file contents, \
20configuration files, JSON/TOML/YAML data, code snippets, or error messages. \
21If the message is structured data or raw command output, return empty arrays.
223. Do NOT extract structural data: config keys, file paths, tool names, TOML/JSON keys, \
23programming keywords, or single-letter identifiers.
244. Entity types must be one of: person, project, tool, language, organization, concept. \
25\"tool\" covers frameworks, software tools, and libraries. \
26\"language\" covers programming and natural languages. \
27\"concept\" covers abstract ideas, methodologies, and practices.
285. Only extract entities with clear semantic meaning about people, projects, or domain knowledge.
296. Entity names must be at least 3 characters long. Reject single characters, two-letter \
30tokens (e.g. standalone \"go\", \"cd\"), URLs, and bare file paths.
317. Relations should be short verb phrases: \"prefers\", \"uses\", \"works_on\", \"knows\", \
32\"created\", \"depends_on\", \"replaces\", \"configured_with\".
338. The \"fact\" field is a human-readable sentence summarizing the relationship.
349. If a message contains a temporal change (e.g., \"switched from X to Y\"), include a \
35temporal_hint like \"replaced X\" or \"since January 2026\".
3610. Each edge must include an \"edge_type\" field classifying the relationship:
37  - \"semantic\": conceptual relationships (uses, prefers, knows, works_on, depends_on, created)
38  - \"temporal\": time-ordered events (preceded_by, followed_by, happened_during, started_before)
39  - \"causal\": cause-effect chains (caused, triggered, resulted_in, led_to, prevented)
40  - \"entity\": identity/structural relationships (is_a, part_of, instance_of, alias_of, replaces)
41  Default to \"semantic\" if the relationship type is uncertain.
4211. Each edge must include a \"confidence\" field: a float in [0.0, 1.0] reflecting how \
43certain you are that this relationship was explicitly stated or strongly implied by the text. \
44Use 1.0 only for direct verbatim statements. Use 0.5–0.8 for clear implications. \
45Use 0.3–0.5 for weak inferences. Omit or use null if you cannot assign a meaningful score.
4611. Do not extract entities from greetings, filler, or meta-conversation (\"hi\", \"thanks\", \"ok\").
4712. Do not extract personal identifiable information as entity names: email addresses, \
48phone numbers, physical addresses, SSNs, or API keys. Use generic references instead.
4913. Always output entity names and relation verbs in English. Translate if needed.
5014. Return empty arrays if no entities or relationships are present.
51
52Output JSON schema:
53{
54  \"entities\": [
55    {\"name\": \"string\", \"type\": \"person|project|tool|language|organization|concept\", \"summary\": \"optional string\"}
56  ],
57  \"edges\": [
58    {\"source\": \"entity name\", \"target\": \"entity name\", \"relation\": \"verb phrase\", \"fact\": \"human-readable sentence\", \"temporal_hint\": \"optional string\", \"edge_type\": \"semantic|temporal|causal|entity\", \"confidence\": 0.0}
59  ]
60}";
61
62#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
63pub struct ExtractionResult {
64    pub entities: Vec<ExtractedEntity>,
65    pub edges: Vec<ExtractedEdge>,
66}
67
68#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
69pub struct ExtractedEntity {
70    pub name: String,
71    #[serde(rename = "type")]
72    pub entity_type: String,
73    pub summary: Option<String>,
74}
75
76fn default_semantic() -> String {
77    "semantic".to_owned()
78}
79
80#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
81pub struct ExtractedEdge {
82    pub source: String,
83    pub target: String,
84    pub relation: String,
85    pub fact: String,
86    pub temporal_hint: Option<String>,
87    /// MAGMA edge type classification. Defaults to "semantic" when omitted by the LLM.
88    #[serde(default = "default_semantic")]
89    pub edge_type: String,
90    /// Extractor confidence in the relationship, in `[0.0, 1.0]`.
91    ///
92    /// Assigned by the LLM during extraction. `None` means the LLM omitted the field;
93    /// callers should treat `None` as `1.0` (direct statement, commit immediately).
94    /// Values below `BeliefMemConfig::promote_threshold` route the edge to
95    /// `BeliefStore` for evidence accumulation instead of immediate commit.
96    #[serde(default)]
97    pub confidence: Option<f32>,
98}
99
100pub struct GraphExtractor {
101    provider: AnyProvider,
102    max_entities: usize,
103    max_edges: usize,
104    llm_timeout_secs: u64,
105}
106
107impl GraphExtractor {
108    #[must_use]
109    pub fn new(
110        provider: AnyProvider,
111        max_entities: usize,
112        max_edges: usize,
113        llm_timeout_secs: u64,
114    ) -> Self {
115        Self {
116            provider,
117            max_entities,
118            max_edges,
119            llm_timeout_secs,
120        }
121    }
122
123    /// Extract entities and relations from a message with surrounding context.
124    ///
125    /// Returns `None` if the message is empty, extraction fails, or the LLM returns
126    /// unparseable output. Callers should treat `None` as a graceful degradation.
127    ///
128    /// # Errors
129    ///
130    /// Returns an error only for transport-level failures (network, auth).
131    /// JSON parse failures are logged and return `Ok(None)`.
132    #[tracing::instrument(name = "memory.graph.extract", skip_all, level = "debug", err)]
133    pub async fn extract(
134        &self,
135        message: &str,
136        context_messages: &[&str],
137    ) -> Result<Option<ExtractionResult>, MemoryError> {
138        if message.trim().is_empty() {
139            return Ok(None);
140        }
141
142        let user_prompt = build_user_prompt(message, context_messages);
143        let messages = [
144            Message::from_legacy(Role::System, SYSTEM_PROMPT),
145            Message::from_legacy(Role::User, user_prompt),
146        ];
147
148        match tokio::time::timeout(
149            std::time::Duration::from_secs(self.llm_timeout_secs),
150            self.provider
151                .chat_typed_erased::<ExtractionResult>(&messages),
152        )
153        .await
154        {
155            Err(_elapsed) => {
156                let t = self.llm_timeout_secs;
157                tracing::warn!("graph_extractor: extract LLM call timed out after {t}s");
158                return Ok(None);
159            }
160            Ok(Ok(mut result)) => {
161                result.entities.truncate(self.max_entities);
162                result.edges.truncate(self.max_edges);
163                return Ok(Some(result));
164            }
165            Ok(Err(LlmError::StructuredParse(msg))) => {
166                tracing::warn!(
167                    "graph extraction: LLM returned unparseable output (len={}): {:.200}",
168                    msg.len(),
169                    msg
170                );
171                return Ok(None);
172            }
173            Ok(Err(other)) => return Err(MemoryError::Llm(other)),
174        }
175    }
176}
177
178fn build_user_prompt(message: &str, context_messages: &[&str]) -> String {
179    if context_messages.is_empty() {
180        format!("Current message:\n{message}\n\nExtract entities and relationships as JSON.")
181    } else {
182        let n = context_messages.len();
183        let context = context_messages.join("\n");
184        format!(
185            "Context (last {n} messages):\n{context}\n\nCurrent message:\n{message}\n\nExtract entities and relationships as JSON."
186        )
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    fn make_entity(name: &str, entity_type: &str, summary: Option<&str>) -> ExtractedEntity {
195        ExtractedEntity {
196            name: name.into(),
197            entity_type: entity_type.into(),
198            summary: summary.map(Into::into),
199        }
200    }
201
202    fn make_edge(
203        source: &str,
204        target: &str,
205        relation: &str,
206        fact: &str,
207        temporal_hint: Option<&str>,
208    ) -> ExtractedEdge {
209        ExtractedEdge {
210            source: source.into(),
211            target: target.into(),
212            relation: relation.into(),
213            fact: fact.into(),
214            temporal_hint: temporal_hint.map(Into::into),
215            edge_type: "semantic".into(),
216            confidence: None,
217        }
218    }
219
220    #[test]
221    fn extraction_result_deserialize_valid_json() {
222        let json = r#"{"entities":[{"name":"Rust","type":"language","summary":"A systems language"}],"edges":[]}"#;
223        let result: ExtractionResult = serde_json::from_str(json).unwrap();
224        assert_eq!(result.entities.len(), 1);
225        assert_eq!(result.entities[0].name, "Rust");
226        assert_eq!(result.entities[0].entity_type, "language");
227        assert_eq!(
228            result.entities[0].summary.as_deref(),
229            Some("A systems language")
230        );
231        assert!(result.edges.is_empty());
232    }
233
234    #[test]
235    fn extraction_result_deserialize_empty_arrays() {
236        let json = r#"{"entities":[],"edges":[]}"#;
237        let result: ExtractionResult = serde_json::from_str(json).unwrap();
238        assert!(result.entities.is_empty());
239        assert!(result.edges.is_empty());
240    }
241
242    #[test]
243    fn extraction_result_deserialize_missing_optional_fields() {
244        let json = r#"{"entities":[{"name":"Alice","type":"person","summary":null}],"edges":[{"source":"Alice","target":"Rust","relation":"uses","fact":"Alice uses Rust","temporal_hint":null}]}"#;
245        let result: ExtractionResult = serde_json::from_str(json).unwrap();
246        assert_eq!(result.entities[0].summary, None);
247        assert_eq!(result.edges[0].temporal_hint, None);
248        // edge_type defaults to "semantic" when omitted
249        assert_eq!(result.edges[0].edge_type, "semantic");
250    }
251
252    #[test]
253    fn extracted_edge_type_defaults_to_semantic_when_missing() {
254        // When LLM omits edge_type, serde(default) must provide "semantic".
255        let json = r#"{"source":"A","target":"B","relation":"uses","fact":"A uses B"}"#;
256        let edge: ExtractedEdge = serde_json::from_str(json).unwrap();
257        assert_eq!(edge.edge_type, "semantic");
258    }
259
260    #[test]
261    fn extracted_edge_type_parses_all_variants() {
262        for et in &["semantic", "temporal", "causal", "entity"] {
263            let json = format!(
264                r#"{{"source":"A","target":"B","relation":"r","fact":"f","edge_type":"{et}"}}"#
265            );
266            let edge: ExtractedEdge = serde_json::from_str(&json).unwrap();
267            assert_eq!(&edge.edge_type, et);
268        }
269    }
270
271    #[test]
272    fn extraction_result_with_edge_types_roundtrip() {
273        let original = ExtractionResult {
274            entities: vec![],
275            edges: vec![
276                ExtractedEdge {
277                    source: "A".into(),
278                    target: "B".into(),
279                    relation: "caused".into(),
280                    fact: "A caused B".into(),
281                    temporal_hint: None,
282                    edge_type: "causal".into(),
283                    confidence: Some(0.9),
284                },
285                ExtractedEdge {
286                    source: "B".into(),
287                    target: "C".into(),
288                    relation: "preceded_by".into(),
289                    fact: "B preceded_by C".into(),
290                    temporal_hint: None,
291                    edge_type: "temporal".into(),
292                    confidence: None,
293                },
294            ],
295        };
296        let json = serde_json::to_string(&original).unwrap();
297        let restored: ExtractionResult = serde_json::from_str(&json).unwrap();
298        assert_eq!(original, restored);
299        assert_eq!(restored.edges[0].edge_type, "causal");
300        assert_eq!(restored.edges[1].edge_type, "temporal");
301    }
302
303    #[test]
304    fn extracted_entity_type_field_rename() {
305        let json = r#"{"name":"cargo","type":"tool","summary":null}"#;
306        let entity: ExtractedEntity = serde_json::from_str(json).unwrap();
307        assert_eq!(entity.entity_type, "tool");
308
309        let serialized = serde_json::to_string(&entity).unwrap();
310        assert!(serialized.contains("\"type\""));
311        assert!(!serialized.contains("\"entity_type\""));
312    }
313
314    #[test]
315    fn extraction_result_roundtrip() {
316        let original = ExtractionResult {
317            entities: vec![make_entity("Rust", "language", Some("A systems language"))],
318            edges: vec![make_edge("Alice", "Rust", "uses", "Alice uses Rust", None)],
319        };
320        let json = serde_json::to_string(&original).unwrap();
321        let restored: ExtractionResult = serde_json::from_str(&json).unwrap();
322        assert_eq!(original, restored);
323    }
324
325    #[test]
326    fn extraction_result_json_schema() {
327        let schema = schemars::schema_for!(ExtractionResult);
328        let value = serde_json::to_value(&schema).unwrap();
329        let schema_obj = value.as_object().unwrap();
330        assert!(
331            schema_obj.contains_key("title") || schema_obj.contains_key("properties"),
332            "schema should have top-level keys"
333        );
334        let json_str = serde_json::to_string(&schema).unwrap();
335        assert!(
336            json_str.contains("entities"),
337            "schema should contain 'entities'"
338        );
339        assert!(json_str.contains("edges"), "schema should contain 'edges'");
340    }
341
342    #[test]
343    fn build_user_prompt_with_context() {
344        let prompt = build_user_prompt("Hello Rust", &["prev message 1", "prev message 2"]);
345        assert!(prompt.contains("Context (last 2 messages):"));
346        assert!(prompt.contains("prev message 1\nprev message 2"));
347        assert!(prompt.contains("Current message:\nHello Rust"));
348        assert!(prompt.contains("Extract entities and relationships as JSON."));
349    }
350
351    #[test]
352    fn build_user_prompt_without_context() {
353        let prompt = build_user_prompt("Hello Rust", &[]);
354        assert!(!prompt.contains("Context"));
355        assert!(prompt.contains("Current message:\nHello Rust"));
356        assert!(prompt.contains("Extract entities and relationships as JSON."));
357    }
358
359    mod mock_tests {
360        use super::*;
361        use zeph_llm::mock::MockProvider;
362
363        fn make_entities_json(count: usize) -> String {
364            let entities: Vec<String> = (0..count)
365                .map(|i| format!(r#"{{"name":"entity{i}","type":"concept","summary":null}}"#))
366                .collect();
367            format!(r#"{{"entities":[{}],"edges":[]}}"#, entities.join(","))
368        }
369
370        fn make_edges_json(count: usize) -> String {
371            let edges: Vec<String> = (0..count)
372                .map(|i| {
373                    format!(
374                        r#"{{"source":"A","target":"B{i}","relation":"uses","fact":"A uses B{i}","temporal_hint":null}}"#
375                    )
376                })
377                .collect();
378            format!(r#"{{"entities":[],"edges":[{}]}}"#, edges.join(","))
379        }
380
381        #[tokio::test]
382        async fn extract_truncates_to_max_entities() {
383            let json = make_entities_json(20);
384            let mock = MockProvider::with_responses(vec![json]);
385            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 5, 100, 30);
386            let result = extractor.extract("test message", &[]).await.unwrap();
387            let result = result.unwrap();
388            assert_eq!(result.entities.len(), 5);
389        }
390
391        #[tokio::test]
392        async fn extract_truncates_to_max_edges() {
393            let json = make_edges_json(15);
394            let mock = MockProvider::with_responses(vec![json]);
395            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 100, 3, 30);
396            let result = extractor.extract("test message", &[]).await.unwrap();
397            let result = result.unwrap();
398            assert_eq!(result.edges.len(), 3);
399        }
400
401        #[tokio::test]
402        async fn extract_returns_none_on_parse_failure() {
403            let mock = MockProvider::with_responses(vec!["not valid json at all".into()]);
404            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10, 30);
405            let result = extractor.extract("test message", &[]).await.unwrap();
406            assert!(result.is_none());
407        }
408
409        #[tokio::test]
410        async fn extract_returns_err_on_transport_failure() {
411            let mock = MockProvider::default()
412                .with_errors(vec![zeph_llm::LlmError::Other("connection refused".into())]);
413            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10, 30);
414            let result = extractor.extract("test message", &[]).await;
415            assert!(result.is_err());
416            assert!(matches!(result.unwrap_err(), MemoryError::Llm(_)));
417        }
418
419        #[test]
420        fn graph_extractor_stores_custom_llm_timeout() {
421            let extractor = GraphExtractor::new(
422                zeph_llm::any::AnyProvider::Mock(MockProvider::default()),
423                10,
424                5,
425                42,
426            );
427            assert_eq!(extractor.llm_timeout_secs, 42);
428        }
429
430        #[tokio::test]
431        async fn extract_returns_none_on_empty_message() {
432            let mock = MockProvider::with_responses(vec!["should not be called".into()]);
433            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10, 30);
434
435            let result_empty = extractor.extract("", &[]).await.unwrap();
436            assert!(result_empty.is_none());
437
438            let result_whitespace = extractor.extract("   \t\n  ", &[]).await.unwrap();
439            assert!(result_whitespace.is_none());
440        }
441    }
442}