1use std::sync::Arc;
30use std::time::Duration;
31
32use tracing::Instrument as _;
33use zeph_llm::any::AnyProvider;
34use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, Role};
35
36pub use zeph_config::memory::EmGraphConfig;
37
38use zeph_db::ActiveDialect;
39
40use crate::error::MemoryError;
41use crate::store::SqliteStore;
42use crate::types::MessageId;
43
44const MAX_CAUSAL_VISITED: usize = 400;
47
48#[derive(Debug, Clone)]
52pub struct EpisodicEvent {
53 pub id: i64,
55 pub session_id: String,
57 pub message_id: MessageId,
59 pub event_type: String,
61 pub summary: String,
63 pub embedding: Option<Vec<u8>>,
65 pub created_at: i64,
67}
68
69#[derive(Debug, Clone)]
71pub struct CausalLink {
72 pub id: i64,
74 pub cause_event_id: i64,
76 pub effect_event_id: i64,
78 pub strength: f32,
80 pub created_at: i64,
82}
83
84pub async fn extract_events(
99 provider: &Arc<AnyProvider>,
100 content: &str,
101 session_id: &str,
102 message_id: MessageId,
103 config: &EmGraphConfig,
104) -> Vec<EpisodicEvent> {
105 let span = tracing::debug_span!("memory.em_graph.extract_events", message_id = message_id.0);
106
107 async move {
108 if !config.enabled {
109 return vec![];
110 }
111
112 let snippet = content.chars().take(2000).collect::<String>();
113
114 let prompt = format!(
115 "Identify episodic events in the following conversation turn. \
116 An event is a concrete action, decision, discovery, or error. \
117 Return a JSON array of objects with fields: \
118 {{\"event_type\": \"<type>\", \"summary\": \"<one sentence>\"}}. \
119 Types: decision, discovery, error, tool_use, question, answer, other. \
120 Return [] if no notable events. Output JSON only.\n\nTurn:\n{snippet}"
121 );
122
123 let messages = vec![
124 Message {
125 role: Role::System,
126 content: "You are an episodic memory extractor. Extract concrete events from \
127 conversation turns as structured JSON. Output only valid JSON, no preamble."
128 .to_owned(),
129 parts: vec![],
130 metadata: MessageMetadata::default(),
131 },
132 Message {
133 role: Role::User,
134 content: prompt,
135 parts: vec![],
136 metadata: MessageMetadata::default(),
137 },
138 ];
139
140 let raw = match tokio::time::timeout(Duration::from_secs(10), provider.chat(&messages)).await {
141 Ok(Ok(r)) => r,
142 Ok(Err(e)) => {
143 tracing::warn!(error = %e, "em_graph: event extraction LLM call failed");
144 return vec![];
145 }
146 Err(_) => {
147 tracing::warn!("em_graph: event extraction timed out");
148 return vec![];
149 }
150 };
151
152 parse_events_response(&raw, session_id, message_id)
153 }
154 .instrument(span)
155 .await
156}
157
158fn parse_events_response(raw: &str, session_id: &str, message_id: MessageId) -> Vec<EpisodicEvent> {
159 let json_str = raw
160 .find('[')
161 .and_then(|s| raw[s..].rfind(']').map(|e| &raw[s..=s + e]))
162 .unwrap_or("[]");
163
164 let values: Vec<serde_json::Value> = serde_json::from_str(json_str).unwrap_or_default();
165
166 values
167 .into_iter()
168 .filter_map(|v| {
169 let event_type = v.get("event_type")?.as_str()?.to_owned();
170 let summary = v.get("summary")?.as_str()?.to_owned();
171 if summary.is_empty() {
172 return None;
173 }
174 Some(EpisodicEvent {
175 id: 0,
176 session_id: session_id.to_owned(),
177 message_id,
178 event_type,
179 summary,
180 embedding: None,
181 created_at: 0,
182 })
183 })
184 .collect()
185}
186
187pub async fn link_events(
201 provider: &Arc<AnyProvider>,
202 new_events: &[EpisodicEvent],
203 recent_events: &[EpisodicEvent],
204 config: &EmGraphConfig,
205) -> Vec<CausalLink> {
206 let span = tracing::debug_span!(
207 "memory.em_graph.link_events",
208 new_count = new_events.len(),
209 recent_count = recent_events.len()
210 );
211
212 async move {
213 if !config.enabled || new_events.is_empty() || recent_events.is_empty() {
214 return vec![];
215 }
216
217 let new_desc: Vec<String> = new_events
219 .iter()
220 .enumerate()
221 .map(|(i, e)| {
222 let s: String = e.summary.chars().take(200).collect();
223 format!("NEW[{i}] (id={}): {s}", e.id)
224 })
225 .collect();
226
227 let recent_desc: Vec<String> = recent_events
228 .iter()
229 .enumerate()
230 .map(|(i, e)| {
231 let s: String = e.summary.chars().take(200).collect();
232 format!("RECENT[{i}] (id={}): {s}", e.id)
233 })
234 .collect();
235
236 let prompt = format!(
237 "Given these recent events and new events, identify causal relationships \
238 (cause → effect). Return a JSON array of objects: \
239 {{\"cause_id\": <event_id>, \"effect_id\": <event_id>, \"strength\": 0.0-1.0}}. \
240 Only include strong causal links (strength >= 0.5). Output [] if none.\n\n\
241 Recent events:\n{}\n\nNew events:\n{}",
242 recent_desc.join("\n"),
243 new_desc.join("\n"),
244 );
245
246 let messages = vec![
247 Message {
248 role: Role::System,
249 content: "You are a causal reasoning engine. Identify cause-and-effect \
250 relationships between events. Output only valid JSON."
251 .to_owned(),
252 parts: vec![],
253 metadata: MessageMetadata::default(),
254 },
255 Message {
256 role: Role::User,
257 content: prompt,
258 parts: vec![],
259 metadata: MessageMetadata::default(),
260 },
261 ];
262
263 let raw =
264 match tokio::time::timeout(Duration::from_secs(10), provider.chat(&messages)).await {
265 Ok(Ok(r)) => r,
266 Ok(Err(e)) => {
267 tracing::warn!(error = %e, "em_graph: causal link LLM call failed");
268 return vec![];
269 }
270 Err(_) => {
271 tracing::warn!("em_graph: causal link detection timed out");
272 return vec![];
273 }
274 };
275
276 parse_links_response(&raw)
277 }
278 .instrument(span)
279 .await
280}
281
282fn parse_links_response(raw: &str) -> Vec<CausalLink> {
283 let json_str = raw
284 .find('[')
285 .and_then(|s| raw[s..].rfind(']').map(|e| &raw[s..=s + e]))
286 .unwrap_or("[]");
287
288 let values: Vec<serde_json::Value> = serde_json::from_str(json_str).unwrap_or_default();
289
290 values
291 .into_iter()
292 .filter_map(|v| {
293 let cause_id = v.get("cause_id")?.as_i64()?;
294 let effect_id = v.get("effect_id")?.as_i64()?;
295 #[allow(clippy::cast_possible_truncation)]
296 let strength = v
297 .get("strength")
298 .and_then(serde_json::Value::as_f64)
299 .map_or(0.5, |s| s.clamp(0.0, 1.0) as f32);
300 if strength < 0.5 {
301 return None;
302 }
303 Some(CausalLink {
304 id: 0,
305 cause_event_id: cause_id,
306 effect_event_id: effect_id,
307 strength,
308 created_at: 0,
309 })
310 })
311 .collect()
312}
313
314pub async fn store_events(
325 store: &SqliteStore,
326 events: &mut [EpisodicEvent],
327) -> Result<(), MemoryError> {
328 if events.is_empty() {
329 return Ok(());
330 }
331 let mut tx = store.pool().begin().await?;
332 for event in events.iter_mut() {
333 let epoch_now = <ActiveDialect as zeph_db::dialect::Dialect>::EPOCH_NOW;
334 let raw = format!(
335 "INSERT INTO episodic_events (session_id, message_id, event_type, summary, created_at) \
336 VALUES (?, ?, ?, ?, {epoch_now}) \
337 RETURNING id"
338 );
339 let sql = zeph_db::rewrite_placeholders(&raw);
340 let id = sqlx::query_scalar::<_, i64>(&sql)
341 .bind(&event.session_id)
342 .bind(event.message_id.0)
343 .bind(&event.event_type)
344 .bind(&event.summary)
345 .fetch_one(&mut *tx)
346 .await?;
347 event.id = id;
348 }
349 tx.commit().await?;
350 Ok(())
351}
352
353pub async fn store_links(store: &SqliteStore, links: &[CausalLink]) -> Result<(), MemoryError> {
364 if links.is_empty() {
365 return Ok(());
366 }
367 let mut tx = store.pool().begin().await?;
368 for link in links {
369 let epoch_now = <ActiveDialect as zeph_db::dialect::Dialect>::EPOCH_NOW;
370 let raw = format!(
371 "INSERT INTO causal_links \
372 (cause_event_id, effect_event_id, strength, created_at) \
373 VALUES (?, ?, ?, {epoch_now}) \
374 ON CONFLICT (cause_event_id, effect_event_id) DO NOTHING"
375 );
376 let sql = zeph_db::rewrite_placeholders(&raw);
377 sqlx::query(&sql)
378 .bind(link.cause_event_id)
379 .bind(link.effect_event_id)
380 .bind(link.strength)
381 .execute(&mut *tx)
382 .await?;
383 }
384 tx.commit().await?;
385 Ok(())
386}
387
388pub async fn fetch_recent_events(
398 store: &SqliteStore,
399 session_id: &str,
400 limit: usize,
401) -> Result<Vec<EpisodicEvent>, MemoryError> {
402 let rows = sqlx::query_as::<_, (i64, String, i64, String, String, i64)>(
403 "SELECT id, session_id, message_id, event_type, summary, created_at
404 FROM episodic_events
405 WHERE session_id = ?
406 ORDER BY created_at DESC
407 LIMIT ?",
408 )
409 .bind(session_id)
410 .bind(i64::try_from(limit).unwrap_or(i64::MAX))
411 .fetch_all(store.pool())
412 .await?;
413
414 Ok(rows
415 .into_iter()
416 .map(
417 |(id, session_id, message_id, event_type, summary, created_at)| EpisodicEvent {
418 id,
419 session_id,
420 message_id: MessageId(message_id),
421 event_type,
422 summary,
423 embedding: None,
424 created_at,
425 },
426 )
427 .collect())
428}
429
430pub async fn recall_episodic_causal(
439 store: &SqliteStore,
440 seed_event_id: i64,
441 session_id: &str,
442 max_depth: u32,
443 config: &EmGraphConfig,
444) -> Result<Vec<EpisodicEvent>, MemoryError> {
445 let span = tracing::debug_span!("memory.em_graph.causal_recall", seed_event_id, max_depth);
446
447 if !config.enabled {
448 return Ok(vec![]);
449 }
450
451 let pool = store.pool().clone();
453 let session_id = session_id.to_owned();
454
455 async move {
456 let mut visited: Vec<i64> = vec![seed_event_id];
457 let mut frontier: Vec<i64> = vec![seed_event_id];
458
459 for depth in 0..max_depth {
460 if frontier.is_empty() || visited.len() >= MAX_CAUSAL_VISITED {
461 break;
462 }
463
464 let frontier_ph = frontier.iter().map(|_| "?").collect::<Vec<_>>().join(",");
465 let visited_ph = visited.iter().map(|_| "?").collect::<Vec<_>>().join(",");
466
467 let query = format!(
468 "SELECT DISTINCT effect_event_id FROM causal_links
469 WHERE cause_event_id IN ({frontier_ph})
470 AND effect_event_id NOT IN ({visited_ph})"
471 );
472
473 let mut q = sqlx::query_scalar::<_, i64>(&query);
474 for &id in &frontier {
475 q = q.bind(id);
476 }
477 for &id in &visited {
478 q = q.bind(id);
479 }
480
481 let next: Vec<i64> = q.fetch_all(&pool).await?;
482
483 tracing::debug!(depth, next_count = next.len(), "em_graph: causal hop");
484 visited.extend_from_slice(&next);
485 frontier = next;
486 }
487
488 if visited.is_empty() {
489 return Ok(vec![]);
490 }
491
492 let placeholders = visited.iter().map(|_| "?").collect::<Vec<_>>().join(",");
494
495 let query = format!(
496 "SELECT id, session_id, message_id, event_type, summary, created_at
497 FROM episodic_events
498 WHERE id IN ({placeholders}) AND session_id = ?
499 ORDER BY created_at ASC"
500 );
501
502 let mut q = sqlx::query_as::<_, (i64, String, i64, String, String, i64)>(&query);
503 for &id in &visited {
504 q = q.bind(id);
505 }
506 q = q.bind(session_id);
507
508 let rows = q.fetch_all(&pool).await?;
509
510 Ok(rows
511 .into_iter()
512 .map(
513 |(id, session_id, message_id, event_type, summary, created_at)| EpisodicEvent {
514 id,
515 session_id,
516 message_id: MessageId(message_id),
517 event_type,
518 summary,
519 embedding: None,
520 created_at,
521 },
522 )
523 .collect())
524 }
525 .instrument(span)
526 .await
527}
528
529#[cfg(test)]
532mod tests {
533 use super::*;
534 use zeph_config::providers::ProviderName;
535
536 #[test]
537 fn parse_events_response_valid_json() {
538 let raw = r#"[{"event_type":"decision","summary":"User chose approach A"},{"event_type":"discovery","summary":"Found a bug in module X"}]"#;
539 let events = parse_events_response(raw, "sess-1", MessageId(42));
540 assert_eq!(events.len(), 2);
541 assert_eq!(events[0].event_type, "decision");
542 assert_eq!(events[1].summary, "Found a bug in module X");
543 assert_eq!(events[0].message_id, MessageId(42));
544 assert_eq!(events[0].session_id, "sess-1");
545 }
546
547 #[test]
548 fn parse_events_response_empty_array() {
549 let events = parse_events_response("[]", "sess-1", MessageId(1));
550 assert!(events.is_empty());
551 }
552
553 #[test]
554 fn parse_events_response_malformed_json() {
555 let events = parse_events_response("not json", "sess-1", MessageId(1));
556 assert!(events.is_empty());
557 }
558
559 #[test]
560 fn parse_events_response_skips_empty_summary() {
561 let raw = r#"[{"event_type":"decision","summary":""}]"#;
562 let events = parse_events_response(raw, "sess-1", MessageId(1));
563 assert!(events.is_empty(), "empty summary must be skipped");
564 }
565
566 #[test]
567 fn parse_links_response_valid_json() {
568 let raw = r#"[{"cause_id":1,"effect_id":2,"strength":0.8}]"#;
569 let links = parse_links_response(raw);
570 assert_eq!(links.len(), 1);
571 assert_eq!(links[0].cause_event_id, 1);
572 assert_eq!(links[0].effect_event_id, 2);
573 assert!((links[0].strength - 0.8).abs() < 0.01);
574 }
575
576 #[test]
577 fn parse_links_response_filters_weak_links() {
578 let raw = r#"[{"cause_id":1,"effect_id":2,"strength":0.3}]"#;
579 let links = parse_links_response(raw);
580 assert!(
581 links.is_empty(),
582 "weak links (strength < 0.5) must be filtered"
583 );
584 }
585
586 #[test]
587 fn parse_links_response_empty() {
588 let links = parse_links_response("[]");
589 assert!(links.is_empty());
590 }
591
592 #[test]
593 fn em_graph_config_defaults() {
594 let cfg = EmGraphConfig::default();
595 assert!(!cfg.enabled);
596 assert_eq!(cfg.max_chain_depth, 3);
597 }
598
599 #[tokio::test]
600 async fn store_and_fetch_events_in_memory_db() {
601 use crate::store::SqliteStore;
602
603 let store = SqliteStore::new(":memory:")
604 .await
605 .expect("SqliteStore::new");
606 let cid = store.create_conversation().await.expect("conversation");
607 let mid = store
608 .save_message(cid, "user", "hello world")
609 .await
610 .expect("save_message");
611
612 let mut events = vec![EpisodicEvent {
613 id: 0,
614 session_id: "test-session".to_owned(),
615 message_id: mid,
616 event_type: "decision".to_owned(),
617 summary: "User decided to use approach A".to_owned(),
618 embedding: None,
619 created_at: 0,
620 }];
621
622 store_events(&store, &mut events)
623 .await
624 .expect("store_events");
625 assert!(events[0].id > 0, "id must be assigned after insert");
626
627 let fetched = fetch_recent_events(&store, "test-session", 10)
628 .await
629 .expect("fetch_recent_events");
630 assert_eq!(fetched.len(), 1);
631 assert_eq!(fetched[0].summary, "User decided to use approach A");
632 }
633
634 #[tokio::test]
635 async fn store_and_recall_causal_chain() {
636 use crate::store::SqliteStore;
637
638 let store = SqliteStore::new(":memory:")
639 .await
640 .expect("SqliteStore::new");
641 let cid = store.create_conversation().await.expect("conversation");
642 let mid = store
643 .save_message(cid, "user", "test")
644 .await
645 .expect("save_message");
646
647 let mut events = vec![
648 EpisodicEvent {
649 id: 0,
650 session_id: "sess".to_owned(),
651 message_id: mid,
652 event_type: "discovery".to_owned(),
653 summary: "Found a bug".to_owned(),
654 embedding: None,
655 created_at: 0,
656 },
657 EpisodicEvent {
658 id: 0,
659 session_id: "sess".to_owned(),
660 message_id: mid,
661 event_type: "decision".to_owned(),
662 summary: "Decided to fix it".to_owned(),
663 embedding: None,
664 created_at: 0,
665 },
666 ];
667 store_events(&store, &mut events)
668 .await
669 .expect("store_events");
670
671 let link = CausalLink {
672 id: 0,
673 cause_event_id: events[0].id,
674 effect_event_id: events[1].id,
675 strength: 0.9,
676 created_at: 0,
677 };
678 store_links(&store, &[link]).await.expect("store_links");
679
680 let config = EmGraphConfig {
681 enabled: true,
682 extract_provider: ProviderName::default(),
683 max_chain_depth: 3,
684 };
685 let chain = recall_episodic_causal(&store, events[0].id, "sess", 3, &config)
686 .await
687 .expect("recall_episodic_causal");
688
689 assert_eq!(
690 chain.len(),
691 2,
692 "chain must include seed and causally-linked event"
693 );
694 }
695
696 #[test]
697 fn parse_links_response_strength_at_boundary_included() {
698 let raw = r#"[{"cause_id":1,"effect_id":2,"strength":0.5}]"#;
700 let links = parse_links_response(raw);
701 assert_eq!(
702 links.len(),
703 1,
704 "strength=0.5 must be included (threshold is strict < 0.5)"
705 );
706 assert!((links[0].strength - 0.5).abs() < 0.001);
707 }
708
709 #[tokio::test]
710 async fn recall_episodic_causal_disabled_returns_empty() {
711 use crate::store::SqliteStore;
712
713 let store = SqliteStore::new(":memory:")
714 .await
715 .expect("SqliteStore::new");
716 let config = EmGraphConfig {
717 enabled: false,
718 extract_provider: ProviderName::default(),
719 max_chain_depth: 3,
720 };
721 let result = recall_episodic_causal(&store, 1, "sess", 3, &config).await;
722 assert!(result.is_ok());
723 assert!(
724 result.unwrap().is_empty(),
725 "disabled config must return empty"
726 );
727 }
728
729 #[tokio::test]
730 async fn store_links_is_idempotent_with_unique_constraint() {
731 use crate::store::SqliteStore;
732
733 let store = SqliteStore::new(":memory:")
734 .await
735 .expect("SqliteStore::new");
736 let cid = store.create_conversation().await.expect("conversation");
737 let mid = store
738 .save_message(cid, "user", "test")
739 .await
740 .expect("save_message");
741
742 let mut events = vec![
743 EpisodicEvent {
744 id: 0,
745 session_id: "sess".to_owned(),
746 message_id: mid,
747 event_type: "decision".to_owned(),
748 summary: "A".to_owned(),
749 embedding: None,
750 created_at: 0,
751 },
752 EpisodicEvent {
753 id: 0,
754 session_id: "sess".to_owned(),
755 message_id: mid,
756 event_type: "discovery".to_owned(),
757 summary: "B".to_owned(),
758 embedding: None,
759 created_at: 0,
760 },
761 ];
762 store_events(&store, &mut events)
763 .await
764 .expect("store_events");
765
766 let link = CausalLink {
767 id: 0,
768 cause_event_id: events[0].id,
769 effect_event_id: events[1].id,
770 strength: 0.8,
771 created_at: 0,
772 };
773 store_links(&store, std::slice::from_ref(&link))
775 .await
776 .expect("first store_links");
777 store_links(&store, &[link])
778 .await
779 .expect("second store_links (idempotent)");
780
781 let count: i64 = sqlx::query_scalar(
782 "SELECT COUNT(*) FROM causal_links WHERE cause_event_id = ? AND effect_event_id = ?",
783 )
784 .bind(events[0].id)
785 .bind(events[1].id)
786 .fetch_one(store.pool())
787 .await
788 .expect("count query");
789
790 assert_eq!(
791 count, 1,
792 "duplicate causal links must be deduplicated by UNIQUE constraint"
793 );
794 }
795}