1use std::sync::Arc;
23use std::time::Duration;
24
25use serde::{Deserialize, Serialize};
26use tokio_util::sync::CancellationToken;
27use zeph_common::config::memory::HebbianConsolidationConfig;
28use zeph_db::sql;
29use zeph_llm::any::AnyProvider;
30use zeph_llm::provider::{LlmProvider as _, Message, Role};
31
32use crate::error::MemoryError;
33use crate::store::SqliteStore;
34
35struct ClearStatusOnDrop(Option<tokio::sync::mpsc::UnboundedSender<String>>);
41
42impl Drop for ClearStatusOnDrop {
43 fn drop(&mut self) {
44 if let Some(ref tx) = self.0 {
45 let _ = tx.send(String::new());
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
54pub struct HebbianConsolidationOutcome {
55 pub summary: String,
57 pub trigger_hint: Option<String>,
59 pub confidence: f64,
61}
62
63#[derive(Debug, Clone)]
65pub struct HebbianConsolidationCandidate {
66 pub entity_id: i64,
68 pub degree: u64,
70 pub avg_weight: f64,
72 pub score: f64,
74}
75
76#[derive(Debug, Clone)]
78pub struct GraphRule {
79 pub id: i64,
81 pub anchor_entity_id: i64,
83 pub summary: String,
85 pub trigger_hint: Option<String>,
87 pub confidence: f64,
89 pub created_at: i64,
91}
92
93pub async fn find_candidates(
105 pool: &zeph_db::DbPool,
106 threshold: f64,
107 cooldown_before: i64,
108 limit: usize,
109) -> Result<Vec<HebbianConsolidationCandidate>, MemoryError> {
110 let rows: Vec<(i64, i64, f64, f64)> = zeph_db::query_as(sql!(
112 "SELECT e.id,
113 COUNT(ed.id) AS degree,
114 AVG(ed.weight) AS avg_weight,
115 COUNT(ed.id) * AVG(ed.weight) AS score
116 FROM graph_entities e
117 JOIN graph_edges ed
118 ON (ed.source_entity_id = e.id OR ed.target_entity_id = e.id)
119 AND ed.valid_to IS NULL
120 WHERE (e.consolidated_at IS NULL OR e.consolidated_at < ?)
121 GROUP BY e.id
122 HAVING score > ?
123 ORDER BY score DESC
124 LIMIT ?"
125 ))
126 .bind(cooldown_before)
127 .bind(threshold)
128 .bind(i64::try_from(limit).unwrap_or(i64::MAX))
129 .fetch_all(pool)
130 .await?;
131
132 Ok(rows
133 .into_iter()
134 .map(
135 |(entity_id, degree, avg_weight, score)| HebbianConsolidationCandidate {
136 entity_id,
137 degree: u64::try_from(degree).unwrap_or(0),
139 avg_weight,
140 score,
141 },
142 )
143 .collect())
144}
145
146pub async fn collect_neighbors(
155 pool: &zeph_db::DbPool,
156 entity_id: i64,
157 max_neighbors: usize,
158) -> Result<Vec<String>, MemoryError> {
159 let query_fut = zeph_db::query_as(sql!(
161 "SELECT DISTINCT e.summary
162 FROM graph_entities e
163 JOIN graph_edges ed
164 ON (ed.source_entity_id = ? AND ed.target_entity_id = e.id)
165 OR (ed.target_entity_id = ? AND ed.source_entity_id = e.id)
166 WHERE ed.valid_to IS NULL
167 AND e.summary IS NOT NULL
168 LIMIT ?"
169 ))
170 .bind(entity_id)
171 .bind(entity_id)
172 .bind(i64::try_from(max_neighbors).unwrap_or(i64::MAX))
173 .fetch_all(pool);
174
175 let rows: Vec<(Option<String>,)> = tokio::time::timeout(Duration::from_secs(10), query_fut)
176 .await
177 .map_err(|_| {
178 tracing::warn!(
179 entity_id,
180 "hebbian_consolidation: collect_neighbors timed out after 10s"
181 );
182 MemoryError::Timeout("collect_neighbors".into())
183 })??;
184
185 Ok(rows.into_iter().filter_map(|(s,)| s).collect())
186}
187
188pub async fn distill_cluster(
193 provider: &AnyProvider,
194 neighbors: &[String],
195 timeout_secs: u64,
196) -> Option<HebbianConsolidationOutcome> {
197 if neighbors.is_empty() {
198 return None;
199 }
200
201 let cluster_text = neighbors
202 .iter()
203 .enumerate()
204 .map(|(i, s)| format!(" [{}] {s}", i + 1))
205 .collect::<Vec<_>>()
206 .join("\n");
207
208 let system = "You are a memory strategy analyst. \
209 Given a cluster of related entity summaries from an agent's knowledge graph, \
210 produce a single JSON object with this exact schema:\n\
211 {\"summary\":\"<distilled strategy or pattern>\",\
212 \"trigger_hint\":\"<short retrieval phrase, or null>\",\
213 \"confidence\":<0.0-1.0>}\n\
214 Return ONLY the JSON object — no markdown, no explanation.";
215
216 let user = format!("Entity cluster:\n{cluster_text}");
217
218 let messages = vec![
219 Message::from_legacy(Role::System, system),
220 Message::from_legacy(Role::User, &user),
221 ];
222
223 let chat_future = provider.chat(&messages);
224 let text = match tokio::time::timeout(Duration::from_secs(timeout_secs), chat_future).await {
225 Ok(Ok(t)) => t,
226 Ok(Err(e)) => {
227 tracing::warn!(error = %e, "hebbian_consolidation: LLM call failed");
228 return None;
229 }
230 Err(_) => {
231 tracing::warn!(timeout_secs, "hebbian_consolidation: LLM call timed out");
232 return None;
233 }
234 };
235
236 let start = text.find('{')?;
237 let end = text.rfind('}')?;
238 let json_slice = &text[start..=end];
239
240 match serde_json::from_str::<HebbianConsolidationOutcome>(json_slice) {
241 Ok(outcome) => Some(outcome),
242 Err(e) => {
243 tracing::debug!(
244 error = %e,
245 response = %json_slice,
246 "hebbian_consolidation: failed to parse LLM response"
247 );
248 None
249 }
250 }
251}
252
253pub async fn insert_graph_rule_and_mark(
263 pool: &zeph_db::DbPool,
264 anchor_id: i64,
265 outcome: &HebbianConsolidationOutcome,
266) -> Result<(), MemoryError> {
267 let now = chrono::Utc::now().timestamp();
268
269 let begin_fut = pool.begin();
270 let mut tx = tokio::time::timeout(Duration::from_secs(10), begin_fut)
271 .await
272 .map_err(|_| {
273 tracing::warn!(
274 anchor_id,
275 "hebbian_consolidation: begin transaction timed out after 10s"
276 );
277 MemoryError::Timeout("insert_graph_rule_and_mark: begin".into())
278 })??;
279
280 let insert_fut = zeph_db::query(sql!(
281 "INSERT INTO graph_rules (anchor_entity_id, summary, trigger_hint, confidence, created_at)
282 VALUES (?, ?, ?, ?, ?)"
283 ))
284 .bind(anchor_id)
285 .bind(&outcome.summary)
286 .bind(outcome.trigger_hint.as_deref())
287 .bind(outcome.confidence)
288 .bind(now)
289 .execute(&mut *tx);
290
291 tokio::time::timeout(Duration::from_secs(10), insert_fut)
292 .await
293 .map_err(|_| {
294 tracing::warn!(
295 anchor_id,
296 "hebbian_consolidation: INSERT graph_rules timed out after 10s"
297 );
298 MemoryError::Timeout("insert_graph_rule_and_mark: insert".into())
299 })??;
300
301 let update_fut = zeph_db::query(sql!(
302 "UPDATE graph_entities SET consolidated_at = ? WHERE id = ?"
303 ))
304 .bind(now)
305 .bind(anchor_id)
306 .execute(&mut *tx);
307
308 tokio::time::timeout(Duration::from_secs(10), update_fut)
309 .await
310 .map_err(|_| {
311 tracing::warn!(
312 anchor_id,
313 "hebbian_consolidation: UPDATE graph_entities timed out after 10s"
314 );
315 MemoryError::Timeout("insert_graph_rule_and_mark: update".into())
316 })??;
317
318 tx.commit().await?;
319 Ok(())
320}
321
322#[tracing::instrument(skip_all)]
332pub async fn run_consolidation_sweep(
333 store: &SqliteStore,
334 config: &HebbianConsolidationConfig,
335 provider: &AnyProvider,
336 status_tx: Option<&tokio::sync::mpsc::UnboundedSender<String>>,
337 cancel: &CancellationToken,
338) -> Result<u32, MemoryError> {
339 let _clear_status = ClearStatusOnDrop(status_tx.cloned());
341
342 if let Some(tx) = status_tx {
343 let _ = tx.send("Consolidating memory clusters\u{2026}".to_owned());
344 }
345
346 let now = chrono::Utc::now().timestamp();
347 let cooldown_secs = i64::try_from(config.consolidation_cooldown_secs).unwrap_or(i64::MAX);
348 let cooldown_before = now.saturating_sub(cooldown_secs);
349
350 let candidates = find_candidates(
351 store.pool(),
352 config.consolidation_threshold,
353 cooldown_before,
354 config.max_candidates_per_sweep,
355 )
356 .await?;
357
358 let mut consolidated = 0u32;
359
360 use tracing::Instrument as _;
362
363 for candidate in &candidates {
364 if cancel.is_cancelled() {
365 tracing::debug!("hebbian consolidation sweep cancelled mid-sweep");
366 break;
367 }
368
369 let neighbors = {
370 match collect_neighbors(
371 store.pool(),
372 candidate.entity_id,
373 config.consolidation_max_neighbors,
374 )
375 .instrument(tracing::debug_span!("memory.hebbian.collect_neighbors"))
376 .await
377 {
378 Ok(n) => n,
379 Err(e) => {
380 tracing::warn!(
381 entity_id = candidate.entity_id,
382 error = %e,
383 "hebbian_consolidation: failed to collect neighbours, skipping"
384 );
385 continue;
386 }
387 }
388 };
389
390 if neighbors.is_empty() {
391 tracing::debug!(
392 entity_id = candidate.entity_id,
393 "hebbian_consolidation: no summaries in neighbourhood, skipping"
394 );
395 continue;
396 }
397
398 let outcome = {
399 distill_cluster(
400 provider,
401 &neighbors,
402 config.consolidation_prompt_timeout_secs,
403 )
404 .instrument(tracing::debug_span!("memory.hebbian.distill"))
405 .await
406 };
407
408 let Some(outcome) = outcome else {
409 tracing::debug!(
410 entity_id = candidate.entity_id,
411 "hebbian_consolidation: LLM returned no outcome, skipping"
412 );
413 continue;
414 };
415
416 let insert_result = {
417 insert_graph_rule_and_mark(store.pool(), candidate.entity_id, &outcome)
418 .instrument(tracing::debug_span!("memory.hebbian.insert"))
419 .await
420 };
421
422 match insert_result {
423 Ok(()) => {
424 consolidated += 1;
425 tracing::info!(
426 entity_id = candidate.entity_id,
427 score = candidate.score,
428 confidence = outcome.confidence,
429 "hebbian_consolidation: rule inserted"
430 );
431 }
432 Err(e) => {
433 tracing::warn!(
434 entity_id = candidate.entity_id,
435 error = %e,
436 "hebbian_consolidation: failed to insert rule"
437 );
438 }
439 }
440 }
441
442 Ok(consolidated)
443}
444
445pub async fn spawn_consolidation_loop(
451 store: Arc<SqliteStore>,
452 config: HebbianConsolidationConfig,
453 provider: AnyProvider,
454 status_tx: Option<tokio::sync::mpsc::UnboundedSender<String>>,
455 cancel: CancellationToken,
456) {
457 if config.consolidation_interval_secs == 0 {
458 tracing::debug!("hebbian_consolidation: loop disabled (consolidation_interval_secs = 0)");
459 return;
460 }
461
462 let mut ticker = tokio::time::interval(Duration::from_secs(config.consolidation_interval_secs));
463 ticker.tick().await;
465
466 loop {
467 tokio::select! {
468 () = cancel.cancelled() => {
469 tracing::debug!("hebbian_consolidation: loop shutting down");
470 return;
471 }
472 _ = ticker.tick() => {}
473 }
474
475 let start = std::time::Instant::now();
476 tracing::debug!("hebbian_consolidation: starting sweep");
477
478 match run_consolidation_sweep(&store, &config, &provider, status_tx.as_ref(), &cancel).await
479 {
480 Ok(n) => {
481 tracing::info!(
482 consolidated = n,
483 elapsed_ms = start.elapsed().as_millis(),
484 "hebbian_consolidation: sweep complete"
485 );
486 }
487 Err(e) => {
488 tracing::warn!(
489 error = %e,
490 elapsed_ms = start.elapsed().as_millis(),
491 "hebbian_consolidation: sweep failed, will retry"
492 );
493 }
494 }
495 }
496}
497
498#[cfg(test)]
501mod tests {
502 use zeph_llm::any::AnyProvider;
503 use zeph_llm::mock::MockProvider;
504
505 use super::*;
506 use crate::store::SqliteStore;
507
508 async fn make_store() -> SqliteStore {
509 SqliteStore::new(":memory:").await.unwrap()
510 }
511
512 async fn seed_entity_with_edges(
518 store: &SqliteStore,
519 name: &str,
520 edge_count: usize,
521 weight: f64,
522 ) -> i64 {
523 let entity_id: i64 = zeph_db::query_scalar(sql!(
524 "INSERT INTO graph_entities (name, canonical_name, entity_type)
525 VALUES (?, ?, 'concept')
526 RETURNING id"
527 ))
528 .bind(name)
529 .bind(name.to_lowercase())
530 .fetch_one(store.pool())
531 .await
532 .unwrap();
533
534 for i in 0..edge_count {
535 let target_name = format!("{name}_sink_{i}");
537 let target_id: i64 = zeph_db::query_scalar(
538 "INSERT INTO graph_entities (name, canonical_name, entity_type)
539 VALUES (?, ?, 'concept')
540 RETURNING id",
541 )
542 .bind(&target_name)
543 .bind(&target_name)
544 .fetch_one(store.pool())
545 .await
546 .unwrap();
547
548 zeph_db::query(
549 "INSERT INTO graph_edges
550 (source_entity_id, target_entity_id, relation, fact, confidence, weight)
551 VALUES (?, ?, 'related', 'test fact', 1.0, ?)",
552 )
553 .bind(entity_id)
554 .bind(target_id)
555 .bind(weight)
556 .execute(store.pool())
557 .await
558 .unwrap();
559 }
560
561 entity_id
562 }
563
564 #[tokio::test]
565 async fn test_find_candidates_empty_db() {
566 let store = make_store().await;
567 let candidates = find_candidates(store.pool(), 5.0, 0, 10).await.unwrap();
568 assert!(candidates.is_empty(), "empty DB must return no candidates");
569 }
570
571 #[tokio::test]
572 async fn test_find_candidates_below_threshold() {
573 let store = make_store().await;
574 seed_entity_with_edges(&store, "low", 1, 1.0).await;
576 let candidates = find_candidates(store.pool(), 5.0, 0, 10).await.unwrap();
577 assert!(
578 candidates.is_empty(),
579 "entity below threshold must not be returned"
580 );
581 }
582
583 #[tokio::test]
584 async fn test_find_candidates_above_threshold() {
585 let store = make_store().await;
586 let entity_id = seed_entity_with_edges(&store, "hot", 3, 2.0).await;
588 let candidates = find_candidates(store.pool(), 5.0, 0, 10).await.unwrap();
589 assert_eq!(candidates.len(), 1);
590 assert_eq!(candidates[0].entity_id, entity_id);
591 assert!(candidates[0].score > 5.0);
592 }
593
594 #[tokio::test]
595 async fn test_cooldown_respected() {
596 let store = make_store().await;
597 let entity_id = seed_entity_with_edges(&store, "hot", 3, 2.0).await;
598
599 let now = chrono::Utc::now().timestamp();
601 zeph_db::query(sql!(
602 "UPDATE graph_entities SET consolidated_at = ? WHERE id = ?"
603 ))
604 .bind(now)
605 .bind(entity_id)
606 .execute(store.pool())
607 .await
608 .unwrap();
609
610 let cooldown_before = now - 86_400;
612 let candidates = find_candidates(store.pool(), 5.0, cooldown_before, 10)
613 .await
614 .unwrap();
615 assert!(
616 candidates.is_empty(),
617 "entity within cooldown window must be skipped"
618 );
619 }
620
621 #[tokio::test]
622 async fn test_distill_cluster_parse_failure() {
623 let mock = MockProvider::with_responses(vec!["not valid json at all".to_owned()]);
624 let provider = AnyProvider::Mock(mock);
625 let neighbors = vec!["Entity A uses Rust".to_owned()];
626 let result = distill_cluster(&provider, &neighbors, 30).await;
627 assert!(
628 result.is_none(),
629 "unparseable LLM response must return None"
630 );
631 }
632
633 #[tokio::test]
634 async fn test_insert_graph_rule_marks_consolidated_at() {
635 let store = make_store().await;
636 let entity_id = seed_entity_with_edges(&store, "anchor", 3, 2.0).await;
637
638 let outcome = HebbianConsolidationOutcome {
639 summary: "Agent frequently uses Rust for systems programming".to_owned(),
640 trigger_hint: Some("Rust systems".to_owned()),
641 confidence: 0.9,
642 };
643
644 insert_graph_rule_and_mark(store.pool(), entity_id, &outcome)
645 .await
646 .unwrap();
647
648 let rule_count: (i64,) = zeph_db::query_as(sql!(
650 "SELECT COUNT(*) FROM graph_rules WHERE anchor_entity_id = ?"
651 ))
652 .bind(entity_id)
653 .fetch_one(store.pool())
654 .await
655 .unwrap();
656 assert_eq!(rule_count.0, 1, "one rule must be inserted");
657
658 let ts: (Option<i64>,) = zeph_db::query_as(sql!(
660 "SELECT consolidated_at FROM graph_entities WHERE id = ?"
661 ))
662 .bind(entity_id)
663 .fetch_one(store.pool())
664 .await
665 .unwrap();
666 assert!(
667 ts.0.is_some(),
668 "consolidated_at must be set after insert_graph_rule_and_mark"
669 );
670 }
671
672 #[tokio::test]
673 async fn test_enabled_false_skips_sweep() {
674 let store = Arc::new(make_store().await);
675 seed_entity_with_edges(&store, "hot", 3, 2.0).await;
677
678 let config = HebbianConsolidationConfig {
680 consolidation_interval_secs: 0,
681 ..HebbianConsolidationConfig::default()
682 };
683
684 let mock = MockProvider::default();
685 let provider = AnyProvider::Mock(mock);
686
687 let cancel = CancellationToken::new();
689 let handle = tokio::spawn(spawn_consolidation_loop(
690 store.clone(),
691 config,
692 provider,
693 None,
694 cancel.clone(),
695 ));
696 tokio::time::timeout(Duration::from_millis(100), handle)
698 .await
699 .expect("loop must exit immediately when interval=0")
700 .unwrap();
701
702 let count: (i64,) = zeph_db::query_as(sql!("SELECT COUNT(*) FROM graph_rules"))
704 .fetch_one(store.pool())
705 .await
706 .unwrap();
707 assert_eq!(
708 count.0, 0,
709 "no rules must be inserted when loop is disabled"
710 );
711 }
712
713 #[tokio::test]
714 async fn test_sweep_cancelled_mid_loop() {
715 let store = Arc::new(make_store().await);
716 seed_entity_with_edges(&store, "hot1", 3, 2.0).await;
718 seed_entity_with_edges(&store, "hot2", 4, 2.0).await;
719
720 let config = HebbianConsolidationConfig {
721 consolidation_threshold: 5.0,
722 max_candidates_per_sweep: 10,
723 ..HebbianConsolidationConfig::default()
724 };
725
726 let cancel = CancellationToken::new();
727 cancel.cancel();
729
730 let mock = MockProvider::default();
731 let provider = AnyProvider::Mock(mock);
732 let result = run_consolidation_sweep(&store, &config, &provider, None, &cancel).await;
733
734 assert!(result.is_ok(), "cancelled sweep must not return error");
736 assert_eq!(result.unwrap(), 0, "cancelled sweep must insert zero rules");
737 }
738}