Skip to main content

zeph_memory/graph/
extractor.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use zeph_llm::LlmError;
7use zeph_llm::any::AnyProvider;
8use zeph_llm::provider::{Message, Role};
9
10use crate::error::MemoryError;
11
12const SYSTEM_PROMPT: &str = "\
13You are an entity and relationship extractor. Given a conversation message and \
14its recent context, extract structured knowledge as JSON.
15
16Rules:
171. Extract only entities that appear in natural conversational text — user statements, \
18preferences, opinions, or factual claims made by a person.
192. Do NOT extract entities from: tool outputs, command results, file contents, \
20configuration files, JSON/TOML/YAML data, code snippets, or error messages. \
21If the message is structured data or raw command output, return empty arrays.
223. Do NOT extract structural data: config keys, file paths, tool names, TOML/JSON keys, \
23programming keywords, or single-letter identifiers.
244. Entity types must be one of: person, project, tool, language, organization, concept. \
25\"tool\" covers frameworks, software tools, and libraries. \
26\"language\" covers programming and natural languages. \
27\"concept\" covers abstract ideas, methodologies, and practices.
285. Only extract entities with clear semantic meaning about people, projects, or domain knowledge.
296. Entity names must be at least 3 characters long. Reject single characters, two-letter \
30tokens (e.g. standalone \"go\", \"cd\"), URLs, and bare file paths.
317. Relations should be short verb phrases: \"prefers\", \"uses\", \"works_on\", \"knows\", \
32\"created\", \"depends_on\", \"replaces\", \"configured_with\".
338. The \"fact\" field is a human-readable sentence summarizing the relationship.
349. If a message contains a temporal change (e.g., \"switched from X to Y\"), include a \
35temporal_hint like \"replaced X\" or \"since January 2026\".
3610. Each edge must include an \"edge_type\" field classifying the relationship:
37  - \"semantic\": conceptual relationships (uses, prefers, knows, works_on, depends_on, created)
38  - \"temporal\": time-ordered events (preceded_by, followed_by, happened_during, started_before)
39  - \"causal\": cause-effect chains (caused, triggered, resulted_in, led_to, prevented)
40  - \"entity\": identity/structural relationships (is_a, part_of, instance_of, alias_of, replaces)
41  Default to \"semantic\" if the relationship type is uncertain.
4211. Do not extract entities from greetings, filler, or meta-conversation (\"hi\", \"thanks\", \"ok\").
4312. Do not extract personal identifiable information as entity names: email addresses, \
44phone numbers, physical addresses, SSNs, or API keys. Use generic references instead.
4513. Always output entity names and relation verbs in English. Translate if needed.
4614. Return empty arrays if no entities or relationships are present.
47
48Output JSON schema:
49{
50  \"entities\": [
51    {\"name\": \"string\", \"type\": \"person|project|tool|language|organization|concept\", \"summary\": \"optional string\"}
52  ],
53  \"edges\": [
54    {\"source\": \"entity name\", \"target\": \"entity name\", \"relation\": \"verb phrase\", \"fact\": \"human-readable sentence\", \"temporal_hint\": \"optional string\", \"edge_type\": \"semantic|temporal|causal|entity\"}
55  ]
56}";
57
58#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
59pub struct ExtractionResult {
60    pub entities: Vec<ExtractedEntity>,
61    pub edges: Vec<ExtractedEdge>,
62}
63
64#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
65pub struct ExtractedEntity {
66    pub name: String,
67    #[serde(rename = "type")]
68    pub entity_type: String,
69    pub summary: Option<String>,
70}
71
72fn default_semantic() -> String {
73    "semantic".to_owned()
74}
75
76#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
77pub struct ExtractedEdge {
78    pub source: String,
79    pub target: String,
80    pub relation: String,
81    pub fact: String,
82    pub temporal_hint: Option<String>,
83    /// MAGMA edge type classification. Defaults to "semantic" when omitted by the LLM.
84    #[serde(default = "default_semantic")]
85    pub edge_type: String,
86}
87
88pub struct GraphExtractor {
89    provider: AnyProvider,
90    max_entities: usize,
91    max_edges: usize,
92}
93
94impl GraphExtractor {
95    #[must_use]
96    pub fn new(provider: AnyProvider, max_entities: usize, max_edges: usize) -> Self {
97        Self {
98            provider,
99            max_entities,
100            max_edges,
101        }
102    }
103
104    /// Extract entities and relations from a message with surrounding context.
105    ///
106    /// Returns `None` if the message is empty, extraction fails, or the LLM returns
107    /// unparseable output. Callers should treat `None` as a graceful degradation.
108    ///
109    /// # Errors
110    ///
111    /// Returns an error only for transport-level failures (network, auth).
112    /// JSON parse failures are logged and return `Ok(None)`.
113    pub async fn extract(
114        &self,
115        message: &str,
116        context_messages: &[&str],
117    ) -> Result<Option<ExtractionResult>, MemoryError> {
118        if message.trim().is_empty() {
119            return Ok(None);
120        }
121
122        let user_prompt = build_user_prompt(message, context_messages);
123        let messages = [
124            Message::from_legacy(Role::System, SYSTEM_PROMPT),
125            Message::from_legacy(Role::User, user_prompt),
126        ];
127
128        match self
129            .provider
130            .chat_typed_erased::<ExtractionResult>(&messages)
131            .await
132        {
133            Ok(mut result) => {
134                result.entities.truncate(self.max_entities);
135                result.edges.truncate(self.max_edges);
136                Ok(Some(result))
137            }
138            Err(LlmError::StructuredParse(msg)) => {
139                tracing::warn!(
140                    "graph extraction: LLM returned unparseable output (len={}): {:.200}",
141                    msg.len(),
142                    msg
143                );
144                Ok(None)
145            }
146            Err(other) => Err(MemoryError::Llm(other)),
147        }
148    }
149}
150
151fn build_user_prompt(message: &str, context_messages: &[&str]) -> String {
152    if context_messages.is_empty() {
153        format!("Current message:\n{message}\n\nExtract entities and relationships as JSON.")
154    } else {
155        let n = context_messages.len();
156        let context = context_messages.join("\n");
157        format!(
158            "Context (last {n} messages):\n{context}\n\nCurrent message:\n{message}\n\nExtract entities and relationships as JSON."
159        )
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    fn make_entity(name: &str, entity_type: &str, summary: Option<&str>) -> ExtractedEntity {
168        ExtractedEntity {
169            name: name.into(),
170            entity_type: entity_type.into(),
171            summary: summary.map(Into::into),
172        }
173    }
174
175    fn make_edge(
176        source: &str,
177        target: &str,
178        relation: &str,
179        fact: &str,
180        temporal_hint: Option<&str>,
181    ) -> ExtractedEdge {
182        ExtractedEdge {
183            source: source.into(),
184            target: target.into(),
185            relation: relation.into(),
186            fact: fact.into(),
187            temporal_hint: temporal_hint.map(Into::into),
188            edge_type: "semantic".into(),
189        }
190    }
191
192    #[test]
193    fn extraction_result_deserialize_valid_json() {
194        let json = r#"{"entities":[{"name":"Rust","type":"language","summary":"A systems language"}],"edges":[]}"#;
195        let result: ExtractionResult = serde_json::from_str(json).unwrap();
196        assert_eq!(result.entities.len(), 1);
197        assert_eq!(result.entities[0].name, "Rust");
198        assert_eq!(result.entities[0].entity_type, "language");
199        assert_eq!(
200            result.entities[0].summary.as_deref(),
201            Some("A systems language")
202        );
203        assert!(result.edges.is_empty());
204    }
205
206    #[test]
207    fn extraction_result_deserialize_empty_arrays() {
208        let json = r#"{"entities":[],"edges":[]}"#;
209        let result: ExtractionResult = serde_json::from_str(json).unwrap();
210        assert!(result.entities.is_empty());
211        assert!(result.edges.is_empty());
212    }
213
214    #[test]
215    fn extraction_result_deserialize_missing_optional_fields() {
216        let json = r#"{"entities":[{"name":"Alice","type":"person","summary":null}],"edges":[{"source":"Alice","target":"Rust","relation":"uses","fact":"Alice uses Rust","temporal_hint":null}]}"#;
217        let result: ExtractionResult = serde_json::from_str(json).unwrap();
218        assert_eq!(result.entities[0].summary, None);
219        assert_eq!(result.edges[0].temporal_hint, None);
220        // edge_type defaults to "semantic" when omitted
221        assert_eq!(result.edges[0].edge_type, "semantic");
222    }
223
224    #[test]
225    fn extracted_edge_type_defaults_to_semantic_when_missing() {
226        // When LLM omits edge_type, serde(default) must provide "semantic".
227        let json = r#"{"source":"A","target":"B","relation":"uses","fact":"A uses B"}"#;
228        let edge: ExtractedEdge = serde_json::from_str(json).unwrap();
229        assert_eq!(edge.edge_type, "semantic");
230    }
231
232    #[test]
233    fn extracted_edge_type_parses_all_variants() {
234        for et in &["semantic", "temporal", "causal", "entity"] {
235            let json = format!(
236                r#"{{"source":"A","target":"B","relation":"r","fact":"f","edge_type":"{et}"}}"#
237            );
238            let edge: ExtractedEdge = serde_json::from_str(&json).unwrap();
239            assert_eq!(&edge.edge_type, et);
240        }
241    }
242
243    #[test]
244    fn extraction_result_with_edge_types_roundtrip() {
245        let original = ExtractionResult {
246            entities: vec![],
247            edges: vec![
248                ExtractedEdge {
249                    source: "A".into(),
250                    target: "B".into(),
251                    relation: "caused".into(),
252                    fact: "A caused B".into(),
253                    temporal_hint: None,
254                    edge_type: "causal".into(),
255                },
256                ExtractedEdge {
257                    source: "B".into(),
258                    target: "C".into(),
259                    relation: "preceded_by".into(),
260                    fact: "B preceded_by C".into(),
261                    temporal_hint: None,
262                    edge_type: "temporal".into(),
263                },
264            ],
265        };
266        let json = serde_json::to_string(&original).unwrap();
267        let restored: ExtractionResult = serde_json::from_str(&json).unwrap();
268        assert_eq!(original, restored);
269        assert_eq!(restored.edges[0].edge_type, "causal");
270        assert_eq!(restored.edges[1].edge_type, "temporal");
271    }
272
273    #[test]
274    fn extracted_entity_type_field_rename() {
275        let json = r#"{"name":"cargo","type":"tool","summary":null}"#;
276        let entity: ExtractedEntity = serde_json::from_str(json).unwrap();
277        assert_eq!(entity.entity_type, "tool");
278
279        let serialized = serde_json::to_string(&entity).unwrap();
280        assert!(serialized.contains("\"type\""));
281        assert!(!serialized.contains("\"entity_type\""));
282    }
283
284    #[test]
285    fn extraction_result_roundtrip() {
286        let original = ExtractionResult {
287            entities: vec![make_entity("Rust", "language", Some("A systems language"))],
288            edges: vec![make_edge("Alice", "Rust", "uses", "Alice uses Rust", None)],
289        };
290        let json = serde_json::to_string(&original).unwrap();
291        let restored: ExtractionResult = serde_json::from_str(&json).unwrap();
292        assert_eq!(original, restored);
293    }
294
295    #[test]
296    fn extraction_result_json_schema() {
297        let schema = schemars::schema_for!(ExtractionResult);
298        let value = serde_json::to_value(&schema).unwrap();
299        let schema_obj = value.as_object().unwrap();
300        assert!(
301            schema_obj.contains_key("title") || schema_obj.contains_key("properties"),
302            "schema should have top-level keys"
303        );
304        let json_str = serde_json::to_string(&schema).unwrap();
305        assert!(
306            json_str.contains("entities"),
307            "schema should contain 'entities'"
308        );
309        assert!(json_str.contains("edges"), "schema should contain 'edges'");
310    }
311
312    #[test]
313    fn build_user_prompt_with_context() {
314        let prompt = build_user_prompt("Hello Rust", &["prev message 1", "prev message 2"]);
315        assert!(prompt.contains("Context (last 2 messages):"));
316        assert!(prompt.contains("prev message 1\nprev message 2"));
317        assert!(prompt.contains("Current message:\nHello Rust"));
318        assert!(prompt.contains("Extract entities and relationships as JSON."));
319    }
320
321    #[test]
322    fn build_user_prompt_without_context() {
323        let prompt = build_user_prompt("Hello Rust", &[]);
324        assert!(!prompt.contains("Context"));
325        assert!(prompt.contains("Current message:\nHello Rust"));
326        assert!(prompt.contains("Extract entities and relationships as JSON."));
327    }
328
329    mod mock_tests {
330        use super::*;
331        use zeph_llm::mock::MockProvider;
332
333        fn make_entities_json(count: usize) -> String {
334            let entities: Vec<String> = (0..count)
335                .map(|i| format!(r#"{{"name":"entity{i}","type":"concept","summary":null}}"#))
336                .collect();
337            format!(r#"{{"entities":[{}],"edges":[]}}"#, entities.join(","))
338        }
339
340        fn make_edges_json(count: usize) -> String {
341            let edges: Vec<String> = (0..count)
342                .map(|i| {
343                    format!(
344                        r#"{{"source":"A","target":"B{i}","relation":"uses","fact":"A uses B{i}","temporal_hint":null}}"#
345                    )
346                })
347                .collect();
348            format!(r#"{{"entities":[],"edges":[{}]}}"#, edges.join(","))
349        }
350
351        #[tokio::test]
352        async fn extract_truncates_to_max_entities() {
353            let json = make_entities_json(20);
354            let mock = MockProvider::with_responses(vec![json]);
355            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 5, 100);
356            let result = extractor.extract("test message", &[]).await.unwrap();
357            let result = result.unwrap();
358            assert_eq!(result.entities.len(), 5);
359        }
360
361        #[tokio::test]
362        async fn extract_truncates_to_max_edges() {
363            let json = make_edges_json(15);
364            let mock = MockProvider::with_responses(vec![json]);
365            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 100, 3);
366            let result = extractor.extract("test message", &[]).await.unwrap();
367            let result = result.unwrap();
368            assert_eq!(result.edges.len(), 3);
369        }
370
371        #[tokio::test]
372        async fn extract_returns_none_on_parse_failure() {
373            let mock = MockProvider::with_responses(vec!["not valid json at all".into()]);
374            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
375            let result = extractor.extract("test message", &[]).await.unwrap();
376            assert!(result.is_none());
377        }
378
379        #[tokio::test]
380        async fn extract_returns_err_on_transport_failure() {
381            let mock = MockProvider::default()
382                .with_errors(vec![zeph_llm::LlmError::Other("connection refused".into())]);
383            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
384            let result = extractor.extract("test message", &[]).await;
385            assert!(result.is_err());
386            assert!(matches!(result.unwrap_err(), MemoryError::Llm(_)));
387        }
388
389        #[tokio::test]
390        async fn extract_returns_none_on_empty_message() {
391            let mock = MockProvider::with_responses(vec!["should not be called".into()]);
392            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
393
394            let result_empty = extractor.extract("", &[]).await.unwrap();
395            assert!(result_empty.is_none());
396
397            let result_whitespace = extractor.extract("   \t\n  ", &[]).await.unwrap();
398            assert!(result_whitespace.is_none());
399        }
400    }
401}