1use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use zeph_llm::LlmError;
7use zeph_llm::any::AnyProvider;
8use zeph_llm::provider::{Message, Role};
9
10use crate::error::MemoryError;
11
12const SYSTEM_PROMPT: &str = "\
13You are an entity and relationship extractor. Given a conversation message and \
14its recent context, extract structured knowledge as JSON.
15
16Rules:
171. Extract entities mentioned or implied in the current message.
182. Extract relationships between entities.
193. Entity types must be one of: person, tool, concept, project, language, file, config, organization.
204. Relations should be short verb phrases: \"prefers\", \"uses\", \"works_on\", \"knows\", \"created\", \"depends_on\", \"replaces\", \"configured_with\".
215. The \"fact\" field is a human-readable sentence summarizing the relationship.
226. If a message contains a temporal change (e.g., \"switched from X to Y\"), include a temporal_hint like \"replaced X\" or \"since January 2026\".
237. Do not extract entities from greetings, filler, or meta-conversation (\"hi\", \"thanks\", \"ok\").
248. Do not extract personal identifiable information as entity names: email addresses, phone numbers, physical addresses, SSNs, or API keys. Use generic references instead (e.g., \"User\" instead of \"Alice Smith\").
259. Always output entity names and relation verbs in English, even if the conversation is in another language. Translate entity names if needed.
2610. Return empty arrays if no entities or relationships are present.
27
28Output JSON schema:
29{
30 \"entities\": [
31 {\"name\": \"string\", \"type\": \"person|tool|concept|project|language|file|config|organization\", \"summary\": \"optional string\"}
32 ],
33 \"edges\": [
34 {\"source\": \"entity name\", \"target\": \"entity name\", \"relation\": \"verb phrase\", \"fact\": \"human-readable sentence\", \"temporal_hint\": \"optional string\"}
35 ]
36}";
37
38#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
39pub struct ExtractionResult {
40 pub entities: Vec<ExtractedEntity>,
41 pub edges: Vec<ExtractedEdge>,
42}
43
44#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
45pub struct ExtractedEntity {
46 pub name: String,
47 #[serde(rename = "type")]
48 pub entity_type: String,
49 pub summary: Option<String>,
50}
51
52#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
53pub struct ExtractedEdge {
54 pub source: String,
55 pub target: String,
56 pub relation: String,
57 pub fact: String,
58 pub temporal_hint: Option<String>,
59}
60
61pub struct GraphExtractor {
62 provider: AnyProvider,
63 max_entities: usize,
64 max_edges: usize,
65}
66
67impl GraphExtractor {
68 #[must_use]
69 pub fn new(provider: AnyProvider, max_entities: usize, max_edges: usize) -> Self {
70 Self {
71 provider,
72 max_entities,
73 max_edges,
74 }
75 }
76
77 pub async fn extract(
87 &self,
88 message: &str,
89 context_messages: &[&str],
90 ) -> Result<Option<ExtractionResult>, MemoryError> {
91 if message.trim().is_empty() {
92 return Ok(None);
93 }
94
95 let user_prompt = build_user_prompt(message, context_messages);
96 let messages = [
97 Message::from_legacy(Role::System, SYSTEM_PROMPT),
98 Message::from_legacy(Role::User, user_prompt),
99 ];
100
101 match self
102 .provider
103 .chat_typed_erased::<ExtractionResult>(&messages)
104 .await
105 {
106 Ok(mut result) => {
107 result.entities.truncate(self.max_entities);
108 result.edges.truncate(self.max_edges);
109 Ok(Some(result))
110 }
111 Err(LlmError::StructuredParse(msg)) => {
112 tracing::warn!(
113 "graph extraction: LLM returned unparseable output (len={}): {:.200}",
114 msg.len(),
115 msg
116 );
117 Ok(None)
118 }
119 Err(other) => Err(MemoryError::Llm(other)),
120 }
121 }
122}
123
124fn build_user_prompt(message: &str, context_messages: &[&str]) -> String {
125 if context_messages.is_empty() {
126 format!("Current message:\n{message}\n\nExtract entities and relationships as JSON.")
127 } else {
128 let n = context_messages.len();
129 let context = context_messages.join("\n");
130 format!(
131 "Context (last {n} messages):\n{context}\n\nCurrent message:\n{message}\n\nExtract entities and relationships as JSON."
132 )
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 fn make_entity(name: &str, entity_type: &str, summary: Option<&str>) -> ExtractedEntity {
141 ExtractedEntity {
142 name: name.into(),
143 entity_type: entity_type.into(),
144 summary: summary.map(Into::into),
145 }
146 }
147
148 fn make_edge(
149 source: &str,
150 target: &str,
151 relation: &str,
152 fact: &str,
153 temporal_hint: Option<&str>,
154 ) -> ExtractedEdge {
155 ExtractedEdge {
156 source: source.into(),
157 target: target.into(),
158 relation: relation.into(),
159 fact: fact.into(),
160 temporal_hint: temporal_hint.map(Into::into),
161 }
162 }
163
164 #[test]
165 fn extraction_result_deserialize_valid_json() {
166 let json = r#"{"entities":[{"name":"Rust","type":"language","summary":"A systems language"}],"edges":[]}"#;
167 let result: ExtractionResult = serde_json::from_str(json).unwrap();
168 assert_eq!(result.entities.len(), 1);
169 assert_eq!(result.entities[0].name, "Rust");
170 assert_eq!(result.entities[0].entity_type, "language");
171 assert_eq!(
172 result.entities[0].summary.as_deref(),
173 Some("A systems language")
174 );
175 assert!(result.edges.is_empty());
176 }
177
178 #[test]
179 fn extraction_result_deserialize_empty_arrays() {
180 let json = r#"{"entities":[],"edges":[]}"#;
181 let result: ExtractionResult = serde_json::from_str(json).unwrap();
182 assert!(result.entities.is_empty());
183 assert!(result.edges.is_empty());
184 }
185
186 #[test]
187 fn extraction_result_deserialize_missing_optional_fields() {
188 let json = r#"{"entities":[{"name":"Alice","type":"person","summary":null}],"edges":[{"source":"Alice","target":"Rust","relation":"uses","fact":"Alice uses Rust","temporal_hint":null}]}"#;
189 let result: ExtractionResult = serde_json::from_str(json).unwrap();
190 assert_eq!(result.entities[0].summary, None);
191 assert_eq!(result.edges[0].temporal_hint, None);
192 }
193
194 #[test]
195 fn extracted_entity_type_field_rename() {
196 let json = r#"{"name":"cargo","type":"tool","summary":null}"#;
197 let entity: ExtractedEntity = serde_json::from_str(json).unwrap();
198 assert_eq!(entity.entity_type, "tool");
199
200 let serialized = serde_json::to_string(&entity).unwrap();
201 assert!(serialized.contains("\"type\""));
202 assert!(!serialized.contains("\"entity_type\""));
203 }
204
205 #[test]
206 fn extraction_result_roundtrip() {
207 let original = ExtractionResult {
208 entities: vec![make_entity("Rust", "language", Some("A systems language"))],
209 edges: vec![make_edge("Alice", "Rust", "uses", "Alice uses Rust", None)],
210 };
211 let json = serde_json::to_string(&original).unwrap();
212 let restored: ExtractionResult = serde_json::from_str(&json).unwrap();
213 assert_eq!(original, restored);
214 }
215
216 #[test]
217 fn extraction_result_json_schema() {
218 let schema = schemars::schema_for!(ExtractionResult);
219 let value = serde_json::to_value(&schema).unwrap();
220 let schema_obj = value.as_object().unwrap();
221 assert!(
222 schema_obj.contains_key("title") || schema_obj.contains_key("properties"),
223 "schema should have top-level keys"
224 );
225 let json_str = serde_json::to_string(&schema).unwrap();
226 assert!(
227 json_str.contains("entities"),
228 "schema should contain 'entities'"
229 );
230 assert!(json_str.contains("edges"), "schema should contain 'edges'");
231 }
232
233 #[test]
234 fn build_user_prompt_with_context() {
235 let prompt = build_user_prompt("Hello Rust", &["prev message 1", "prev message 2"]);
236 assert!(prompt.contains("Context (last 2 messages):"));
237 assert!(prompt.contains("prev message 1\nprev message 2"));
238 assert!(prompt.contains("Current message:\nHello Rust"));
239 assert!(prompt.contains("Extract entities and relationships as JSON."));
240 }
241
242 #[test]
243 fn build_user_prompt_without_context() {
244 let prompt = build_user_prompt("Hello Rust", &[]);
245 assert!(!prompt.contains("Context"));
246 assert!(prompt.contains("Current message:\nHello Rust"));
247 assert!(prompt.contains("Extract entities and relationships as JSON."));
248 }
249
250 mod mock_tests {
251 use super::*;
252 use zeph_llm::mock::MockProvider;
253
254 fn make_entities_json(count: usize) -> String {
255 let entities: Vec<String> = (0..count)
256 .map(|i| format!(r#"{{"name":"entity{i}","type":"concept","summary":null}}"#))
257 .collect();
258 format!(r#"{{"entities":[{}],"edges":[]}}"#, entities.join(","))
259 }
260
261 fn make_edges_json(count: usize) -> String {
262 let edges: Vec<String> = (0..count)
263 .map(|i| {
264 format!(
265 r#"{{"source":"A","target":"B{i}","relation":"uses","fact":"A uses B{i}","temporal_hint":null}}"#
266 )
267 })
268 .collect();
269 format!(r#"{{"entities":[],"edges":[{}]}}"#, edges.join(","))
270 }
271
272 #[tokio::test]
273 async fn extract_truncates_to_max_entities() {
274 let json = make_entities_json(20);
275 let mock = MockProvider::with_responses(vec![json]);
276 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 5, 100);
277 let result = extractor.extract("test message", &[]).await.unwrap();
278 let result = result.unwrap();
279 assert_eq!(result.entities.len(), 5);
280 }
281
282 #[tokio::test]
283 async fn extract_truncates_to_max_edges() {
284 let json = make_edges_json(15);
285 let mock = MockProvider::with_responses(vec![json]);
286 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 100, 3);
287 let result = extractor.extract("test message", &[]).await.unwrap();
288 let result = result.unwrap();
289 assert_eq!(result.edges.len(), 3);
290 }
291
292 #[tokio::test]
293 async fn extract_returns_none_on_parse_failure() {
294 let mock = MockProvider::with_responses(vec!["not valid json at all".into()]);
295 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
296 let result = extractor.extract("test message", &[]).await.unwrap();
297 assert!(result.is_none());
298 }
299
300 #[tokio::test]
301 async fn extract_returns_err_on_transport_failure() {
302 let mock = MockProvider::default()
303 .with_errors(vec![zeph_llm::LlmError::Other("connection refused".into())]);
304 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
305 let result = extractor.extract("test message", &[]).await;
306 assert!(result.is_err());
307 assert!(matches!(result.unwrap_err(), MemoryError::Llm(_)));
308 }
309
310 #[tokio::test]
311 async fn extract_returns_none_on_empty_message() {
312 let mock = MockProvider::with_responses(vec!["should not be called".into()]);
313 let extractor = GraphExtractor::new(zeph_llm::any::AnyProvider::Mock(mock), 10, 10);
314
315 let result_empty = extractor.extract("", &[]).await.unwrap();
316 assert!(result_empty.is_none());
317
318 let result_whitespace = extractor.extract(" \t\n ", &[]).await.unwrap();
319 assert!(result_whitespace.is_none());
320 }
321 }
322}