Skip to main content

zeph_memory/graph/
extractor.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use zeph_llm::LlmError;
7use zeph_llm::any::AnyProvider;
8use zeph_llm::provider::{Message, Role};
9
10use crate::error::MemoryError;
11
12const SYSTEM_PROMPT: &str = "\
13You are an entity and relationship extractor. Given a conversation message and \
14its recent context, extract structured knowledge as JSON.
15
16Rules:
171. Extract only entities that appear in natural conversational text — user statements, \
18preferences, opinions, or factual claims made by a person.
192. Do NOT extract entities from: tool outputs, command results, file contents, \
20configuration files, JSON/TOML/YAML data, code snippets, or error messages. \
21If the message is structured data or raw command output, return empty arrays.
223. Do NOT extract structural data: config keys, file paths, tool names, TOML/JSON keys, \
23programming keywords, or single-letter identifiers.
244. Entity types must be one of: person, project, technology, organization, concept. \
25\"technology\" covers programming languages, frameworks, tools, and software. \
26\"concept\" covers abstract ideas, methodologies, and practices.
275. Only extract entities with clear semantic meaning about people, projects, or domain knowledge.
286. Entity names must be at least 3 characters long. Reject single characters, two-letter \
29tokens (e.g. standalone \"go\", \"cd\"), URLs, and bare file paths.
307. Relations should be short verb phrases: \"prefers\", \"uses\", \"works_on\", \"knows\", \
31\"created\", \"depends_on\", \"replaces\", \"configured_with\".
328. The \"fact\" field is a human-readable sentence summarizing the relationship.
339. If a message contains a temporal change (e.g., \"switched from X to Y\"), include a \
34temporal_hint like \"replaced X\" or \"since January 2026\".
3510. Do not extract entities from greetings, filler, or meta-conversation (\"hi\", \"thanks\", \"ok\").
3611. Do not extract personal identifiable information as entity names: email addresses, \
37phone numbers, physical addresses, SSNs, or API keys. Use generic references instead.
3812. Always output entity names and relation verbs in English. Translate if needed.
3913. Return empty arrays if no entities or relationships are present.
40
41Output JSON schema:
42{
43  \"entities\": [
44    {\"name\": \"string\", \"type\": \"person|project|technology|organization|concept\", \"summary\": \"optional string\"}
45  ],
46  \"edges\": [
47    {\"source\": \"entity name\", \"target\": \"entity name\", \"relation\": \"verb phrase\", \"fact\": \"human-readable sentence\", \"temporal_hint\": \"optional string\"}
48  ]
49}";
50
51#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
52pub struct ExtractionResult {
53    pub entities: Vec<ExtractedEntity>,
54    pub edges: Vec<ExtractedEdge>,
55}
56
57#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
58pub struct ExtractedEntity {
59    pub name: String,
60    #[serde(rename = "type")]
61    pub entity_type: String,
62    pub summary: Option<String>,
63}
64
65#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
66pub struct ExtractedEdge {
67    pub source: String,
68    pub target: String,
69    pub relation: String,
70    pub fact: String,
71    pub temporal_hint: Option<String>,
72}
73
74pub struct GraphExtractor {
75    provider: AnyProvider,
76    max_entities: usize,
77    max_edges: usize,
78}
79
80impl GraphExtractor {
81    #[must_use]
82    pub fn new(provider: AnyProvider, max_entities: usize, max_edges: usize) -> Self {
83        Self {
84            provider,
85            max_entities,
86            max_edges,
87        }
88    }
89
90    /// Extract entities and relations from a message with surrounding context.
91    ///
92    /// Returns `None` if the message is empty, extraction fails, or the LLM returns
93    /// unparseable output. Callers should treat `None` as a graceful degradation.
94    ///
95    /// # Errors
96    ///
97    /// Returns an error only for transport-level failures (network, auth).
98    /// JSON parse failures are logged and return `Ok(None)`.
99    pub async fn extract(
100        &self,
101        message: &str,
102        context_messages: &[&str],
103    ) -> Result<Option<ExtractionResult>, MemoryError> {
104        if message.trim().is_empty() {
105            return Ok(None);
106        }
107
108        let user_prompt = build_user_prompt(message, context_messages);
109        let messages = [
110            Message::from_legacy(Role::System, SYSTEM_PROMPT),
111            Message::from_legacy(Role::User, user_prompt),
112        ];
113
114        match self
115            .provider
116            .chat_typed_erased::<ExtractionResult>(&messages)
117            .await
118        {
119            Ok(mut result) => {
120                result.entities.truncate(self.max_entities);
121                result.edges.truncate(self.max_edges);
122                Ok(Some(result))
123            }
124            Err(LlmError::StructuredParse(msg)) => {
125                tracing::warn!(
126                    "graph extraction: LLM returned unparseable output (len={}): {:.200}",
127                    msg.len(),
128                    msg
129                );
130                Ok(None)
131            }
132            Err(other) => Err(MemoryError::Llm(other)),
133        }
134    }
135}
136
137fn build_user_prompt(message: &str, context_messages: &[&str]) -> String {
138    if context_messages.is_empty() {
139        format!("Current message:\n{message}\n\nExtract entities and relationships as JSON.")
140    } else {
141        let n = context_messages.len();
142        let context = context_messages.join("\n");
143        format!(
144            "Context (last {n} messages):\n{context}\n\nCurrent message:\n{message}\n\nExtract entities and relationships as JSON."
145        )
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    fn make_entity(name: &str, entity_type: &str, summary: Option<&str>) -> ExtractedEntity {
154        ExtractedEntity {
155            name: name.into(),
156            entity_type: entity_type.into(),
157            summary: summary.map(Into::into),
158        }
159    }
160
161    fn make_edge(
162        source: &str,
163        target: &str,
164        relation: &str,
165        fact: &str,
166        temporal_hint: Option<&str>,
167    ) -> ExtractedEdge {
168        ExtractedEdge {
169            source: source.into(),
170            target: target.into(),
171            relation: relation.into(),
172            fact: fact.into(),
173            temporal_hint: temporal_hint.map(Into::into),
174        }
175    }
176
177    #[test]
178    fn extraction_result_deserialize_valid_json() {
179        let json = r#"{"entities":[{"name":"Rust","type":"language","summary":"A systems language"}],"edges":[]}"#;
180        let result: ExtractionResult = serde_json::from_str(json).unwrap();
181        assert_eq!(result.entities.len(), 1);
182        assert_eq!(result.entities[0].name, "Rust");
183        assert_eq!(result.entities[0].entity_type, "language");
184        assert_eq!(
185            result.entities[0].summary.as_deref(),
186            Some("A systems language")
187        );
188        assert!(result.edges.is_empty());
189    }
190
191    #[test]
192    fn extraction_result_deserialize_empty_arrays() {
193        let json = r#"{"entities":[],"edges":[]}"#;
194        let result: ExtractionResult = serde_json::from_str(json).unwrap();
195        assert!(result.entities.is_empty());
196        assert!(result.edges.is_empty());
197    }
198
199    #[test]
200    fn extraction_result_deserialize_missing_optional_fields() {
201        let json = r#"{"entities":[{"name":"Alice","type":"person","summary":null}],"edges":[{"source":"Alice","target":"Rust","relation":"uses","fact":"Alice uses Rust","temporal_hint":null}]}"#;
202        let result: ExtractionResult = serde_json::from_str(json).unwrap();
203        assert_eq!(result.entities[0].summary, None);
204        assert_eq!(result.edges[0].temporal_hint, None);
205    }
206
207    #[test]
208    fn extracted_entity_type_field_rename() {
209        let json = r#"{"name":"cargo","type":"tool","summary":null}"#;
210        let entity: ExtractedEntity = serde_json::from_str(json).unwrap();
211        assert_eq!(entity.entity_type, "tool");
212
213        let serialized = serde_json::to_string(&entity).unwrap();
214        assert!(serialized.contains("\"type\""));
215        assert!(!serialized.contains("\"entity_type\""));
216    }
217
218    #[test]
219    fn extraction_result_roundtrip() {
220        let original = ExtractionResult {
221            entities: vec![make_entity("Rust", "language", Some("A systems language"))],
222            edges: vec![make_edge("Alice", "Rust", "uses", "Alice uses Rust", None)],
223        };
224        let json = serde_json::to_string(&original).unwrap();
225        let restored: ExtractionResult = serde_json::from_str(&json).unwrap();
226        assert_eq!(original, restored);
227    }
228
229    #[test]
230    fn extraction_result_json_schema() {
231        let schema = schemars::schema_for!(ExtractionResult);
232        let value = serde_json::to_value(&schema).unwrap();
233        let schema_obj = value.as_object().unwrap();
234        assert!(
235            schema_obj.contains_key("title") || schema_obj.contains_key("properties"),
236            "schema should have top-level keys"
237        );
238        let json_str = serde_json::to_string(&schema).unwrap();
239        assert!(
240            json_str.contains("entities"),
241            "schema should contain 'entities'"
242        );
243        assert!(json_str.contains("edges"), "schema should contain 'edges'");
244    }
245
246    #[test]
247    fn build_user_prompt_with_context() {
248        let prompt = build_user_prompt("Hello Rust", &["prev message 1", "prev message 2"]);
249        assert!(prompt.contains("Context (last 2 messages):"));
250        assert!(prompt.contains("prev message 1\nprev message 2"));
251        assert!(prompt.contains("Current message:\nHello Rust"));
252        assert!(prompt.contains("Extract entities and relationships as JSON."));
253    }
254
255    #[test]
256    fn build_user_prompt_without_context() {
257        let prompt = build_user_prompt("Hello Rust", &[]);
258        assert!(!prompt.contains("Context"));
259        assert!(prompt.contains("Current message:\nHello Rust"));
260        assert!(prompt.contains("Extract entities and relationships as JSON."));
261    }
262
263    mod mock_tests {
264        use super::*;
265        use zeph_llm::mock::MockProvider;
266
267        fn make_entities_json(count: usize) -> String {
268            let entities: Vec<String> = (0..count)
269                .map(|i| format!(r#"{{"name":"entity{i}","type":"concept","summary":null}}"#))
270                .collect();
271            format!(r#"{{"entities":[{}],"edges":[]}}"#, entities.join(","))
272        }
273
274        fn make_edges_json(count: usize) -> String {
275            let edges: Vec<String> = (0..count)
276                .map(|i| {
277                    format!(
278                        r#"{{"source":"A","target":"B{i}","relation":"uses","fact":"A uses B{i}","temporal_hint":null}}"#
279                    )
280                })
281                .collect();
282            format!(r#"{{"entities":[],"edges":[{}]}}"#, edges.join(","))
283        }
284
285        #[tokio::test]
286        async fn extract_truncates_to_max_entities() {
287            let json = make_entities_json(20);
288            let mock = MockProvider::with_responses(vec![json]);
289            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 5, 100);
290            let result = extractor.extract("test message", &[]).await.unwrap();
291            let result = result.unwrap();
292            assert_eq!(result.entities.len(), 5);
293        }
294
295        #[tokio::test]
296        async fn extract_truncates_to_max_edges() {
297            let json = make_edges_json(15);
298            let mock = MockProvider::with_responses(vec![json]);
299            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 100, 3);
300            let result = extractor.extract("test message", &[]).await.unwrap();
301            let result = result.unwrap();
302            assert_eq!(result.edges.len(), 3);
303        }
304
305        #[tokio::test]
306        async fn extract_returns_none_on_parse_failure() {
307            let mock = MockProvider::with_responses(vec!["not valid json at all".into()]);
308            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
309            let result = extractor.extract("test message", &[]).await.unwrap();
310            assert!(result.is_none());
311        }
312
313        #[tokio::test]
314        async fn extract_returns_err_on_transport_failure() {
315            let mock = MockProvider::default()
316                .with_errors(vec![zeph_llm::LlmError::Other("connection refused".into())]);
317            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
318            let result = extractor.extract("test message", &[]).await;
319            assert!(result.is_err());
320            assert!(matches!(result.unwrap_err(), MemoryError::Llm(_)));
321        }
322
323        #[tokio::test]
324        async fn extract_returns_none_on_empty_message() {
325            let mock = MockProvider::with_responses(vec!["should not be called".into()]);
326            let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
327
328            let result_empty = extractor.extract("", &[]).await.unwrap();
329            assert!(result_empty.is_none());
330
331            let result_whitespace = extractor.extract("   \t\n  ", &[]).await.unwrap();
332            assert!(result_whitespace.is_none());
333        }
334    }
335}