1use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use zeph_llm::LlmError;
7use zeph_llm::any::AnyProvider;
8use zeph_llm::provider::{Message, Role};
9
10use crate::error::MemoryError;
11
12const SYSTEM_PROMPT: &str = "\
13You are an entity and relationship extractor. Given a conversation message and \
14its recent context, extract structured knowledge as JSON.
15
16Rules:
171. Extract only entities that appear in natural conversational text — user statements, \
18preferences, opinions, or factual claims made by a person.
192. Do NOT extract entities from: tool outputs, command results, file contents, \
20configuration files, JSON/TOML/YAML data, code snippets, or error messages. \
21If the message is structured data or raw command output, return empty arrays.
223. Do NOT extract structural data: config keys, file paths, tool names, TOML/JSON keys, \
23programming keywords, or single-letter identifiers.
244. Entity types must be one of: person, project, technology, organization, concept. \
25\"technology\" covers programming languages, frameworks, tools, and software. \
26\"concept\" covers abstract ideas, methodologies, and practices.
275. Only extract entities with clear semantic meaning about people, projects, or domain knowledge.
286. Entity names must be at least 3 characters long. Reject single characters, two-letter \
29tokens (e.g. standalone \"go\", \"cd\"), URLs, and bare file paths.
307. Relations should be short verb phrases: \"prefers\", \"uses\", \"works_on\", \"knows\", \
31\"created\", \"depends_on\", \"replaces\", \"configured_with\".
328. The \"fact\" field is a human-readable sentence summarizing the relationship.
339. If a message contains a temporal change (e.g., \"switched from X to Y\"), include a \
34temporal_hint like \"replaced X\" or \"since January 2026\".
3510. Do not extract entities from greetings, filler, or meta-conversation (\"hi\", \"thanks\", \"ok\").
3611. Do not extract personal identifiable information as entity names: email addresses, \
37phone numbers, physical addresses, SSNs, or API keys. Use generic references instead.
3812. Always output entity names and relation verbs in English. Translate if needed.
3913. Return empty arrays if no entities or relationships are present.
40
41Output JSON schema:
42{
43 \"entities\": [
44 {\"name\": \"string\", \"type\": \"person|project|technology|organization|concept\", \"summary\": \"optional string\"}
45 ],
46 \"edges\": [
47 {\"source\": \"entity name\", \"target\": \"entity name\", \"relation\": \"verb phrase\", \"fact\": \"human-readable sentence\", \"temporal_hint\": \"optional string\"}
48 ]
49}";
50
51#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
52pub struct ExtractionResult {
53 pub entities: Vec<ExtractedEntity>,
54 pub edges: Vec<ExtractedEdge>,
55}
56
57#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
58pub struct ExtractedEntity {
59 pub name: String,
60 #[serde(rename = "type")]
61 pub entity_type: String,
62 pub summary: Option<String>,
63}
64
65#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
66pub struct ExtractedEdge {
67 pub source: String,
68 pub target: String,
69 pub relation: String,
70 pub fact: String,
71 pub temporal_hint: Option<String>,
72}
73
74pub struct GraphExtractor {
75 provider: AnyProvider,
76 max_entities: usize,
77 max_edges: usize,
78}
79
80impl GraphExtractor {
81 #[must_use]
82 pub fn new(provider: AnyProvider, max_entities: usize, max_edges: usize) -> Self {
83 Self {
84 provider,
85 max_entities,
86 max_edges,
87 }
88 }
89
90 pub async fn extract(
100 &self,
101 message: &str,
102 context_messages: &[&str],
103 ) -> Result<Option<ExtractionResult>, MemoryError> {
104 if message.trim().is_empty() {
105 return Ok(None);
106 }
107
108 let user_prompt = build_user_prompt(message, context_messages);
109 let messages = [
110 Message::from_legacy(Role::System, SYSTEM_PROMPT),
111 Message::from_legacy(Role::User, user_prompt),
112 ];
113
114 match self
115 .provider
116 .chat_typed_erased::<ExtractionResult>(&messages)
117 .await
118 {
119 Ok(mut result) => {
120 result.entities.truncate(self.max_entities);
121 result.edges.truncate(self.max_edges);
122 Ok(Some(result))
123 }
124 Err(LlmError::StructuredParse(msg)) => {
125 tracing::warn!(
126 "graph extraction: LLM returned unparseable output (len={}): {:.200}",
127 msg.len(),
128 msg
129 );
130 Ok(None)
131 }
132 Err(other) => Err(MemoryError::Llm(other)),
133 }
134 }
135}
136
137fn build_user_prompt(message: &str, context_messages: &[&str]) -> String {
138 if context_messages.is_empty() {
139 format!("Current message:\n{message}\n\nExtract entities and relationships as JSON.")
140 } else {
141 let n = context_messages.len();
142 let context = context_messages.join("\n");
143 format!(
144 "Context (last {n} messages):\n{context}\n\nCurrent message:\n{message}\n\nExtract entities and relationships as JSON."
145 )
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 fn make_entity(name: &str, entity_type: &str, summary: Option<&str>) -> ExtractedEntity {
154 ExtractedEntity {
155 name: name.into(),
156 entity_type: entity_type.into(),
157 summary: summary.map(Into::into),
158 }
159 }
160
161 fn make_edge(
162 source: &str,
163 target: &str,
164 relation: &str,
165 fact: &str,
166 temporal_hint: Option<&str>,
167 ) -> ExtractedEdge {
168 ExtractedEdge {
169 source: source.into(),
170 target: target.into(),
171 relation: relation.into(),
172 fact: fact.into(),
173 temporal_hint: temporal_hint.map(Into::into),
174 }
175 }
176
177 #[test]
178 fn extraction_result_deserialize_valid_json() {
179 let json = r#"{"entities":[{"name":"Rust","type":"language","summary":"A systems language"}],"edges":[]}"#;
180 let result: ExtractionResult = serde_json::from_str(json).unwrap();
181 assert_eq!(result.entities.len(), 1);
182 assert_eq!(result.entities[0].name, "Rust");
183 assert_eq!(result.entities[0].entity_type, "language");
184 assert_eq!(
185 result.entities[0].summary.as_deref(),
186 Some("A systems language")
187 );
188 assert!(result.edges.is_empty());
189 }
190
191 #[test]
192 fn extraction_result_deserialize_empty_arrays() {
193 let json = r#"{"entities":[],"edges":[]}"#;
194 let result: ExtractionResult = serde_json::from_str(json).unwrap();
195 assert!(result.entities.is_empty());
196 assert!(result.edges.is_empty());
197 }
198
199 #[test]
200 fn extraction_result_deserialize_missing_optional_fields() {
201 let json = r#"{"entities":[{"name":"Alice","type":"person","summary":null}],"edges":[{"source":"Alice","target":"Rust","relation":"uses","fact":"Alice uses Rust","temporal_hint":null}]}"#;
202 let result: ExtractionResult = serde_json::from_str(json).unwrap();
203 assert_eq!(result.entities[0].summary, None);
204 assert_eq!(result.edges[0].temporal_hint, None);
205 }
206
207 #[test]
208 fn extracted_entity_type_field_rename() {
209 let json = r#"{"name":"cargo","type":"tool","summary":null}"#;
210 let entity: ExtractedEntity = serde_json::from_str(json).unwrap();
211 assert_eq!(entity.entity_type, "tool");
212
213 let serialized = serde_json::to_string(&entity).unwrap();
214 assert!(serialized.contains("\"type\""));
215 assert!(!serialized.contains("\"entity_type\""));
216 }
217
218 #[test]
219 fn extraction_result_roundtrip() {
220 let original = ExtractionResult {
221 entities: vec![make_entity("Rust", "language", Some("A systems language"))],
222 edges: vec![make_edge("Alice", "Rust", "uses", "Alice uses Rust", None)],
223 };
224 let json = serde_json::to_string(&original).unwrap();
225 let restored: ExtractionResult = serde_json::from_str(&json).unwrap();
226 assert_eq!(original, restored);
227 }
228
229 #[test]
230 fn extraction_result_json_schema() {
231 let schema = schemars::schema_for!(ExtractionResult);
232 let value = serde_json::to_value(&schema).unwrap();
233 let schema_obj = value.as_object().unwrap();
234 assert!(
235 schema_obj.contains_key("title") || schema_obj.contains_key("properties"),
236 "schema should have top-level keys"
237 );
238 let json_str = serde_json::to_string(&schema).unwrap();
239 assert!(
240 json_str.contains("entities"),
241 "schema should contain 'entities'"
242 );
243 assert!(json_str.contains("edges"), "schema should contain 'edges'");
244 }
245
246 #[test]
247 fn build_user_prompt_with_context() {
248 let prompt = build_user_prompt("Hello Rust", &["prev message 1", "prev message 2"]);
249 assert!(prompt.contains("Context (last 2 messages):"));
250 assert!(prompt.contains("prev message 1\nprev message 2"));
251 assert!(prompt.contains("Current message:\nHello Rust"));
252 assert!(prompt.contains("Extract entities and relationships as JSON."));
253 }
254
255 #[test]
256 fn build_user_prompt_without_context() {
257 let prompt = build_user_prompt("Hello Rust", &[]);
258 assert!(!prompt.contains("Context"));
259 assert!(prompt.contains("Current message:\nHello Rust"));
260 assert!(prompt.contains("Extract entities and relationships as JSON."));
261 }
262
263 mod mock_tests {
264 use super::*;
265 use zeph_llm::mock::MockProvider;
266
267 fn make_entities_json(count: usize) -> String {
268 let entities: Vec<String> = (0..count)
269 .map(|i| format!(r#"{{"name":"entity{i}","type":"concept","summary":null}}"#))
270 .collect();
271 format!(r#"{{"entities":[{}],"edges":[]}}"#, entities.join(","))
272 }
273
274 fn make_edges_json(count: usize) -> String {
275 let edges: Vec<String> = (0..count)
276 .map(|i| {
277 format!(
278 r#"{{"source":"A","target":"B{i}","relation":"uses","fact":"A uses B{i}","temporal_hint":null}}"#
279 )
280 })
281 .collect();
282 format!(r#"{{"entities":[],"edges":[{}]}}"#, edges.join(","))
283 }
284
285 #[tokio::test]
286 async fn extract_truncates_to_max_entities() {
287 let json = make_entities_json(20);
288 let mock = MockProvider::with_responses(vec![json]);
289 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 5, 100);
290 let result = extractor.extract("test message", &[]).await.unwrap();
291 let result = result.unwrap();
292 assert_eq!(result.entities.len(), 5);
293 }
294
295 #[tokio::test]
296 async fn extract_truncates_to_max_edges() {
297 let json = make_edges_json(15);
298 let mock = MockProvider::with_responses(vec![json]);
299 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 100, 3);
300 let result = extractor.extract("test message", &[]).await.unwrap();
301 let result = result.unwrap();
302 assert_eq!(result.edges.len(), 3);
303 }
304
305 #[tokio::test]
306 async fn extract_returns_none_on_parse_failure() {
307 let mock = MockProvider::with_responses(vec!["not valid json at all".into()]);
308 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
309 let result = extractor.extract("test message", &[]).await.unwrap();
310 assert!(result.is_none());
311 }
312
313 #[tokio::test]
314 async fn extract_returns_err_on_transport_failure() {
315 let mock = MockProvider::default()
316 .with_errors(vec![zeph_llm::LlmError::Other("connection refused".into())]);
317 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
318 let result = extractor.extract("test message", &[]).await;
319 assert!(result.is_err());
320 assert!(matches!(result.unwrap_err(), MemoryError::Llm(_)));
321 }
322
323 #[tokio::test]
324 async fn extract_returns_none_on_empty_message() {
325 let mock = MockProvider::with_responses(vec!["should not be called".into()]);
326 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
327
328 let result_empty = extractor.extract("", &[]).await.unwrap();
329 assert!(result_empty.is_none());
330
331 let result_whitespace = extractor.extract(" \t\n ", &[]).await.unwrap();
332 assert!(result_whitespace.is_none());
333 }
334 }
335}