1use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use zeph_llm::LlmError;
7use zeph_llm::any::AnyProvider;
8use zeph_llm::provider::{Message, Role};
9
10use crate::error::MemoryError;
11
12const SYSTEM_PROMPT: &str = "\
13You are an entity and relationship extractor. Given a conversation message and \
14its recent context, extract structured knowledge as JSON.
15
16Rules:
171. Extract only entities that appear in natural conversational text — user statements, \
18preferences, opinions, or factual claims made by a person.
192. Do NOT extract entities from: tool outputs, command results, file contents, \
20configuration files, JSON/TOML/YAML data, code snippets, or error messages. \
21If the message is structured data or raw command output, return empty arrays.
223. Do NOT extract structural data: config keys, file paths, tool names, TOML/JSON keys, \
23programming keywords, or single-letter identifiers.
244. Entity types must be one of: person, project, tool, language, organization, concept. \
25\"tool\" covers frameworks, software tools, and libraries. \
26\"language\" covers programming and natural languages. \
27\"concept\" covers abstract ideas, methodologies, and practices.
285. Only extract entities with clear semantic meaning about people, projects, or domain knowledge.
296. Entity names must be at least 3 characters long. Reject single characters, two-letter \
30tokens (e.g. standalone \"go\", \"cd\"), URLs, and bare file paths.
317. Relations should be short verb phrases: \"prefers\", \"uses\", \"works_on\", \"knows\", \
32\"created\", \"depends_on\", \"replaces\", \"configured_with\".
338. The \"fact\" field is a human-readable sentence summarizing the relationship.
349. If a message contains a temporal change (e.g., \"switched from X to Y\"), include a \
35temporal_hint like \"replaced X\" or \"since January 2026\".
3610. Each edge must include an \"edge_type\" field classifying the relationship:
37 - \"semantic\": conceptual relationships (uses, prefers, knows, works_on, depends_on, created)
38 - \"temporal\": time-ordered events (preceded_by, followed_by, happened_during, started_before)
39 - \"causal\": cause-effect chains (caused, triggered, resulted_in, led_to, prevented)
40 - \"entity\": identity/structural relationships (is_a, part_of, instance_of, alias_of, replaces)
41 Default to \"semantic\" if the relationship type is uncertain.
4211. Do not extract entities from greetings, filler, or meta-conversation (\"hi\", \"thanks\", \"ok\").
4312. Do not extract personal identifiable information as entity names: email addresses, \
44phone numbers, physical addresses, SSNs, or API keys. Use generic references instead.
4513. Always output entity names and relation verbs in English. Translate if needed.
4614. Return empty arrays if no entities or relationships are present.
47
48Output JSON schema:
49{
50 \"entities\": [
51 {\"name\": \"string\", \"type\": \"person|project|tool|language|organization|concept\", \"summary\": \"optional string\"}
52 ],
53 \"edges\": [
54 {\"source\": \"entity name\", \"target\": \"entity name\", \"relation\": \"verb phrase\", \"fact\": \"human-readable sentence\", \"temporal_hint\": \"optional string\", \"edge_type\": \"semantic|temporal|causal|entity\"}
55 ]
56}";
57
58#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
59pub struct ExtractionResult {
60 pub entities: Vec<ExtractedEntity>,
61 pub edges: Vec<ExtractedEdge>,
62}
63
64#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
65pub struct ExtractedEntity {
66 pub name: String,
67 #[serde(rename = "type")]
68 pub entity_type: String,
69 pub summary: Option<String>,
70}
71
72fn default_semantic() -> String {
73 "semantic".to_owned()
74}
75
76#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
77pub struct ExtractedEdge {
78 pub source: String,
79 pub target: String,
80 pub relation: String,
81 pub fact: String,
82 pub temporal_hint: Option<String>,
83 #[serde(default = "default_semantic")]
85 pub edge_type: String,
86}
87
88pub struct GraphExtractor {
89 provider: AnyProvider,
90 max_entities: usize,
91 max_edges: usize,
92}
93
94impl GraphExtractor {
95 #[must_use]
96 pub fn new(provider: AnyProvider, max_entities: usize, max_edges: usize) -> Self {
97 Self {
98 provider,
99 max_entities,
100 max_edges,
101 }
102 }
103
104 pub async fn extract(
114 &self,
115 message: &str,
116 context_messages: &[&str],
117 ) -> Result<Option<ExtractionResult>, MemoryError> {
118 if message.trim().is_empty() {
119 return Ok(None);
120 }
121
122 let user_prompt = build_user_prompt(message, context_messages);
123 let messages = [
124 Message::from_legacy(Role::System, SYSTEM_PROMPT),
125 Message::from_legacy(Role::User, user_prompt),
126 ];
127
128 match self
129 .provider
130 .chat_typed_erased::<ExtractionResult>(&messages)
131 .await
132 {
133 Ok(mut result) => {
134 result.entities.truncate(self.max_entities);
135 result.edges.truncate(self.max_edges);
136 Ok(Some(result))
137 }
138 Err(LlmError::StructuredParse(msg)) => {
139 tracing::warn!(
140 "graph extraction: LLM returned unparseable output (len={}): {:.200}",
141 msg.len(),
142 msg
143 );
144 Ok(None)
145 }
146 Err(other) => Err(MemoryError::Llm(other)),
147 }
148 }
149}
150
151fn build_user_prompt(message: &str, context_messages: &[&str]) -> String {
152 if context_messages.is_empty() {
153 format!("Current message:\n{message}\n\nExtract entities and relationships as JSON.")
154 } else {
155 let n = context_messages.len();
156 let context = context_messages.join("\n");
157 format!(
158 "Context (last {n} messages):\n{context}\n\nCurrent message:\n{message}\n\nExtract entities and relationships as JSON."
159 )
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 fn make_entity(name: &str, entity_type: &str, summary: Option<&str>) -> ExtractedEntity {
168 ExtractedEntity {
169 name: name.into(),
170 entity_type: entity_type.into(),
171 summary: summary.map(Into::into),
172 }
173 }
174
175 fn make_edge(
176 source: &str,
177 target: &str,
178 relation: &str,
179 fact: &str,
180 temporal_hint: Option<&str>,
181 ) -> ExtractedEdge {
182 ExtractedEdge {
183 source: source.into(),
184 target: target.into(),
185 relation: relation.into(),
186 fact: fact.into(),
187 temporal_hint: temporal_hint.map(Into::into),
188 edge_type: "semantic".into(),
189 }
190 }
191
192 #[test]
193 fn extraction_result_deserialize_valid_json() {
194 let json = r#"{"entities":[{"name":"Rust","type":"language","summary":"A systems language"}],"edges":[]}"#;
195 let result: ExtractionResult = serde_json::from_str(json).unwrap();
196 assert_eq!(result.entities.len(), 1);
197 assert_eq!(result.entities[0].name, "Rust");
198 assert_eq!(result.entities[0].entity_type, "language");
199 assert_eq!(
200 result.entities[0].summary.as_deref(),
201 Some("A systems language")
202 );
203 assert!(result.edges.is_empty());
204 }
205
206 #[test]
207 fn extraction_result_deserialize_empty_arrays() {
208 let json = r#"{"entities":[],"edges":[]}"#;
209 let result: ExtractionResult = serde_json::from_str(json).unwrap();
210 assert!(result.entities.is_empty());
211 assert!(result.edges.is_empty());
212 }
213
214 #[test]
215 fn extraction_result_deserialize_missing_optional_fields() {
216 let json = r#"{"entities":[{"name":"Alice","type":"person","summary":null}],"edges":[{"source":"Alice","target":"Rust","relation":"uses","fact":"Alice uses Rust","temporal_hint":null}]}"#;
217 let result: ExtractionResult = serde_json::from_str(json).unwrap();
218 assert_eq!(result.entities[0].summary, None);
219 assert_eq!(result.edges[0].temporal_hint, None);
220 assert_eq!(result.edges[0].edge_type, "semantic");
222 }
223
224 #[test]
225 fn extracted_edge_type_defaults_to_semantic_when_missing() {
226 let json = r#"{"source":"A","target":"B","relation":"uses","fact":"A uses B"}"#;
228 let edge: ExtractedEdge = serde_json::from_str(json).unwrap();
229 assert_eq!(edge.edge_type, "semantic");
230 }
231
232 #[test]
233 fn extracted_edge_type_parses_all_variants() {
234 for et in &["semantic", "temporal", "causal", "entity"] {
235 let json = format!(
236 r#"{{"source":"A","target":"B","relation":"r","fact":"f","edge_type":"{et}"}}"#
237 );
238 let edge: ExtractedEdge = serde_json::from_str(&json).unwrap();
239 assert_eq!(&edge.edge_type, et);
240 }
241 }
242
243 #[test]
244 fn extraction_result_with_edge_types_roundtrip() {
245 let original = ExtractionResult {
246 entities: vec![],
247 edges: vec![
248 ExtractedEdge {
249 source: "A".into(),
250 target: "B".into(),
251 relation: "caused".into(),
252 fact: "A caused B".into(),
253 temporal_hint: None,
254 edge_type: "causal".into(),
255 },
256 ExtractedEdge {
257 source: "B".into(),
258 target: "C".into(),
259 relation: "preceded_by".into(),
260 fact: "B preceded_by C".into(),
261 temporal_hint: None,
262 edge_type: "temporal".into(),
263 },
264 ],
265 };
266 let json = serde_json::to_string(&original).unwrap();
267 let restored: ExtractionResult = serde_json::from_str(&json).unwrap();
268 assert_eq!(original, restored);
269 assert_eq!(restored.edges[0].edge_type, "causal");
270 assert_eq!(restored.edges[1].edge_type, "temporal");
271 }
272
273 #[test]
274 fn extracted_entity_type_field_rename() {
275 let json = r#"{"name":"cargo","type":"tool","summary":null}"#;
276 let entity: ExtractedEntity = serde_json::from_str(json).unwrap();
277 assert_eq!(entity.entity_type, "tool");
278
279 let serialized = serde_json::to_string(&entity).unwrap();
280 assert!(serialized.contains("\"type\""));
281 assert!(!serialized.contains("\"entity_type\""));
282 }
283
284 #[test]
285 fn extraction_result_roundtrip() {
286 let original = ExtractionResult {
287 entities: vec![make_entity("Rust", "language", Some("A systems language"))],
288 edges: vec![make_edge("Alice", "Rust", "uses", "Alice uses Rust", None)],
289 };
290 let json = serde_json::to_string(&original).unwrap();
291 let restored: ExtractionResult = serde_json::from_str(&json).unwrap();
292 assert_eq!(original, restored);
293 }
294
295 #[test]
296 fn extraction_result_json_schema() {
297 let schema = schemars::schema_for!(ExtractionResult);
298 let value = serde_json::to_value(&schema).unwrap();
299 let schema_obj = value.as_object().unwrap();
300 assert!(
301 schema_obj.contains_key("title") || schema_obj.contains_key("properties"),
302 "schema should have top-level keys"
303 );
304 let json_str = serde_json::to_string(&schema).unwrap();
305 assert!(
306 json_str.contains("entities"),
307 "schema should contain 'entities'"
308 );
309 assert!(json_str.contains("edges"), "schema should contain 'edges'");
310 }
311
312 #[test]
313 fn build_user_prompt_with_context() {
314 let prompt = build_user_prompt("Hello Rust", &["prev message 1", "prev message 2"]);
315 assert!(prompt.contains("Context (last 2 messages):"));
316 assert!(prompt.contains("prev message 1\nprev message 2"));
317 assert!(prompt.contains("Current message:\nHello Rust"));
318 assert!(prompt.contains("Extract entities and relationships as JSON."));
319 }
320
321 #[test]
322 fn build_user_prompt_without_context() {
323 let prompt = build_user_prompt("Hello Rust", &[]);
324 assert!(!prompt.contains("Context"));
325 assert!(prompt.contains("Current message:\nHello Rust"));
326 assert!(prompt.contains("Extract entities and relationships as JSON."));
327 }
328
329 mod mock_tests {
330 use super::*;
331 use zeph_llm::mock::MockProvider;
332
333 fn make_entities_json(count: usize) -> String {
334 let entities: Vec<String> = (0..count)
335 .map(|i| format!(r#"{{"name":"entity{i}","type":"concept","summary":null}}"#))
336 .collect();
337 format!(r#"{{"entities":[{}],"edges":[]}}"#, entities.join(","))
338 }
339
340 fn make_edges_json(count: usize) -> String {
341 let edges: Vec<String> = (0..count)
342 .map(|i| {
343 format!(
344 r#"{{"source":"A","target":"B{i}","relation":"uses","fact":"A uses B{i}","temporal_hint":null}}"#
345 )
346 })
347 .collect();
348 format!(r#"{{"entities":[],"edges":[{}]}}"#, edges.join(","))
349 }
350
351 #[tokio::test]
352 async fn extract_truncates_to_max_entities() {
353 let json = make_entities_json(20);
354 let mock = MockProvider::with_responses(vec![json]);
355 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 5, 100);
356 let result = extractor.extract("test message", &[]).await.unwrap();
357 let result = result.unwrap();
358 assert_eq!(result.entities.len(), 5);
359 }
360
361 #[tokio::test]
362 async fn extract_truncates_to_max_edges() {
363 let json = make_edges_json(15);
364 let mock = MockProvider::with_responses(vec![json]);
365 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 100, 3);
366 let result = extractor.extract("test message", &[]).await.unwrap();
367 let result = result.unwrap();
368 assert_eq!(result.edges.len(), 3);
369 }
370
371 #[tokio::test]
372 async fn extract_returns_none_on_parse_failure() {
373 let mock = MockProvider::with_responses(vec!["not valid json at all".into()]);
374 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
375 let result = extractor.extract("test message", &[]).await.unwrap();
376 assert!(result.is_none());
377 }
378
379 #[tokio::test]
380 async fn extract_returns_err_on_transport_failure() {
381 let mock = MockProvider::default()
382 .with_errors(vec![zeph_llm::LlmError::Other("connection refused".into())]);
383 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
384 let result = extractor.extract("test message", &[]).await;
385 assert!(result.is_err());
386 assert!(matches!(result.unwrap_err(), MemoryError::Llm(_)));
387 }
388
389 #[tokio::test]
390 async fn extract_returns_none_on_empty_message() {
391 let mock = MockProvider::with_responses(vec!["should not be called".into()]);
392 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
393
394 let result_empty = extractor.extract("", &[]).await.unwrap();
395 assert!(result_empty.is_none());
396
397 let result_whitespace = extractor.extract(" \t\n ", &[]).await.unwrap();
398 assert!(result_whitespace.is_none());
399 }
400 }
401}