1use std::sync::Arc;
30use std::time::Duration;
31
32use tracing::Instrument as _;
33use zeph_llm::any::AnyProvider;
34use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, Role};
35
36pub use zeph_config::memory::EmGraphConfig;
37
38use zeph_common::SessionId;
39use zeph_db::ActiveDialect;
40
41use crate::error::MemoryError;
42use crate::store::SqliteStore;
43use crate::types::MessageId;
44
45const MAX_CAUSAL_VISITED: usize = 400;
48
49#[derive(Debug, Clone)]
53pub struct EpisodicEvent {
54 pub id: i64,
56 pub session_id: SessionId,
58 pub message_id: MessageId,
60 pub event_type: String,
62 pub summary: String,
64 pub embedding: Option<Vec<u8>>,
66 pub created_at: i64,
68}
69
70#[derive(Debug, Clone)]
72pub struct CausalLink {
73 pub id: i64,
75 pub cause_event_id: i64,
77 pub effect_event_id: i64,
79 pub strength: f32,
81 pub created_at: i64,
83}
84
85pub async fn extract_events(
100 provider: &Arc<AnyProvider>,
101 content: &str,
102 session_id: &str,
103 message_id: MessageId,
104 config: &EmGraphConfig,
105) -> Vec<EpisodicEvent> {
106 let span = tracing::debug_span!("memory.em_graph.extract_events", message_id = message_id.0);
107
108 async move {
109 if !config.enabled {
110 return vec![];
111 }
112
113 let snippet = content.chars().take(2000).collect::<String>();
114
115 let prompt = format!(
116 "Identify episodic events in the following conversation turn. \
117 An event is a concrete action, decision, discovery, or error. \
118 Return a JSON array of objects with fields: \
119 {{\"event_type\": \"<type>\", \"summary\": \"<one sentence>\"}}. \
120 Types: decision, discovery, error, tool_use, question, answer, other. \
121 Return [] if no notable events. Output JSON only.\n\nTurn:\n{snippet}"
122 );
123
124 let messages = vec![
125 Message {
126 role: Role::System,
127 content: "You are an episodic memory extractor. Extract concrete events from \
128 conversation turns as structured JSON. Output only valid JSON, no preamble."
129 .to_owned(),
130 parts: vec![],
131 metadata: MessageMetadata::default(),
132 },
133 Message {
134 role: Role::User,
135 content: prompt,
136 parts: vec![],
137 metadata: MessageMetadata::default(),
138 },
139 ];
140
141 let raw = match tokio::time::timeout(Duration::from_secs(10), provider.chat(&messages)).await {
142 Ok(Ok(r)) => r,
143 Ok(Err(e)) => {
144 tracing::warn!(error = %e, "em_graph: event extraction LLM call failed");
145 return vec![];
146 }
147 Err(_) => {
148 tracing::warn!("em_graph: event extraction timed out");
149 return vec![];
150 }
151 };
152
153 parse_events_response(&raw, session_id, message_id)
154 }
155 .instrument(span)
156 .await
157}
158
159fn parse_events_response(raw: &str, session_id: &str, message_id: MessageId) -> Vec<EpisodicEvent> {
160 let json_str = raw
161 .find('[')
162 .and_then(|s| raw[s..].rfind(']').map(|e| &raw[s..=s + e]))
163 .unwrap_or("[]");
164
165 let values: Vec<serde_json::Value> = serde_json::from_str(json_str).unwrap_or_default();
166
167 values
168 .into_iter()
169 .filter_map(|v| {
170 let event_type = v.get("event_type")?.as_str()?.to_owned();
171 let summary = v.get("summary")?.as_str()?.to_owned();
172 if summary.is_empty() {
173 return None;
174 }
175 Some(EpisodicEvent {
176 id: 0,
177 session_id: SessionId::new(session_id),
178 message_id,
179 event_type,
180 summary,
181 embedding: None,
182 created_at: 0,
183 })
184 })
185 .collect()
186}
187
188pub async fn link_events(
202 provider: &Arc<AnyProvider>,
203 new_events: &[EpisodicEvent],
204 recent_events: &[EpisodicEvent],
205 config: &EmGraphConfig,
206) -> Vec<CausalLink> {
207 let span = tracing::debug_span!(
208 "memory.em_graph.link_events",
209 new_count = new_events.len(),
210 recent_count = recent_events.len()
211 );
212
213 async move {
214 if !config.enabled || new_events.is_empty() || recent_events.is_empty() {
215 return vec![];
216 }
217
218 let new_desc: Vec<String> = new_events
220 .iter()
221 .enumerate()
222 .map(|(i, e)| {
223 let s: String = e.summary.chars().take(200).collect();
224 format!("NEW[{i}] (id={}): {s}", e.id)
225 })
226 .collect();
227
228 let recent_desc: Vec<String> = recent_events
229 .iter()
230 .enumerate()
231 .map(|(i, e)| {
232 let s: String = e.summary.chars().take(200).collect();
233 format!("RECENT[{i}] (id={}): {s}", e.id)
234 })
235 .collect();
236
237 let prompt = format!(
238 "Given these recent events and new events, identify causal relationships \
239 (cause → effect). Return a JSON array of objects: \
240 {{\"cause_id\": <event_id>, \"effect_id\": <event_id>, \"strength\": 0.0-1.0}}. \
241 Only include strong causal links (strength >= 0.5). Output [] if none.\n\n\
242 Recent events:\n{}\n\nNew events:\n{}",
243 recent_desc.join("\n"),
244 new_desc.join("\n"),
245 );
246
247 let messages = vec![
248 Message {
249 role: Role::System,
250 content: "You are a causal reasoning engine. Identify cause-and-effect \
251 relationships between events. Output only valid JSON."
252 .to_owned(),
253 parts: vec![],
254 metadata: MessageMetadata::default(),
255 },
256 Message {
257 role: Role::User,
258 content: prompt,
259 parts: vec![],
260 metadata: MessageMetadata::default(),
261 },
262 ];
263
264 let raw =
265 match tokio::time::timeout(Duration::from_secs(10), provider.chat(&messages)).await {
266 Ok(Ok(r)) => r,
267 Ok(Err(e)) => {
268 tracing::warn!(error = %e, "em_graph: causal link LLM call failed");
269 return vec![];
270 }
271 Err(_) => {
272 tracing::warn!("em_graph: causal link detection timed out");
273 return vec![];
274 }
275 };
276
277 parse_links_response(&raw)
278 }
279 .instrument(span)
280 .await
281}
282
283fn parse_links_response(raw: &str) -> Vec<CausalLink> {
284 let json_str = raw
285 .find('[')
286 .and_then(|s| raw[s..].rfind(']').map(|e| &raw[s..=s + e]))
287 .unwrap_or("[]");
288
289 let values: Vec<serde_json::Value> = serde_json::from_str(json_str).unwrap_or_default();
290
291 values
292 .into_iter()
293 .filter_map(|v| {
294 let cause_id = v.get("cause_id")?.as_i64()?;
295 let effect_id = v.get("effect_id")?.as_i64()?;
296 #[allow(clippy::cast_possible_truncation)]
297 let strength = v
298 .get("strength")
299 .and_then(serde_json::Value::as_f64)
300 .map_or(0.5, |s| s.clamp(0.0, 1.0) as f32);
301 if strength < 0.5 {
302 return None;
303 }
304 Some(CausalLink {
305 id: 0,
306 cause_event_id: cause_id,
307 effect_event_id: effect_id,
308 strength,
309 created_at: 0,
310 })
311 })
312 .collect()
313}
314
315pub async fn store_events(
326 store: &SqliteStore,
327 events: &mut [EpisodicEvent],
328) -> Result<(), MemoryError> {
329 if events.is_empty() {
330 return Ok(());
331 }
332 let mut tx = store.pool().begin().await?;
333 for event in events.iter_mut() {
334 let epoch_now = <ActiveDialect as zeph_db::dialect::Dialect>::EPOCH_NOW;
335 let raw = format!(
336 "INSERT INTO episodic_events (session_id, message_id, event_type, summary, created_at) \
337 VALUES (?, ?, ?, ?, {epoch_now}) \
338 RETURNING id"
339 );
340 let sql = zeph_db::rewrite_placeholders(&raw);
341 let id = sqlx::query_scalar::<_, i64>(&sql)
342 .bind(event.session_id.as_str())
343 .bind(event.message_id.0)
344 .bind(&event.event_type)
345 .bind(&event.summary)
346 .fetch_one(&mut *tx)
347 .await?;
348 event.id = id;
349 }
350 tx.commit().await?;
351 Ok(())
352}
353
354pub async fn store_links(store: &SqliteStore, links: &[CausalLink]) -> Result<(), MemoryError> {
365 if links.is_empty() {
366 return Ok(());
367 }
368 let mut tx = store.pool().begin().await?;
369 for link in links {
370 let epoch_now = <ActiveDialect as zeph_db::dialect::Dialect>::EPOCH_NOW;
371 let raw = format!(
372 "INSERT INTO causal_links \
373 (cause_event_id, effect_event_id, strength, created_at) \
374 VALUES (?, ?, ?, {epoch_now}) \
375 ON CONFLICT (cause_event_id, effect_event_id) DO NOTHING"
376 );
377 let sql = zeph_db::rewrite_placeholders(&raw);
378 sqlx::query(&sql)
379 .bind(link.cause_event_id)
380 .bind(link.effect_event_id)
381 .bind(link.strength)
382 .execute(&mut *tx)
383 .await?;
384 }
385 tx.commit().await?;
386 Ok(())
387}
388
389pub async fn fetch_recent_events(
399 store: &SqliteStore,
400 session_id: &str,
401 limit: usize,
402) -> Result<Vec<EpisodicEvent>, MemoryError> {
403 let rows = sqlx::query_as::<_, (i64, String, i64, String, String, i64)>(
404 "SELECT id, session_id, message_id, event_type, summary, created_at
405 FROM episodic_events
406 WHERE session_id = ?
407 ORDER BY created_at DESC
408 LIMIT ?",
409 )
410 .bind(session_id)
411 .bind(i64::try_from(limit).unwrap_or(i64::MAX))
412 .fetch_all(store.pool())
413 .await?;
414
415 Ok(rows
416 .into_iter()
417 .map(
418 |(id, session_id, message_id, event_type, summary, created_at)| EpisodicEvent {
419 id,
420 session_id: SessionId::new(session_id),
421 message_id: MessageId(message_id),
422 event_type,
423 summary,
424 embedding: None,
425 created_at,
426 },
427 )
428 .collect())
429}
430
431pub async fn recall_episodic_causal(
440 store: &SqliteStore,
441 seed_event_id: i64,
442 session_id: &str,
443 max_depth: u32,
444 config: &EmGraphConfig,
445) -> Result<Vec<EpisodicEvent>, MemoryError> {
446 let span = tracing::debug_span!("memory.em_graph.causal_recall", seed_event_id, max_depth);
447
448 if !config.enabled {
449 return Ok(vec![]);
450 }
451
452 let pool = store.pool().clone();
454 let session_id = session_id.to_owned();
455
456 async move {
457 let mut visited: Vec<i64> = vec![seed_event_id];
458 let mut frontier: Vec<i64> = vec![seed_event_id];
459
460 for depth in 0..max_depth {
461 if frontier.is_empty() || visited.len() >= MAX_CAUSAL_VISITED {
462 break;
463 }
464
465 let frontier_ph = frontier.iter().map(|_| "?").collect::<Vec<_>>().join(",");
466 let visited_ph = visited.iter().map(|_| "?").collect::<Vec<_>>().join(",");
467
468 let query = format!(
469 "SELECT DISTINCT effect_event_id FROM causal_links
470 WHERE cause_event_id IN ({frontier_ph})
471 AND effect_event_id NOT IN ({visited_ph})"
472 );
473
474 let mut q = sqlx::query_scalar::<_, i64>(&query);
475 for &id in &frontier {
476 q = q.bind(id);
477 }
478 for &id in &visited {
479 q = q.bind(id);
480 }
481
482 let next: Vec<i64> = q.fetch_all(&pool).await?;
483
484 tracing::debug!(depth, next_count = next.len(), "em_graph: causal hop");
485 visited.extend_from_slice(&next);
486 frontier = next;
487 }
488
489 if visited.is_empty() {
490 return Ok(vec![]);
491 }
492
493 let placeholders = visited.iter().map(|_| "?").collect::<Vec<_>>().join(",");
495
496 let query = format!(
497 "SELECT id, session_id, message_id, event_type, summary, created_at
498 FROM episodic_events
499 WHERE id IN ({placeholders}) AND session_id = ?
500 ORDER BY created_at ASC"
501 );
502
503 let mut q = sqlx::query_as::<_, (i64, String, i64, String, String, i64)>(&query);
504 for &id in &visited {
505 q = q.bind(id);
506 }
507 q = q.bind(session_id.as_str());
508
509 let rows = q.fetch_all(&pool).await?;
510
511 Ok(rows
512 .into_iter()
513 .map(
514 |(id, session_id, message_id, event_type, summary, created_at)| EpisodicEvent {
515 id,
516 session_id: SessionId::new(session_id),
517 message_id: MessageId(message_id),
518 event_type,
519 summary,
520 embedding: None,
521 created_at,
522 },
523 )
524 .collect())
525 }
526 .instrument(span)
527 .await
528}
529
530#[cfg(test)]
533mod tests {
534 use super::*;
535 use zeph_config::providers::ProviderName;
536
537 #[test]
538 fn parse_events_response_valid_json() {
539 let raw = r#"[{"event_type":"decision","summary":"User chose approach A"},{"event_type":"discovery","summary":"Found a bug in module X"}]"#;
540 let events = parse_events_response(raw, "sess-1", MessageId(42));
541 assert_eq!(events.len(), 2);
542 assert_eq!(events[0].event_type, "decision");
543 assert_eq!(events[1].summary, "Found a bug in module X");
544 assert_eq!(events[0].message_id, MessageId(42));
545 assert_eq!(events[0].session_id, "sess-1");
546 }
547
548 #[test]
549 fn parse_events_response_empty_array() {
550 let events = parse_events_response("[]", "sess-1", MessageId(1));
551 assert!(events.is_empty());
552 }
553
554 #[test]
555 fn parse_events_response_malformed_json() {
556 let events = parse_events_response("not json", "sess-1", MessageId(1));
557 assert!(events.is_empty());
558 }
559
560 #[test]
561 fn parse_events_response_skips_empty_summary() {
562 let raw = r#"[{"event_type":"decision","summary":""}]"#;
563 let events = parse_events_response(raw, "sess-1", MessageId(1));
564 assert!(events.is_empty(), "empty summary must be skipped");
565 }
566
567 #[test]
568 fn parse_links_response_valid_json() {
569 let raw = r#"[{"cause_id":1,"effect_id":2,"strength":0.8}]"#;
570 let links = parse_links_response(raw);
571 assert_eq!(links.len(), 1);
572 assert_eq!(links[0].cause_event_id, 1);
573 assert_eq!(links[0].effect_event_id, 2);
574 assert!((links[0].strength - 0.8).abs() < 0.01);
575 }
576
577 #[test]
578 fn parse_links_response_filters_weak_links() {
579 let raw = r#"[{"cause_id":1,"effect_id":2,"strength":0.3}]"#;
580 let links = parse_links_response(raw);
581 assert!(
582 links.is_empty(),
583 "weak links (strength < 0.5) must be filtered"
584 );
585 }
586
587 #[test]
588 fn parse_links_response_empty() {
589 let links = parse_links_response("[]");
590 assert!(links.is_empty());
591 }
592
593 #[test]
594 fn em_graph_config_defaults() {
595 let cfg = EmGraphConfig::default();
596 assert!(!cfg.enabled);
597 assert_eq!(cfg.max_chain_depth, 3);
598 }
599
600 #[tokio::test]
601 async fn store_and_fetch_events_in_memory_db() {
602 use crate::store::SqliteStore;
603
604 let store = SqliteStore::new(":memory:")
605 .await
606 .expect("SqliteStore::new");
607 let cid = store.create_conversation().await.expect("conversation");
608 let mid = store
609 .save_message(cid, "user", "hello world")
610 .await
611 .expect("save_message");
612
613 let mut events = vec![EpisodicEvent {
614 id: 0,
615 session_id: SessionId::new("test-session"),
616 message_id: mid,
617 event_type: "decision".to_owned(),
618 summary: "User decided to use approach A".to_owned(),
619 embedding: None,
620 created_at: 0,
621 }];
622
623 store_events(&store, &mut events)
624 .await
625 .expect("store_events");
626 assert!(events[0].id > 0, "id must be assigned after insert");
627
628 let fetched = fetch_recent_events(&store, "test-session", 10)
629 .await
630 .expect("fetch_recent_events");
631 assert_eq!(fetched.len(), 1);
632 assert_eq!(fetched[0].summary, "User decided to use approach A");
633 }
634
635 #[tokio::test]
636 async fn store_and_recall_causal_chain() {
637 use crate::store::SqliteStore;
638
639 let store = SqliteStore::new(":memory:")
640 .await
641 .expect("SqliteStore::new");
642 let cid = store.create_conversation().await.expect("conversation");
643 let mid = store
644 .save_message(cid, "user", "test")
645 .await
646 .expect("save_message");
647
648 let mut events = vec![
649 EpisodicEvent {
650 id: 0,
651 session_id: SessionId::new("sess"),
652 message_id: mid,
653 event_type: "discovery".to_owned(),
654 summary: "Found a bug".to_owned(),
655 embedding: None,
656 created_at: 0,
657 },
658 EpisodicEvent {
659 id: 0,
660 session_id: SessionId::new("sess"),
661 message_id: mid,
662 event_type: "decision".to_owned(),
663 summary: "Decided to fix it".to_owned(),
664 embedding: None,
665 created_at: 0,
666 },
667 ];
668 store_events(&store, &mut events)
669 .await
670 .expect("store_events");
671
672 let link = CausalLink {
673 id: 0,
674 cause_event_id: events[0].id,
675 effect_event_id: events[1].id,
676 strength: 0.9,
677 created_at: 0,
678 };
679 store_links(&store, &[link]).await.expect("store_links");
680
681 let config = EmGraphConfig {
682 enabled: true,
683 extract_provider: ProviderName::default(),
684 max_chain_depth: 3,
685 };
686 let chain = recall_episodic_causal(&store, events[0].id, "sess", 3, &config)
687 .await
688 .expect("recall_episodic_causal");
689
690 assert_eq!(
691 chain.len(),
692 2,
693 "chain must include seed and causally-linked event"
694 );
695 }
696
697 #[test]
698 fn parse_links_response_strength_at_boundary_included() {
699 let raw = r#"[{"cause_id":1,"effect_id":2,"strength":0.5}]"#;
701 let links = parse_links_response(raw);
702 assert_eq!(
703 links.len(),
704 1,
705 "strength=0.5 must be included (threshold is strict < 0.5)"
706 );
707 assert!((links[0].strength - 0.5).abs() < 0.001);
708 }
709
710 #[tokio::test]
711 async fn recall_episodic_causal_disabled_returns_empty() {
712 use crate::store::SqliteStore;
713
714 let store = SqliteStore::new(":memory:")
715 .await
716 .expect("SqliteStore::new");
717 let config = EmGraphConfig {
718 enabled: false,
719 extract_provider: ProviderName::default(),
720 max_chain_depth: 3,
721 };
722 let result = recall_episodic_causal(&store, 1, "sess", 3, &config).await;
723 assert!(result.is_ok());
724 assert!(
725 result.unwrap().is_empty(),
726 "disabled config must return empty"
727 );
728 }
729
730 #[tokio::test]
731 async fn store_links_is_idempotent_with_unique_constraint() {
732 use crate::store::SqliteStore;
733
734 let store = SqliteStore::new(":memory:")
735 .await
736 .expect("SqliteStore::new");
737 let cid = store.create_conversation().await.expect("conversation");
738 let mid = store
739 .save_message(cid, "user", "test")
740 .await
741 .expect("save_message");
742
743 let mut events = vec![
744 EpisodicEvent {
745 id: 0,
746 session_id: SessionId::new("sess"),
747 message_id: mid,
748 event_type: "decision".to_owned(),
749 summary: "A".to_owned(),
750 embedding: None,
751 created_at: 0,
752 },
753 EpisodicEvent {
754 id: 0,
755 session_id: SessionId::new("sess"),
756 message_id: mid,
757 event_type: "discovery".to_owned(),
758 summary: "B".to_owned(),
759 embedding: None,
760 created_at: 0,
761 },
762 ];
763 store_events(&store, &mut events)
764 .await
765 .expect("store_events");
766
767 let link = CausalLink {
768 id: 0,
769 cause_event_id: events[0].id,
770 effect_event_id: events[1].id,
771 strength: 0.8,
772 created_at: 0,
773 };
774 store_links(&store, std::slice::from_ref(&link))
776 .await
777 .expect("first store_links");
778 store_links(&store, &[link])
779 .await
780 .expect("second store_links (idempotent)");
781
782 let count: i64 = sqlx::query_scalar(
783 "SELECT COUNT(*) FROM causal_links WHERE cause_event_id = ? AND effect_event_id = ?",
784 )
785 .bind(events[0].id)
786 .bind(events[1].id)
787 .fetch_one(store.pool())
788 .await
789 .expect("count query");
790
791 assert_eq!(
792 count, 1,
793 "duplicate causal links must be deduplicated by UNIQUE constraint"
794 );
795 }
796}