1use std::sync::Arc;
15use std::time::Duration;
16
17use tokio_util::sync::CancellationToken;
18use zeph_llm::any::AnyProvider;
19use zeph_llm::provider::{LlmProvider as _, Message, Role};
20
21use crate::error::MemoryError;
22use crate::store::SqliteStore;
23use crate::store::memory_tree::MemoryTreeRow;
24use zeph_common::math::cosine_similarity;
25
26const MERGE_SYSTEM_PROMPT: &str = "\
27You are a memory consolidation assistant. Given several related memory nodes, produce a single \
28concise summary that captures the essential information from all of them. \
29Keep it to 2-4 sentences. Do not repeat details already captured in a single sentence. \
30Return only the summary text — no JSON, no preamble.";
31
32#[derive(Clone)]
34pub struct TreeConsolidationConfig {
35 pub enabled: bool,
36 pub sweep_interval_secs: u64,
37 pub batch_size: usize,
38 pub similarity_threshold: f32,
39 pub max_level: u32,
40 pub min_cluster_size: usize,
41}
42
43#[derive(Debug, Default)]
45pub struct TreeConsolidationResult {
46 pub clusters_merged: u32,
47 pub nodes_created: u32,
48}
49
50pub async fn start_tree_consolidation_loop(
54 store: Arc<SqliteStore>,
55 provider: AnyProvider,
56 config: TreeConsolidationConfig,
57 cancel: CancellationToken,
58) {
59 if !config.enabled {
60 tracing::debug!("tree consolidation disabled (tree.enabled = false)");
61 return;
62 }
63
64 let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
65 ticker.tick().await;
67
68 loop {
69 tokio::select! {
70 () = cancel.cancelled() => {
71 tracing::debug!("tree consolidation loop shutting down");
72 return;
73 }
74 _ = ticker.tick() => {}
75 }
76
77 tracing::debug!("tree consolidation: starting sweep");
78 let start = std::time::Instant::now();
79
80 let result = run_tree_consolidation_sweep(&store, &provider, &config).await;
81 let elapsed_ms = start.elapsed().as_millis();
82
83 match result {
84 Ok(r) => tracing::info!(
85 clusters_merged = r.clusters_merged,
86 nodes_created = r.nodes_created,
87 elapsed_ms,
88 "tree consolidation: sweep complete"
89 ),
90 Err(e) => tracing::warn!(
91 error = %e,
92 elapsed_ms,
93 "tree consolidation: sweep failed, will retry"
94 ),
95 }
96 }
97}
98
99pub async fn run_tree_consolidation_sweep(
107 store: &SqliteStore,
108 provider: &AnyProvider,
109 config: &TreeConsolidationConfig,
110) -> Result<TreeConsolidationResult, MemoryError> {
111 let mut result = TreeConsolidationResult::default();
112
113 for level in 0..config.max_level {
114 let candidates = if level == 0 {
115 store
116 .load_tree_leaves_unconsolidated(config.batch_size)
117 .await?
118 } else {
119 store
120 .load_tree_level(i64::from(level), config.batch_size)
121 .await?
122 };
123
124 if candidates.len() < config.min_cluster_size {
125 continue;
126 }
127
128 if !provider.supports_embeddings() {
129 tracing::debug!(
130 "tree consolidation: provider has no embedding support, skipping level {level}"
131 );
132 continue;
133 }
134
135 let embedded = embed_candidates(provider, &candidates).await;
136 if embedded.len() < config.min_cluster_size {
137 continue;
138 }
139
140 let clusters = cluster_by_similarity(
141 &embedded,
142 config.similarity_threshold,
143 config.min_cluster_size,
144 );
145
146 for cluster in clusters {
147 if cluster.len() < config.min_cluster_size {
148 continue;
149 }
150
151 let child_ids: Vec<i64> = cluster.iter().map(|(id, _, _)| *id).collect();
152 let contents: Vec<&str> = cluster
153 .iter()
154 .map(|(_, content, _)| content.as_str())
155 .collect();
156
157 let summary = match merge_via_llm(provider, &contents).await {
158 Ok(s) => s,
159 Err(e) => {
160 tracing::warn!(
161 error = %e,
162 level,
163 child_count = cluster.len(),
164 "tree consolidation: LLM merge failed, skipping cluster"
165 );
166 continue;
167 }
168 };
169
170 if summary.is_empty() {
171 continue;
172 }
173
174 let token_count = i64::try_from(summary.split_whitespace().count()).unwrap_or(i64::MAX);
175 let source_ids_json =
176 serde_json::to_string(&child_ids).unwrap_or_else(|_| "[]".to_string());
177
178 match store
180 .consolidate_cluster(
181 i64::from(level + 1),
182 &summary,
183 &source_ids_json,
184 token_count,
185 &child_ids,
186 )
187 .await
188 {
189 Ok(_) => {}
190 Err(e) => {
191 tracing::warn!(
192 error = %e,
193 level,
194 child_count = cluster.len(),
195 "tree consolidation: cluster persist failed, skipping"
196 );
197 continue;
198 }
199 }
200
201 result.clusters_merged += 1;
202 result.nodes_created += 1;
203 }
204 }
205
206 if result.nodes_created > 0 {
207 let _ = store.increment_tree_consolidation_count().await;
208 }
209
210 Ok(result)
211}
212
213const EMBED_CONCURRENCY: usize = 8;
215
216async fn embed_candidates(
217 provider: &AnyProvider,
218 candidates: &[MemoryTreeRow],
219) -> Vec<(i64, String, Vec<f32>)> {
220 let mut embedded = Vec::with_capacity(candidates.len());
221
222 for chunk in candidates.chunks(EMBED_CONCURRENCY) {
224 let futures: Vec<_> = chunk
225 .iter()
226 .map(|row| {
227 let id = row.id;
228 let content = row.content.clone();
229 async move { (id, content.clone(), provider.embed(&content).await) }
230 })
231 .collect();
232
233 let results = futures::future::join_all(futures).await;
234 for (id, content, result) in results {
235 match result {
236 Ok(vec) => embedded.push((id, content, vec)),
237 Err(e) => tracing::warn!(
238 node_id = id,
239 error = %e,
240 "tree consolidation: failed to embed node, skipping"
241 ),
242 }
243 }
244 }
245 embedded
246}
247
248fn cluster_by_similarity(
253 embedded: &[(i64, String, Vec<f32>)],
254 threshold: f32,
255 min_cluster_size: usize,
256) -> Vec<Vec<(i64, String, Vec<f32>)>> {
257 let n = embedded.len();
258 let mut assigned = vec![false; n];
259 let mut clusters: Vec<Vec<(i64, String, Vec<f32>)>> = Vec::new();
260
261 for i in 0..n {
262 if assigned[i] {
263 continue;
264 }
265 let mut cluster = vec![embedded[i].clone()];
266 assigned[i] = true;
267
268 for j in (i + 1)..n {
269 if assigned[j] {
270 continue;
271 }
272 let sim = cosine_similarity(&embedded[i].2, &embedded[j].2);
273 if sim >= threshold {
274 cluster.push(embedded[j].clone());
275 assigned[j] = true;
276 }
277 }
278
279 if cluster.len() >= min_cluster_size {
280 clusters.push(cluster);
281 }
282 }
283
284 clusters
285}
286
287async fn merge_via_llm(provider: &AnyProvider, contents: &[&str]) -> Result<String, MemoryError> {
288 let mut user_prompt = String::from("Memory nodes to consolidate:\n");
289 for (i, content) in contents.iter().enumerate() {
290 use std::fmt::Write as _;
291 let _ = writeln!(user_prompt, "[{}] {}", i + 1, content);
292 }
293 user_prompt.push_str("\nProduce a concise summary.");
294
295 let llm_messages = [
296 Message::from_legacy(Role::System, MERGE_SYSTEM_PROMPT),
297 Message::from_legacy(Role::User, user_prompt),
298 ];
299
300 let response = provider
301 .chat(&llm_messages)
302 .await
303 .map_err(MemoryError::Llm)?;
304
305 Ok(response.trim().to_string())
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn cluster_by_similarity_groups_identical_vectors() {
314 let v1 = vec![1.0f32, 0.0, 0.0];
315 let v2 = vec![1.0f32, 0.0, 0.0];
316 let v3 = vec![0.0f32, 1.0, 0.0]; let embedded = vec![
319 (1i64, "a".to_string(), v1),
320 (2i64, "b".to_string(), v2),
321 (3i64, "c".to_string(), v3),
322 ];
323
324 let clusters = cluster_by_similarity(&embedded, 0.9, 2);
325 assert_eq!(
326 clusters.len(),
327 1,
328 "identical vectors should form one cluster"
329 );
330 assert_eq!(clusters[0].len(), 2);
331 }
332
333 #[test]
334 fn cluster_by_similarity_min_cluster_size_gate() {
335 let v1 = vec![1.0f32, 0.0];
336 let v2 = vec![1.0f32, 0.0];
337
338 let embedded = vec![(1i64, "a".to_string(), v1), (2i64, "b".to_string(), v2)];
339
340 let clusters = cluster_by_similarity(&embedded, 0.9, 3);
342 assert!(clusters.is_empty());
343 }
344
345 #[test]
346 fn cluster_by_similarity_no_duplicates_across_clusters() {
347 let v = vec![1.0f32, 0.0];
348 let embedded = vec![
349 (1i64, "a".to_string(), v.clone()),
350 (2i64, "b".to_string(), v.clone()),
351 (3i64, "c".to_string(), v.clone()),
352 ];
353
354 let clusters = cluster_by_similarity(&embedded, 0.9, 2);
355 let total_items: usize = clusters.iter().map(Vec::len).sum();
356 assert_eq!(total_items, 3);
358 }
359}