1use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use zeph_llm::LlmError;
7use zeph_llm::any::AnyProvider;
8use zeph_llm::provider::{Message, Role};
9
10use crate::error::MemoryError;
11
12const SYSTEM_PROMPT: &str = "\
13You are an entity and relationship extractor. Given a conversation message and \
14its recent context, extract structured knowledge as JSON.
15
16Rules:
171. Extract only entities that appear in natural conversational text — user statements, \
18preferences, opinions, or factual claims made by a person.
192. Do NOT extract entities from: tool outputs, command results, file contents, \
20configuration files, JSON/TOML/YAML data, code snippets, or error messages. \
21If the message is structured data or raw command output, return empty arrays.
223. Do NOT extract structural data: config keys, file paths, tool names, TOML/JSON keys, \
23programming keywords, or single-letter identifiers.
244. Entity types must be one of: person, project, tool, language, organization, concept. \
25\"tool\" covers frameworks, software tools, and libraries. \
26\"language\" covers programming and natural languages. \
27\"concept\" covers abstract ideas, methodologies, and practices.
285. Only extract entities with clear semantic meaning about people, projects, or domain knowledge.
296. Entity names must be at least 3 characters long. Reject single characters, two-letter \
30tokens (e.g. standalone \"go\", \"cd\"), URLs, and bare file paths.
317. Relations should be short verb phrases: \"prefers\", \"uses\", \"works_on\", \"knows\", \
32\"created\", \"depends_on\", \"replaces\", \"configured_with\".
338. The \"fact\" field is a human-readable sentence summarizing the relationship.
349. If a message contains a temporal change (e.g., \"switched from X to Y\"), include a \
35temporal_hint like \"replaced X\" or \"since January 2026\".
3610. Each edge must include an \"edge_type\" field classifying the relationship:
37 - \"semantic\": conceptual relationships (uses, prefers, knows, works_on, depends_on, created)
38 - \"temporal\": time-ordered events (preceded_by, followed_by, happened_during, started_before)
39 - \"causal\": cause-effect chains (caused, triggered, resulted_in, led_to, prevented)
40 - \"entity\": identity/structural relationships (is_a, part_of, instance_of, alias_of, replaces)
41 Default to \"semantic\" if the relationship type is uncertain.
4211. Each edge must include a \"confidence\" field: a float in [0.0, 1.0] reflecting how \
43certain you are that this relationship was explicitly stated or strongly implied by the text. \
44Use 1.0 only for direct verbatim statements. Use 0.5–0.8 for clear implications. \
45Use 0.3–0.5 for weak inferences. Omit or use null if you cannot assign a meaningful score.
4611. Do not extract entities from greetings, filler, or meta-conversation (\"hi\", \"thanks\", \"ok\").
4712. Do not extract personal identifiable information as entity names: email addresses, \
48phone numbers, physical addresses, SSNs, or API keys. Use generic references instead.
4913. Always output entity names and relation verbs in English. Translate if needed.
5014. Return empty arrays if no entities or relationships are present.
51
52Output JSON schema:
53{
54 \"entities\": [
55 {\"name\": \"string\", \"type\": \"person|project|tool|language|organization|concept\", \"summary\": \"optional string\"}
56 ],
57 \"edges\": [
58 {\"source\": \"entity name\", \"target\": \"entity name\", \"relation\": \"verb phrase\", \"fact\": \"human-readable sentence\", \"temporal_hint\": \"optional string\", \"edge_type\": \"semantic|temporal|causal|entity\", \"confidence\": 0.0}
59 ]
60}";
61
62#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
63pub struct ExtractionResult {
64 pub entities: Vec<ExtractedEntity>,
65 pub edges: Vec<ExtractedEdge>,
66}
67
68#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
69pub struct ExtractedEntity {
70 pub name: String,
71 #[serde(rename = "type")]
72 pub entity_type: String,
73 pub summary: Option<String>,
74}
75
76fn default_semantic() -> String {
77 "semantic".to_owned()
78}
79
80#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
81pub struct ExtractedEdge {
82 pub source: String,
83 pub target: String,
84 pub relation: String,
85 pub fact: String,
86 pub temporal_hint: Option<String>,
87 #[serde(default = "default_semantic")]
89 pub edge_type: String,
90 #[serde(default)]
97 pub confidence: Option<f32>,
98}
99
100pub struct GraphExtractor {
101 provider: AnyProvider,
102 max_entities: usize,
103 max_edges: usize,
104}
105
106impl GraphExtractor {
107 #[must_use]
108 pub fn new(provider: AnyProvider, max_entities: usize, max_edges: usize) -> Self {
109 Self {
110 provider,
111 max_entities,
112 max_edges,
113 }
114 }
115
116 pub async fn extract(
126 &self,
127 message: &str,
128 context_messages: &[&str],
129 ) -> Result<Option<ExtractionResult>, MemoryError> {
130 if message.trim().is_empty() {
131 return Ok(None);
132 }
133
134 let user_prompt = build_user_prompt(message, context_messages);
135 let messages = [
136 Message::from_legacy(Role::System, SYSTEM_PROMPT),
137 Message::from_legacy(Role::User, user_prompt),
138 ];
139
140 match self
141 .provider
142 .chat_typed_erased::<ExtractionResult>(&messages)
143 .await
144 {
145 Ok(mut result) => {
146 result.entities.truncate(self.max_entities);
147 result.edges.truncate(self.max_edges);
148 Ok(Some(result))
149 }
150 Err(LlmError::StructuredParse(msg)) => {
151 tracing::warn!(
152 "graph extraction: LLM returned unparseable output (len={}): {:.200}",
153 msg.len(),
154 msg
155 );
156 Ok(None)
157 }
158 Err(other) => Err(MemoryError::Llm(other)),
159 }
160 }
161}
162
163fn build_user_prompt(message: &str, context_messages: &[&str]) -> String {
164 if context_messages.is_empty() {
165 format!("Current message:\n{message}\n\nExtract entities and relationships as JSON.")
166 } else {
167 let n = context_messages.len();
168 let context = context_messages.join("\n");
169 format!(
170 "Context (last {n} messages):\n{context}\n\nCurrent message:\n{message}\n\nExtract entities and relationships as JSON."
171 )
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 fn make_entity(name: &str, entity_type: &str, summary: Option<&str>) -> ExtractedEntity {
180 ExtractedEntity {
181 name: name.into(),
182 entity_type: entity_type.into(),
183 summary: summary.map(Into::into),
184 }
185 }
186
187 fn make_edge(
188 source: &str,
189 target: &str,
190 relation: &str,
191 fact: &str,
192 temporal_hint: Option<&str>,
193 ) -> ExtractedEdge {
194 ExtractedEdge {
195 source: source.into(),
196 target: target.into(),
197 relation: relation.into(),
198 fact: fact.into(),
199 temporal_hint: temporal_hint.map(Into::into),
200 edge_type: "semantic".into(),
201 confidence: None,
202 }
203 }
204
205 #[test]
206 fn extraction_result_deserialize_valid_json() {
207 let json = r#"{"entities":[{"name":"Rust","type":"language","summary":"A systems language"}],"edges":[]}"#;
208 let result: ExtractionResult = serde_json::from_str(json).unwrap();
209 assert_eq!(result.entities.len(), 1);
210 assert_eq!(result.entities[0].name, "Rust");
211 assert_eq!(result.entities[0].entity_type, "language");
212 assert_eq!(
213 result.entities[0].summary.as_deref(),
214 Some("A systems language")
215 );
216 assert!(result.edges.is_empty());
217 }
218
219 #[test]
220 fn extraction_result_deserialize_empty_arrays() {
221 let json = r#"{"entities":[],"edges":[]}"#;
222 let result: ExtractionResult = serde_json::from_str(json).unwrap();
223 assert!(result.entities.is_empty());
224 assert!(result.edges.is_empty());
225 }
226
227 #[test]
228 fn extraction_result_deserialize_missing_optional_fields() {
229 let json = r#"{"entities":[{"name":"Alice","type":"person","summary":null}],"edges":[{"source":"Alice","target":"Rust","relation":"uses","fact":"Alice uses Rust","temporal_hint":null}]}"#;
230 let result: ExtractionResult = serde_json::from_str(json).unwrap();
231 assert_eq!(result.entities[0].summary, None);
232 assert_eq!(result.edges[0].temporal_hint, None);
233 assert_eq!(result.edges[0].edge_type, "semantic");
235 }
236
237 #[test]
238 fn extracted_edge_type_defaults_to_semantic_when_missing() {
239 let json = r#"{"source":"A","target":"B","relation":"uses","fact":"A uses B"}"#;
241 let edge: ExtractedEdge = serde_json::from_str(json).unwrap();
242 assert_eq!(edge.edge_type, "semantic");
243 }
244
245 #[test]
246 fn extracted_edge_type_parses_all_variants() {
247 for et in &["semantic", "temporal", "causal", "entity"] {
248 let json = format!(
249 r#"{{"source":"A","target":"B","relation":"r","fact":"f","edge_type":"{et}"}}"#
250 );
251 let edge: ExtractedEdge = serde_json::from_str(&json).unwrap();
252 assert_eq!(&edge.edge_type, et);
253 }
254 }
255
256 #[test]
257 fn extraction_result_with_edge_types_roundtrip() {
258 let original = ExtractionResult {
259 entities: vec![],
260 edges: vec![
261 ExtractedEdge {
262 source: "A".into(),
263 target: "B".into(),
264 relation: "caused".into(),
265 fact: "A caused B".into(),
266 temporal_hint: None,
267 edge_type: "causal".into(),
268 confidence: Some(0.9),
269 },
270 ExtractedEdge {
271 source: "B".into(),
272 target: "C".into(),
273 relation: "preceded_by".into(),
274 fact: "B preceded_by C".into(),
275 temporal_hint: None,
276 edge_type: "temporal".into(),
277 confidence: None,
278 },
279 ],
280 };
281 let json = serde_json::to_string(&original).unwrap();
282 let restored: ExtractionResult = serde_json::from_str(&json).unwrap();
283 assert_eq!(original, restored);
284 assert_eq!(restored.edges[0].edge_type, "causal");
285 assert_eq!(restored.edges[1].edge_type, "temporal");
286 }
287
288 #[test]
289 fn extracted_entity_type_field_rename() {
290 let json = r#"{"name":"cargo","type":"tool","summary":null}"#;
291 let entity: ExtractedEntity = serde_json::from_str(json).unwrap();
292 assert_eq!(entity.entity_type, "tool");
293
294 let serialized = serde_json::to_string(&entity).unwrap();
295 assert!(serialized.contains("\"type\""));
296 assert!(!serialized.contains("\"entity_type\""));
297 }
298
299 #[test]
300 fn extraction_result_roundtrip() {
301 let original = ExtractionResult {
302 entities: vec![make_entity("Rust", "language", Some("A systems language"))],
303 edges: vec![make_edge("Alice", "Rust", "uses", "Alice uses Rust", None)],
304 };
305 let json = serde_json::to_string(&original).unwrap();
306 let restored: ExtractionResult = serde_json::from_str(&json).unwrap();
307 assert_eq!(original, restored);
308 }
309
310 #[test]
311 fn extraction_result_json_schema() {
312 let schema = schemars::schema_for!(ExtractionResult);
313 let value = serde_json::to_value(&schema).unwrap();
314 let schema_obj = value.as_object().unwrap();
315 assert!(
316 schema_obj.contains_key("title") || schema_obj.contains_key("properties"),
317 "schema should have top-level keys"
318 );
319 let json_str = serde_json::to_string(&schema).unwrap();
320 assert!(
321 json_str.contains("entities"),
322 "schema should contain 'entities'"
323 );
324 assert!(json_str.contains("edges"), "schema should contain 'edges'");
325 }
326
327 #[test]
328 fn build_user_prompt_with_context() {
329 let prompt = build_user_prompt("Hello Rust", &["prev message 1", "prev message 2"]);
330 assert!(prompt.contains("Context (last 2 messages):"));
331 assert!(prompt.contains("prev message 1\nprev message 2"));
332 assert!(prompt.contains("Current message:\nHello Rust"));
333 assert!(prompt.contains("Extract entities and relationships as JSON."));
334 }
335
336 #[test]
337 fn build_user_prompt_without_context() {
338 let prompt = build_user_prompt("Hello Rust", &[]);
339 assert!(!prompt.contains("Context"));
340 assert!(prompt.contains("Current message:\nHello Rust"));
341 assert!(prompt.contains("Extract entities and relationships as JSON."));
342 }
343
344 mod mock_tests {
345 use super::*;
346 use zeph_llm::mock::MockProvider;
347
348 fn make_entities_json(count: usize) -> String {
349 let entities: Vec<String> = (0..count)
350 .map(|i| format!(r#"{{"name":"entity{i}","type":"concept","summary":null}}"#))
351 .collect();
352 format!(r#"{{"entities":[{}],"edges":[]}}"#, entities.join(","))
353 }
354
355 fn make_edges_json(count: usize) -> String {
356 let edges: Vec<String> = (0..count)
357 .map(|i| {
358 format!(
359 r#"{{"source":"A","target":"B{i}","relation":"uses","fact":"A uses B{i}","temporal_hint":null}}"#
360 )
361 })
362 .collect();
363 format!(r#"{{"entities":[],"edges":[{}]}}"#, edges.join(","))
364 }
365
366 #[tokio::test]
367 async fn extract_truncates_to_max_entities() {
368 let json = make_entities_json(20);
369 let mock = MockProvider::with_responses(vec![json]);
370 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 5, 100);
371 let result = extractor.extract("test message", &[]).await.unwrap();
372 let result = result.unwrap();
373 assert_eq!(result.entities.len(), 5);
374 }
375
376 #[tokio::test]
377 async fn extract_truncates_to_max_edges() {
378 let json = make_edges_json(15);
379 let mock = MockProvider::with_responses(vec![json]);
380 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 100, 3);
381 let result = extractor.extract("test message", &[]).await.unwrap();
382 let result = result.unwrap();
383 assert_eq!(result.edges.len(), 3);
384 }
385
386 #[tokio::test]
387 async fn extract_returns_none_on_parse_failure() {
388 let mock = MockProvider::with_responses(vec!["not valid json at all".into()]);
389 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
390 let result = extractor.extract("test message", &[]).await.unwrap();
391 assert!(result.is_none());
392 }
393
394 #[tokio::test]
395 async fn extract_returns_err_on_transport_failure() {
396 let mock = MockProvider::default()
397 .with_errors(vec![zeph_llm::LlmError::Other("connection refused".into())]);
398 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
399 let result = extractor.extract("test message", &[]).await;
400 assert!(result.is_err());
401 assert!(matches!(result.unwrap_err(), MemoryError::Llm(_)));
402 }
403
404 #[tokio::test]
405 async fn extract_returns_none_on_empty_message() {
406 let mock = MockProvider::with_responses(vec!["should not be called".into()]);
407 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
408
409 let result_empty = extractor.extract("", &[]).await.unwrap();
410 assert!(result_empty.is_none());
411
412 let result_whitespace = extractor.extract(" \t\n ", &[]).await.unwrap();
413 assert!(result_whitespace.is_none());
414 }
415 }
416}