Skip to main content

zeph_memory/graph/
extractor.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use zeph_llm::LlmError;
7use zeph_llm::any::AnyProvider;
8use zeph_llm::provider::{Message, Role};
9
10use crate::error::MemoryError;
11
12const SYSTEM_PROMPT: &str = "\
13You are an entity and relationship extractor. Given a conversation message and \
14its recent context, extract structured knowledge as JSON.
15
16Rules:
171. Extract entities mentioned or implied in the current message.
182. Extract relationships between entities.
193. Entity types must be one of: person, tool, concept, project, language, file, config, organization.
204. Relations should be short verb phrases: \"prefers\", \"uses\", \"works_on\", \"knows\", \"created\", \"depends_on\", \"replaces\", \"configured_with\".
215. The \"fact\" field is a human-readable sentence summarizing the relationship.
226. If a message contains a temporal change (e.g., \"switched from X to Y\"), include a temporal_hint like \"replaced X\" or \"since January 2026\".
237. Do not extract entities from greetings, filler, or meta-conversation (\"hi\", \"thanks\", \"ok\").
248. Do not extract personal identifiable information as entity names: email addresses, phone numbers, physical addresses, SSNs, or API keys. Use generic references instead (e.g., \"User\" instead of \"Alice Smith\").
259. Always output entity names and relation verbs in English, even if the conversation is in another language. Translate entity names if needed.
2610. Return empty arrays if no entities or relationships are present.
27
28Output JSON schema:
29{
30  \"entities\": [
31    {\"name\": \"string\", \"type\": \"person|tool|concept|project|language|file|config|organization\", \"summary\": \"optional string\"}
32  ],
33  \"edges\": [
34    {\"source\": \"entity name\", \"target\": \"entity name\", \"relation\": \"verb phrase\", \"fact\": \"human-readable sentence\", \"temporal_hint\": \"optional string\"}
35  ]
36}";
37
38#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
39pub struct ExtractionResult {
40    pub entities: Vec<ExtractedEntity>,
41    pub edges: Vec<ExtractedEdge>,
42}
43
44#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
45pub struct ExtractedEntity {
46    pub name: String,
47    #[serde(rename = "type")]
48    pub entity_type: String,
49    pub summary: Option<String>,
50}
51
52#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
53pub struct ExtractedEdge {
54    pub source: String,
55    pub target: String,
56    pub relation: String,
57    pub fact: String,
58    pub temporal_hint: Option<String>,
59}
60
61pub struct GraphExtractor {
62    provider: AnyProvider,
63    max_entities: usize,
64    max_edges: usize,
65}
66
67impl GraphExtractor {
68    #[must_use]
69    pub fn new(provider: AnyProvider, max_entities: usize, max_edges: usize) -> Self {
70        Self {
71            provider,
72            max_entities,
73            max_edges,
74        }
75    }
76
77    /// Extract entities and relations from a message with surrounding context.
78    ///
79    /// Returns `None` if the message is empty, extraction fails, or the LLM returns
80    /// unparseable output. Callers should treat `None` as a graceful degradation.
81    ///
82    /// # Errors
83    ///
84    /// Returns an error only for transport-level failures (network, auth).
85    /// JSON parse failures are logged and return `Ok(None)`.
86    pub async fn extract(
87        &self,
88        message: &str,
89        context_messages: &[&str],
90    ) -> Result<Option<ExtractionResult>, MemoryError> {
91        if message.trim().is_empty() {
92            return Ok(None);
93        }
94
95        let user_prompt = build_user_prompt(message, context_messages);
96        let messages = [
97            Message::from_legacy(Role::System, SYSTEM_PROMPT),
98            Message::from_legacy(Role::User, user_prompt),
99        ];
100
101        match self
102            .provider
103            .chat_typed_erased::<ExtractionResult>(&messages)
104            .await
105        {
106            Ok(mut result) => {
107                result.entities.truncate(self.max_entities);
108                result.edges.truncate(self.max_edges);
109                Ok(Some(result))
110            }
111            Err(LlmError::StructuredParse(msg)) => {
112                tracing::warn!(
113                    "graph extraction: LLM returned unparseable output (len={}): {:.200}",
114                    msg.len(),
115                    msg
116                );
117                Ok(None)
118            }
119            Err(other) => Err(MemoryError::Llm(other)),
120        }
121    }
122}
123
124fn build_user_prompt(message: &str, context_messages: &[&str]) -> String {
125    if context_messages.is_empty() {
126        format!("Current message:\n{message}\n\nExtract entities and relationships as JSON.")
127    } else {
128        let n = context_messages.len();
129        let context = context_messages.join("\n");
130        format!(
131            "Context (last {n} messages):\n{context}\n\nCurrent message:\n{message}\n\nExtract entities and relationships as JSON."
132        )
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    fn make_entity(name: &str, entity_type: &str, summary: Option<&str>) -> ExtractedEntity {
141        ExtractedEntity {
142            name: name.into(),
143            entity_type: entity_type.into(),
144            summary: summary.map(Into::into),
145        }
146    }
147
148    fn make_edge(
149        source: &str,
150        target: &str,
151        relation: &str,
152        fact: &str,
153        temporal_hint: Option<&str>,
154    ) -> ExtractedEdge {
155        ExtractedEdge {
156            source: source.into(),
157            target: target.into(),
158            relation: relation.into(),
159            fact: fact.into(),
160            temporal_hint: temporal_hint.map(Into::into),
161        }
162    }
163
164    #[test]
165    fn extraction_result_deserialize_valid_json() {
166        let json = r#"{"entities":[{"name":"Rust","type":"language","summary":"A systems language"}],"edges":[]}"#;
167        let result: ExtractionResult = serde_json::from_str(json).unwrap();
168        assert_eq!(result.entities.len(), 1);
169        assert_eq!(result.entities[0].name, "Rust");
170        assert_eq!(result.entities[0].entity_type, "language");
171        assert_eq!(
172            result.entities[0].summary.as_deref(),
173            Some("A systems language")
174        );
175        assert!(result.edges.is_empty());
176    }
177
178    #[test]
179    fn extraction_result_deserialize_empty_arrays() {
180        let json = r#"{"entities":[],"edges":[]}"#;
181        let result: ExtractionResult = serde_json::from_str(json).unwrap();
182        assert!(result.entities.is_empty());
183        assert!(result.edges.is_empty());
184    }
185
186    #[test]
187    fn extraction_result_deserialize_missing_optional_fields() {
188        let json = r#"{"entities":[{"name":"Alice","type":"person","summary":null}],"edges":[{"source":"Alice","target":"Rust","relation":"uses","fact":"Alice uses Rust","temporal_hint":null}]}"#;
189        let result: ExtractionResult = serde_json::from_str(json).unwrap();
190        assert_eq!(result.entities[0].summary, None);
191        assert_eq!(result.edges[0].temporal_hint, None);
192    }
193
194    #[test]
195    fn extracted_entity_type_field_rename() {
196        let json = r#"{"name":"cargo","type":"tool","summary":null}"#;
197        let entity: ExtractedEntity = serde_json::from_str(json).unwrap();
198        assert_eq!(entity.entity_type, "tool");
199
200        let serialized = serde_json::to_string(&entity).unwrap();
201        assert!(serialized.contains("\"type\""));
202        assert!(!serialized.contains("\"entity_type\""));
203    }
204
205    #[test]
206    fn extraction_result_roundtrip() {
207        let original = ExtractionResult {
208            entities: vec![make_entity("Rust", "language", Some("A systems language"))],
209            edges: vec![make_edge("Alice", "Rust", "uses", "Alice uses Rust", None)],
210        };
211        let json = serde_json::to_string(&original).unwrap();
212        let restored: ExtractionResult = serde_json::from_str(&json).unwrap();
213        assert_eq!(original, restored);
214    }
215
216    #[test]
217    fn extraction_result_json_schema() {
218        let schema = schemars::schema_for!(ExtractionResult);
219        let value = serde_json::to_value(&schema).unwrap();
220        let schema_obj = value.as_object().unwrap();
221        assert!(
222            schema_obj.contains_key("title") || schema_obj.contains_key("properties"),
223            "schema should have top-level keys"
224        );
225        let json_str = serde_json::to_string(&schema).unwrap();
226        assert!(
227            json_str.contains("entities"),
228            "schema should contain 'entities'"
229        );
230        assert!(json_str.contains("edges"), "schema should contain 'edges'");
231    }
232
233    #[test]
234    fn build_user_prompt_with_context() {
235        let prompt = build_user_prompt("Hello Rust", &["prev message 1", "prev message 2"]);
236        assert!(prompt.contains("Context (last 2 messages):"));
237        assert!(prompt.contains("prev message 1\nprev message 2"));
238        assert!(prompt.contains("Current message:\nHello Rust"));
239        assert!(prompt.contains("Extract entities and relationships as JSON."));
240    }
241
242    #[test]
243    fn build_user_prompt_without_context() {
244        let prompt = build_user_prompt("Hello Rust", &[]);
245        assert!(!prompt.contains("Context"));
246        assert!(prompt.contains("Current message:\nHello Rust"));
247        assert!(prompt.contains("Extract entities and relationships as JSON."));
248    }
249
250    mod mock_tests {
251        use super::*;
252        use zeph_llm::mock::MockProvider;
253
254        fn make_entities_json(count: usize) -> String {
255            let entities: Vec<String> = (0..count)
256                .map(|i| format!(r#"{{"name":"entity{i}","type":"concept","summary":null}}"#))
257                .collect();
258            format!(r#"{{"entities":[{}],"edges":[]}}"#, entities.join(","))
259        }
260
261        fn make_edges_json(count: usize) -> String {
262            let edges: Vec<String> = (0..count)
263                .map(|i| {
264                    format!(
265                        r#"{{"source":"A","target":"B{i}","relation":"uses","fact":"A uses B{i}","temporal_hint":null}}"#
266                    )
267                })
268                .collect();
269            format!(r#"{{"entities":[],"edges":[{}]}}"#, edges.join(","))
270        }
271
272        #[tokio::test]
273        async fn extract_truncates_to_max_entities() {
274            let json = make_entities_json(20);
275            let mock = MockProvider::with_responses(vec![json]);
276            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 5, 100);
277            let result = extractor.extract("test message", &[]).await.unwrap();
278            let result = result.unwrap();
279            assert_eq!(result.entities.len(), 5);
280        }
281
282        #[tokio::test]
283        async fn extract_truncates_to_max_edges() {
284            let json = make_edges_json(15);
285            let mock = MockProvider::with_responses(vec![json]);
286            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 100, 3);
287            let result = extractor.extract("test message", &[]).await.unwrap();
288            let result = result.unwrap();
289            assert_eq!(result.edges.len(), 3);
290        }
291
292        #[tokio::test]
293        async fn extract_returns_none_on_parse_failure() {
294            let mock = MockProvider::with_responses(vec!["not valid json at all".into()]);
295            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
296            let result = extractor.extract("test message", &[]).await.unwrap();
297            assert!(result.is_none());
298        }
299
300        #[tokio::test]
301        async fn extract_returns_err_on_transport_failure() {
302            let mock = MockProvider::default()
303                .with_errors(vec![zeph_llm::LlmError::Other("connection refused".into())]);
304            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
305            let result = extractor.extract("test message", &[]).await;
306            assert!(result.is_err());
307            assert!(matches!(result.unwrap_err(), MemoryError::Llm(_)));
308        }
309
310        #[tokio::test]
311        async fn extract_returns_none_on_empty_message() {
312            let mock = MockProvider::with_responses(vec!["should not be called".into()]);
313            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
314
315            let result_empty = extractor.extract("", &[]).await.unwrap();
316            assert!(result_empty.is_none());
317
318            let result_whitespace = extractor.extract("   \t\n  ", &[]).await.unwrap();
319            assert!(result_whitespace.is_none());
320        }
321    }
322}