Skip to main content

zeph_memory/semantic/
tree_consolidation.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! `TiMem` temporal-hierarchical memory tree consolidation (#2262).
5//!
6//! Background loop that clusters unconsolidated leaf nodes by cosine similarity and merges
7//! each cluster into a parent node via LLM summarization.
8//!
9//! # Transaction safety (critic S2)
10//!
11//! Each cluster merge runs in its own transaction via `mark_nodes_consolidated`.
12//! The full sweep never holds a write lock across multiple clusters.
13
14use std::sync::Arc;
15use std::time::Duration;
16
17use tokio_util::sync::CancellationToken;
18use zeph_llm::any::AnyProvider;
19use zeph_llm::provider::{LlmProvider as _, Message, Role};
20
21use crate::error::MemoryError;
22use crate::store::SqliteStore;
23use crate::store::memory_tree::MemoryTreeRow;
24use zeph_common::math::cosine_similarity;
25
26const MERGE_SYSTEM_PROMPT: &str = "\
27You are a memory consolidation assistant. Given several related memory nodes, produce a single \
28concise summary that captures the essential information from all of them. \
29Keep it to 2-4 sentences. Do not repeat details already captured in a single sentence. \
30Return only the summary text — no JSON, no preamble.";
31
32/// Configuration for the tree consolidation loop.
33#[derive(Clone)]
34pub struct TreeConsolidationConfig {
35    pub enabled: bool,
36    pub sweep_interval_secs: u64,
37    pub batch_size: usize,
38    pub similarity_threshold: f32,
39    pub max_level: u32,
40    pub min_cluster_size: usize,
41}
42
43/// Result of one consolidation sweep.
44#[derive(Debug, Default)]
45pub struct TreeConsolidationResult {
46    pub clusters_merged: u32,
47    pub nodes_created: u32,
48}
49
50/// Start the background tree consolidation loop.
51///
52/// The loop exits immediately when `config.enabled = false` or `cancel` fires.
53pub async fn start_tree_consolidation_loop(
54    store: Arc<SqliteStore>,
55    provider: AnyProvider,
56    config: TreeConsolidationConfig,
57    cancel: CancellationToken,
58) {
59    if !config.enabled {
60        tracing::debug!("tree consolidation disabled (tree.enabled = false)");
61        return;
62    }
63
64    let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
65    // Skip the first immediate tick to avoid running at startup.
66    ticker.tick().await;
67
68    loop {
69        tokio::select! {
70            () = cancel.cancelled() => {
71                tracing::debug!("tree consolidation loop shutting down");
72                return;
73            }
74            _ = ticker.tick() => {}
75        }
76
77        tracing::debug!("tree consolidation: starting sweep");
78        let start = std::time::Instant::now();
79
80        let result = run_tree_consolidation_sweep(&store, &provider, &config).await;
81        let elapsed_ms = start.elapsed().as_millis();
82
83        match result {
84            Ok(r) => tracing::info!(
85                clusters_merged = r.clusters_merged,
86                nodes_created = r.nodes_created,
87                elapsed_ms,
88                "tree consolidation: sweep complete"
89            ),
90            Err(e) => tracing::warn!(
91                error = %e,
92                elapsed_ms,
93                "tree consolidation: sweep failed, will retry"
94            ),
95        }
96    }
97}
98
99/// Execute one full consolidation sweep: leaves → level 1, then level 1 → level 2, etc.
100///
101/// Each cluster runs inside its own transaction (critic S2).
102///
103/// # Errors
104///
105/// Returns an error if a database query fails.
106pub async fn run_tree_consolidation_sweep(
107    store: &SqliteStore,
108    provider: &AnyProvider,
109    config: &TreeConsolidationConfig,
110) -> Result<TreeConsolidationResult, MemoryError> {
111    let mut result = TreeConsolidationResult::default();
112
113    for level in 0..config.max_level {
114        let candidates = if level == 0 {
115            store
116                .load_tree_leaves_unconsolidated(config.batch_size)
117                .await?
118        } else {
119            store
120                .load_tree_level(i64::from(level), config.batch_size)
121                .await?
122        };
123
124        if candidates.len() < config.min_cluster_size {
125            continue;
126        }
127
128        if !provider.supports_embeddings() {
129            tracing::debug!(
130                "tree consolidation: provider has no embedding support, skipping level {level}"
131            );
132            continue;
133        }
134
135        let embedded = embed_candidates(provider, &candidates).await;
136        if embedded.len() < config.min_cluster_size {
137            continue;
138        }
139
140        let clusters = cluster_by_similarity(
141            &embedded,
142            config.similarity_threshold,
143            config.min_cluster_size,
144        );
145
146        for cluster in clusters {
147            if cluster.len() < config.min_cluster_size {
148                continue;
149            }
150
151            let child_ids: Vec<i64> = cluster.iter().map(|(id, _, _)| *id).collect();
152            let contents: Vec<&str> = cluster
153                .iter()
154                .map(|(_, content, _)| content.as_str())
155                .collect();
156
157            let summary = match merge_via_llm(provider, &contents).await {
158                Ok(s) => s,
159                Err(e) => {
160                    tracing::warn!(
161                        error = %e,
162                        level,
163                        child_count = cluster.len(),
164                        "tree consolidation: LLM merge failed, skipping cluster"
165                    );
166                    continue;
167                }
168            };
169
170            if summary.is_empty() {
171                continue;
172            }
173
174            let token_count = i64::try_from(summary.split_whitespace().count()).unwrap_or(i64::MAX);
175            let source_ids_json =
176                serde_json::to_string(&child_ids).unwrap_or_else(|_| "[]".to_string());
177
178            // Atomic cluster consolidation: INSERT parent + UPDATE children in one transaction.
179            match store
180                .consolidate_cluster(
181                    i64::from(level + 1),
182                    &summary,
183                    &source_ids_json,
184                    token_count,
185                    &child_ids,
186                )
187                .await
188            {
189                Ok(_) => {}
190                Err(e) => {
191                    tracing::warn!(
192                        error = %e,
193                        level,
194                        child_count = cluster.len(),
195                        "tree consolidation: cluster persist failed, skipping"
196                    );
197                    continue;
198                }
199            }
200
201            result.clusters_merged += 1;
202            result.nodes_created += 1;
203        }
204    }
205
206    if result.nodes_created > 0 {
207        let _ = store.increment_tree_consolidation_count().await;
208    }
209
210    Ok(result)
211}
212
213/// Concurrency cap for embed calls — matches `embed_concurrency` default (#2677).
214const EMBED_CONCURRENCY: usize = 8;
215
216async fn embed_candidates(
217    provider: &AnyProvider,
218    candidates: &[MemoryTreeRow],
219) -> Vec<(i64, String, Vec<f32>)> {
220    let mut embedded = Vec::with_capacity(candidates.len());
221
222    // Process in bounded batches to avoid saturating the embed provider (#2677).
223    for chunk in candidates.chunks(EMBED_CONCURRENCY) {
224        let futures: Vec<_> = chunk
225            .iter()
226            .map(|row| {
227                let id = row.id;
228                let content = row.content.clone();
229                async move { (id, content.clone(), provider.embed(&content).await) }
230            })
231            .collect();
232
233        let results = futures::future::join_all(futures).await;
234        for (id, content, result) in results {
235            match result {
236                Ok(vec) => embedded.push((id, content, vec)),
237                Err(e) => tracing::warn!(
238                    node_id = id,
239                    error = %e,
240                    "tree consolidation: failed to embed node, skipping"
241                ),
242            }
243        }
244    }
245    embedded
246}
247
248// INVARIANT: `embedded` must be ordered by `created_at ASC` (as returned by
249// `load_tree_leaves_unconsolidated` / `load_tree_level`).  The greedy leader-based algorithm
250// is deterministic only when the input order is stable across sweeps.  Do not sort or shuffle
251// the slice before calling this function.
252fn cluster_by_similarity(
253    embedded: &[(i64, String, Vec<f32>)],
254    threshold: f32,
255    min_cluster_size: usize,
256) -> Vec<Vec<(i64, String, Vec<f32>)>> {
257    let n = embedded.len();
258    let mut assigned = vec![false; n];
259    let mut clusters: Vec<Vec<(i64, String, Vec<f32>)>> = Vec::new();
260
261    for i in 0..n {
262        if assigned[i] {
263            continue;
264        }
265        let mut cluster = vec![embedded[i].clone()];
266        assigned[i] = true;
267
268        for j in (i + 1)..n {
269            if assigned[j] {
270                continue;
271            }
272            let sim = cosine_similarity(&embedded[i].2, &embedded[j].2);
273            if sim >= threshold {
274                cluster.push(embedded[j].clone());
275                assigned[j] = true;
276            }
277        }
278
279        if cluster.len() >= min_cluster_size {
280            clusters.push(cluster);
281        }
282    }
283
284    clusters
285}
286
287async fn merge_via_llm(provider: &AnyProvider, contents: &[&str]) -> Result<String, MemoryError> {
288    let mut user_prompt = String::from("Memory nodes to consolidate:\n");
289    for (i, content) in contents.iter().enumerate() {
290        use std::fmt::Write as _;
291        let _ = writeln!(user_prompt, "[{}] {}", i + 1, content);
292    }
293    user_prompt.push_str("\nProduce a concise summary.");
294
295    let llm_messages = [
296        Message::from_legacy(Role::System, MERGE_SYSTEM_PROMPT),
297        Message::from_legacy(Role::User, user_prompt),
298    ];
299
300    let response = provider
301        .chat(&llm_messages)
302        .await
303        .map_err(MemoryError::Llm)?;
304
305    Ok(response.trim().to_string())
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn cluster_by_similarity_groups_identical_vectors() {
314        let v1 = vec![1.0f32, 0.0, 0.0];
315        let v2 = vec![1.0f32, 0.0, 0.0];
316        let v3 = vec![0.0f32, 1.0, 0.0]; // orthogonal
317
318        let embedded = vec![
319            (1i64, "a".to_string(), v1),
320            (2i64, "b".to_string(), v2),
321            (3i64, "c".to_string(), v3),
322        ];
323
324        let clusters = cluster_by_similarity(&embedded, 0.9, 2);
325        assert_eq!(
326            clusters.len(),
327            1,
328            "identical vectors should form one cluster"
329        );
330        assert_eq!(clusters[0].len(), 2);
331    }
332
333    #[test]
334    fn cluster_by_similarity_min_cluster_size_gate() {
335        let v1 = vec![1.0f32, 0.0];
336        let v2 = vec![1.0f32, 0.0];
337
338        let embedded = vec![(1i64, "a".to_string(), v1), (2i64, "b".to_string(), v2)];
339
340        // Require 3 — no cluster should form.
341        let clusters = cluster_by_similarity(&embedded, 0.9, 3);
342        assert!(clusters.is_empty());
343    }
344
345    #[test]
346    fn cluster_by_similarity_no_duplicates_across_clusters() {
347        let v = vec![1.0f32, 0.0];
348        let embedded = vec![
349            (1i64, "a".to_string(), v.clone()),
350            (2i64, "b".to_string(), v.clone()),
351            (3i64, "c".to_string(), v.clone()),
352        ];
353
354        let clusters = cluster_by_similarity(&embedded, 0.9, 2);
355        let total_items: usize = clusters.iter().map(Vec::len).sum();
356        // All items across all clusters are unique (no double-assignment).
357        assert_eq!(total_items, 3);
358    }
359}