Skip to main content

zeph_memory/graph/
extractor.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use zeph_llm::LlmError;
7use zeph_llm::any::AnyProvider;
8use zeph_llm::provider::{Message, Role};
9
10use crate::error::MemoryError;
11
12const SYSTEM_PROMPT: &str = "\
13You are an entity and relationship extractor. Given a conversation message and \
14its recent context, extract structured knowledge as JSON.
15
16Rules:
171. Extract only entities that appear in natural conversational text — user statements, \
18preferences, opinions, or factual claims made by a person.
192. Do NOT extract entities from: tool outputs, command results, file contents, \
20configuration files, JSON/TOML/YAML data, code snippets, or error messages. \
21If the message is structured data or raw command output, return empty arrays.
223. Do NOT extract structural data: config keys, file paths, tool names, TOML/JSON keys, \
23programming keywords, or single-letter identifiers.
244. Entity types must be one of: person, project, tool, language, organization, concept. \
25\"tool\" covers frameworks, software tools, and libraries. \
26\"language\" covers programming and natural languages. \
27\"concept\" covers abstract ideas, methodologies, and practices.
285. Only extract entities with clear semantic meaning about people, projects, or domain knowledge.
296. Entity names must be at least 3 characters long. Reject single characters, two-letter \
30tokens (e.g. standalone \"go\", \"cd\"), URLs, and bare file paths.
317. Relations should be short verb phrases: \"prefers\", \"uses\", \"works_on\", \"knows\", \
32\"created\", \"depends_on\", \"replaces\", \"configured_with\".
338. The \"fact\" field is a human-readable sentence summarizing the relationship.
349. If a message contains a temporal change (e.g., \"switched from X to Y\"), include a \
35temporal_hint like \"replaced X\" or \"since January 2026\".
3610. Each edge must include an \"edge_type\" field classifying the relationship:
37  - \"semantic\": conceptual relationships (uses, prefers, knows, works_on, depends_on, created)
38  - \"temporal\": time-ordered events (preceded_by, followed_by, happened_during, started_before)
39  - \"causal\": cause-effect chains (caused, triggered, resulted_in, led_to, prevented)
40  - \"entity\": identity/structural relationships (is_a, part_of, instance_of, alias_of, replaces)
41  Default to \"semantic\" if the relationship type is uncertain.
4211. Each edge must include a \"confidence\" field: a float in [0.0, 1.0] reflecting how \
43certain you are that this relationship was explicitly stated or strongly implied by the text. \
44Use 1.0 only for direct verbatim statements. Use 0.5–0.8 for clear implications. \
45Use 0.3–0.5 for weak inferences. Omit or use null if you cannot assign a meaningful score.
4611. Do not extract entities from greetings, filler, or meta-conversation (\"hi\", \"thanks\", \"ok\").
4712. Do not extract personal identifiable information as entity names: email addresses, \
48phone numbers, physical addresses, SSNs, or API keys. Use generic references instead.
4913. Always output entity names and relation verbs in English. Translate if needed.
5014. Return empty arrays if no entities or relationships are present.
51
52Output JSON schema:
53{
54  \"entities\": [
55    {\"name\": \"string\", \"type\": \"person|project|tool|language|organization|concept\", \"summary\": \"optional string\"}
56  ],
57  \"edges\": [
58    {\"source\": \"entity name\", \"target\": \"entity name\", \"relation\": \"verb phrase\", \"fact\": \"human-readable sentence\", \"temporal_hint\": \"optional string\", \"edge_type\": \"semantic|temporal|causal|entity\", \"confidence\": 0.0}
59  ]
60}";
61
62#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
63pub struct ExtractionResult {
64    pub entities: Vec<ExtractedEntity>,
65    pub edges: Vec<ExtractedEdge>,
66}
67
68#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
69pub struct ExtractedEntity {
70    pub name: String,
71    #[serde(rename = "type")]
72    pub entity_type: String,
73    pub summary: Option<String>,
74}
75
76fn default_semantic() -> String {
77    "semantic".to_owned()
78}
79
80#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
81pub struct ExtractedEdge {
82    pub source: String,
83    pub target: String,
84    pub relation: String,
85    pub fact: String,
86    pub temporal_hint: Option<String>,
87    /// MAGMA edge type classification. Defaults to "semantic" when omitted by the LLM.
88    #[serde(default = "default_semantic")]
89    pub edge_type: String,
90    /// Extractor confidence in the relationship, in `[0.0, 1.0]`.
91    ///
92    /// Assigned by the LLM during extraction. `None` means the LLM omitted the field;
93    /// callers should treat `None` as `1.0` (direct statement, commit immediately).
94    /// Values below `BeliefMemConfig::promote_threshold` route the edge to
95    /// `BeliefStore` for evidence accumulation instead of immediate commit.
96    #[serde(default)]
97    pub confidence: Option<f32>,
98}
99
100pub struct GraphExtractor {
101    provider: AnyProvider,
102    max_entities: usize,
103    max_edges: usize,
104}
105
106impl GraphExtractor {
107    #[must_use]
108    pub fn new(provider: AnyProvider, max_entities: usize, max_edges: usize) -> Self {
109        Self {
110            provider,
111            max_entities,
112            max_edges,
113        }
114    }
115
116    /// Extract entities and relations from a message with surrounding context.
117    ///
118    /// Returns `None` if the message is empty, extraction fails, or the LLM returns
119    /// unparseable output. Callers should treat `None` as a graceful degradation.
120    ///
121    /// # Errors
122    ///
123    /// Returns an error only for transport-level failures (network, auth).
124    /// JSON parse failures are logged and return `Ok(None)`.
125    pub async fn extract(
126        &self,
127        message: &str,
128        context_messages: &[&str],
129    ) -> Result<Option<ExtractionResult>, MemoryError> {
130        if message.trim().is_empty() {
131            return Ok(None);
132        }
133
134        let user_prompt = build_user_prompt(message, context_messages);
135        let messages = [
136            Message::from_legacy(Role::System, SYSTEM_PROMPT),
137            Message::from_legacy(Role::User, user_prompt),
138        ];
139
140        match self
141            .provider
142            .chat_typed_erased::<ExtractionResult>(&messages)
143            .await
144        {
145            Ok(mut result) => {
146                result.entities.truncate(self.max_entities);
147                result.edges.truncate(self.max_edges);
148                Ok(Some(result))
149            }
150            Err(LlmError::StructuredParse(msg)) => {
151                tracing::warn!(
152                    "graph extraction: LLM returned unparseable output (len={}): {:.200}",
153                    msg.len(),
154                    msg
155                );
156                Ok(None)
157            }
158            Err(other) => Err(MemoryError::Llm(other)),
159        }
160    }
161}
162
163fn build_user_prompt(message: &str, context_messages: &[&str]) -> String {
164    if context_messages.is_empty() {
165        format!("Current message:\n{message}\n\nExtract entities and relationships as JSON.")
166    } else {
167        let n = context_messages.len();
168        let context = context_messages.join("\n");
169        format!(
170            "Context (last {n} messages):\n{context}\n\nCurrent message:\n{message}\n\nExtract entities and relationships as JSON."
171        )
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    fn make_entity(name: &str, entity_type: &str, summary: Option<&str>) -> ExtractedEntity {
180        ExtractedEntity {
181            name: name.into(),
182            entity_type: entity_type.into(),
183            summary: summary.map(Into::into),
184        }
185    }
186
187    fn make_edge(
188        source: &str,
189        target: &str,
190        relation: &str,
191        fact: &str,
192        temporal_hint: Option<&str>,
193    ) -> ExtractedEdge {
194        ExtractedEdge {
195            source: source.into(),
196            target: target.into(),
197            relation: relation.into(),
198            fact: fact.into(),
199            temporal_hint: temporal_hint.map(Into::into),
200            edge_type: "semantic".into(),
201            confidence: None,
202        }
203    }
204
205    #[test]
206    fn extraction_result_deserialize_valid_json() {
207        let json = r#"{"entities":[{"name":"Rust","type":"language","summary":"A systems language"}],"edges":[]}"#;
208        let result: ExtractionResult = serde_json::from_str(json).unwrap();
209        assert_eq!(result.entities.len(), 1);
210        assert_eq!(result.entities[0].name, "Rust");
211        assert_eq!(result.entities[0].entity_type, "language");
212        assert_eq!(
213            result.entities[0].summary.as_deref(),
214            Some("A systems language")
215        );
216        assert!(result.edges.is_empty());
217    }
218
219    #[test]
220    fn extraction_result_deserialize_empty_arrays() {
221        let json = r#"{"entities":[],"edges":[]}"#;
222        let result: ExtractionResult = serde_json::from_str(json).unwrap();
223        assert!(result.entities.is_empty());
224        assert!(result.edges.is_empty());
225    }
226
227    #[test]
228    fn extraction_result_deserialize_missing_optional_fields() {
229        let json = r#"{"entities":[{"name":"Alice","type":"person","summary":null}],"edges":[{"source":"Alice","target":"Rust","relation":"uses","fact":"Alice uses Rust","temporal_hint":null}]}"#;
230        let result: ExtractionResult = serde_json::from_str(json).unwrap();
231        assert_eq!(result.entities[0].summary, None);
232        assert_eq!(result.edges[0].temporal_hint, None);
233        // edge_type defaults to "semantic" when omitted
234        assert_eq!(result.edges[0].edge_type, "semantic");
235    }
236
237    #[test]
238    fn extracted_edge_type_defaults_to_semantic_when_missing() {
239        // When LLM omits edge_type, serde(default) must provide "semantic".
240        let json = r#"{"source":"A","target":"B","relation":"uses","fact":"A uses B"}"#;
241        let edge: ExtractedEdge = serde_json::from_str(json).unwrap();
242        assert_eq!(edge.edge_type, "semantic");
243    }
244
245    #[test]
246    fn extracted_edge_type_parses_all_variants() {
247        for et in &["semantic", "temporal", "causal", "entity"] {
248            let json = format!(
249                r#"{{"source":"A","target":"B","relation":"r","fact":"f","edge_type":"{et}"}}"#
250            );
251            let edge: ExtractedEdge = serde_json::from_str(&json).unwrap();
252            assert_eq!(&edge.edge_type, et);
253        }
254    }
255
256    #[test]
257    fn extraction_result_with_edge_types_roundtrip() {
258        let original = ExtractionResult {
259            entities: vec![],
260            edges: vec![
261                ExtractedEdge {
262                    source: "A".into(),
263                    target: "B".into(),
264                    relation: "caused".into(),
265                    fact: "A caused B".into(),
266                    temporal_hint: None,
267                    edge_type: "causal".into(),
268                    confidence: Some(0.9),
269                },
270                ExtractedEdge {
271                    source: "B".into(),
272                    target: "C".into(),
273                    relation: "preceded_by".into(),
274                    fact: "B preceded_by C".into(),
275                    temporal_hint: None,
276                    edge_type: "temporal".into(),
277                    confidence: None,
278                },
279            ],
280        };
281        let json = serde_json::to_string(&original).unwrap();
282        let restored: ExtractionResult = serde_json::from_str(&json).unwrap();
283        assert_eq!(original, restored);
284        assert_eq!(restored.edges[0].edge_type, "causal");
285        assert_eq!(restored.edges[1].edge_type, "temporal");
286    }
287
288    #[test]
289    fn extracted_entity_type_field_rename() {
290        let json = r#"{"name":"cargo","type":"tool","summary":null}"#;
291        let entity: ExtractedEntity = serde_json::from_str(json).unwrap();
292        assert_eq!(entity.entity_type, "tool");
293
294        let serialized = serde_json::to_string(&entity).unwrap();
295        assert!(serialized.contains("\"type\""));
296        assert!(!serialized.contains("\"entity_type\""));
297    }
298
299    #[test]
300    fn extraction_result_roundtrip() {
301        let original = ExtractionResult {
302            entities: vec![make_entity("Rust", "language", Some("A systems language"))],
303            edges: vec![make_edge("Alice", "Rust", "uses", "Alice uses Rust", None)],
304        };
305        let json = serde_json::to_string(&original).unwrap();
306        let restored: ExtractionResult = serde_json::from_str(&json).unwrap();
307        assert_eq!(original, restored);
308    }
309
310    #[test]
311    fn extraction_result_json_schema() {
312        let schema = schemars::schema_for!(ExtractionResult);
313        let value = serde_json::to_value(&schema).unwrap();
314        let schema_obj = value.as_object().unwrap();
315        assert!(
316            schema_obj.contains_key("title") || schema_obj.contains_key("properties"),
317            "schema should have top-level keys"
318        );
319        let json_str = serde_json::to_string(&schema).unwrap();
320        assert!(
321            json_str.contains("entities"),
322            "schema should contain 'entities'"
323        );
324        assert!(json_str.contains("edges"), "schema should contain 'edges'");
325    }
326
327    #[test]
328    fn build_user_prompt_with_context() {
329        let prompt = build_user_prompt("Hello Rust", &["prev message 1", "prev message 2"]);
330        assert!(prompt.contains("Context (last 2 messages):"));
331        assert!(prompt.contains("prev message 1\nprev message 2"));
332        assert!(prompt.contains("Current message:\nHello Rust"));
333        assert!(prompt.contains("Extract entities and relationships as JSON."));
334    }
335
336    #[test]
337    fn build_user_prompt_without_context() {
338        let prompt = build_user_prompt("Hello Rust", &[]);
339        assert!(!prompt.contains("Context"));
340        assert!(prompt.contains("Current message:\nHello Rust"));
341        assert!(prompt.contains("Extract entities and relationships as JSON."));
342    }
343
344    mod mock_tests {
345        use super::*;
346        use zeph_llm::mock::MockProvider;
347
348        fn make_entities_json(count: usize) -> String {
349            let entities: Vec<String> = (0..count)
350                .map(|i| format!(r#"{{"name":"entity{i}","type":"concept","summary":null}}"#))
351                .collect();
352            format!(r#"{{"entities":[{}],"edges":[]}}"#, entities.join(","))
353        }
354
355        fn make_edges_json(count: usize) -> String {
356            let edges: Vec<String> = (0..count)
357                .map(|i| {
358                    format!(
359                        r#"{{"source":"A","target":"B{i}","relation":"uses","fact":"A uses B{i}","temporal_hint":null}}"#
360                    )
361                })
362                .collect();
363            format!(r#"{{"entities":[],"edges":[{}]}}"#, edges.join(","))
364        }
365
366        #[tokio::test]
367        async fn extract_truncates_to_max_entities() {
368            let json = make_entities_json(20);
369            let mock = MockProvider::with_responses(vec![json]);
370            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 5, 100);
371            let result = extractor.extract("test message", &[]).await.unwrap();
372            let result = result.unwrap();
373            assert_eq!(result.entities.len(), 5);
374        }
375
376        #[tokio::test]
377        async fn extract_truncates_to_max_edges() {
378            let json = make_edges_json(15);
379            let mock = MockProvider::with_responses(vec![json]);
380            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 100, 3);
381            let result = extractor.extract("test message", &[]).await.unwrap();
382            let result = result.unwrap();
383            assert_eq!(result.edges.len(), 3);
384        }
385
386        #[tokio::test]
387        async fn extract_returns_none_on_parse_failure() {
388            let mock = MockProvider::with_responses(vec!["not valid json at all".into()]);
389            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
390            let result = extractor.extract("test message", &[]).await.unwrap();
391            assert!(result.is_none());
392        }
393
394        #[tokio::test]
395        async fn extract_returns_err_on_transport_failure() {
396            let mock = MockProvider::default()
397                .with_errors(vec![zeph_llm::LlmError::Other("connection refused".into())]);
398            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
399            let result = extractor.extract("test message", &[]).await;
400            assert!(result.is_err());
401            assert!(matches!(result.unwrap_err(), MemoryError::Llm(_)));
402        }
403
404        #[tokio::test]
405        async fn extract_returns_none_on_empty_message() {
406            let mock = MockProvider::with_responses(vec!["should not be called".into()]);
407            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
408
409            let result_empty = extractor.extract("", &[]).await.unwrap();
410            assert!(result_empty.is_none());
411
412            let result_whitespace = extractor.extract("   \t\n  ", &[]).await.unwrap();
413            assert!(result_whitespace.is_none());
414        }
415    }
416}