1use std::sync::Arc;
18use std::time::Duration;
19
20use tokio_util::sync::CancellationToken;
21use zeph_llm::any::AnyProvider;
22use zeph_llm::provider::LlmProvider as _;
23
24use crate::error::MemoryError;
25use crate::store::SqliteStore;
26use crate::store::messages::PromotionCandidate;
27use crate::types::ConversationId;
28use zeph_common::math::cosine_similarity;
29
30const MERGE_VALIDATION_MIN_SIMILARITY: f32 = 0.7;
33
34#[derive(Debug, Clone)]
38pub struct TierPromotionConfig {
39 pub enabled: bool,
41 pub promotion_min_sessions: u32,
44 pub similarity_threshold: f32,
47 pub sweep_interval_secs: u64,
49 pub sweep_batch_size: usize,
51}
52
53pub async fn start_tier_promotion_loop(
66 store: Arc<SqliteStore>,
67 provider: AnyProvider,
68 config: TierPromotionConfig,
69 cancel: CancellationToken,
70) {
71 if !config.enabled {
72 tracing::debug!("tier promotion disabled (tiers.enabled = false)");
73 return;
74 }
75
76 let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
77 ticker.tick().await;
79
80 loop {
81 tokio::select! {
82 () = cancel.cancelled() => {
83 tracing::debug!("tier promotion loop shutting down");
84 return;
85 }
86 _ = ticker.tick() => {}
87 }
88
89 tracing::debug!("tier promotion: starting sweep");
90 let start = std::time::Instant::now();
91
92 let result = run_promotion_sweep(&store, &provider, &config).await;
93
94 let elapsed_ms = start.elapsed().as_millis();
95
96 match result {
97 Ok(stats) => {
98 tracing::info!(
99 candidates = stats.candidates_evaluated,
100 clusters = stats.clusters_formed,
101 promoted = stats.promotions_completed,
102 merge_failures = stats.merge_failures,
103 elapsed_ms,
104 "tier promotion: sweep complete"
105 );
106 }
107 Err(e) => {
108 tracing::warn!(error = %e, elapsed_ms, "tier promotion: sweep failed, will retry");
109 }
110 }
111 }
112}
113
114#[derive(Debug, Default)]
116struct SweepStats {
117 candidates_evaluated: usize,
118 clusters_formed: usize,
119 promotions_completed: usize,
120 merge_failures: usize,
121}
122
123#[cfg_attr(
125 feature = "profiling",
126 tracing::instrument(name = "memory.tier_promotion", skip_all)
127)]
128async fn run_promotion_sweep(
129 store: &SqliteStore,
130 provider: &AnyProvider,
131 config: &TierPromotionConfig,
132) -> Result<SweepStats, MemoryError> {
133 let candidates = store
134 .find_promotion_candidates(config.promotion_min_sessions, config.sweep_batch_size)
135 .await?;
136
137 if candidates.is_empty() {
138 return Ok(SweepStats::default());
139 }
140
141 let mut stats = SweepStats {
142 candidates_evaluated: candidates.len(),
143 ..SweepStats::default()
144 };
145
146 let mut embedded: Vec<(PromotionCandidate, Vec<f32>)> = Vec::with_capacity(candidates.len());
148 for candidate in candidates {
149 if !provider.supports_embeddings() {
150 embedded.push((candidate, Vec::new()));
152 continue;
153 }
154 match provider.embed(&candidate.content).await {
155 Ok(vec) => embedded.push((candidate, vec)),
156 Err(e) => {
157 tracing::warn!(
158 message_id = candidate.id.0,
159 error = %e,
160 "tier promotion: failed to embed candidate, skipping"
161 );
162 }
163 }
164 }
165
166 if embedded.is_empty() {
167 return Ok(stats);
168 }
169
170 let threshold = config.similarity_threshold;
174 let clusters = cluster_by_similarity(embedded, threshold);
175
176 for cluster in clusters {
177 if cluster.len() < 2 {
178 tracing::debug!(
180 cluster_size = cluster.len(),
181 "tier promotion: singleton cluster skipped"
182 );
183 continue;
184 }
185
186 stats.clusters_formed += 1;
187
188 let source_conv_id = cluster[0].0.conversation_id;
189
190 match merge_cluster_and_promote(store, provider, &cluster, source_conv_id).await {
191 Ok(()) => stats.promotions_completed += 1,
192 Err(e) => {
193 tracing::warn!(
194 cluster_size = cluster.len(),
195 error = %e,
196 "tier promotion: cluster merge failed, skipping"
197 );
198 stats.merge_failures += 1;
199 }
200 }
201 }
202
203 Ok(stats)
204}
205
206fn cluster_by_similarity(
212 candidates: Vec<(PromotionCandidate, Vec<f32>)>,
213 threshold: f32,
214) -> Vec<Vec<(PromotionCandidate, Vec<f32>)>> {
215 let mut clusters: Vec<Vec<(PromotionCandidate, Vec<f32>)>> = Vec::new();
216
217 'outer: for candidate in candidates {
218 if candidate.1.is_empty() {
219 clusters.push(vec![candidate]);
221 continue;
222 }
223
224 for cluster in &mut clusters {
225 let rep = &cluster[0].1;
226 if rep.is_empty() {
227 continue;
228 }
229 let sim = cosine_similarity(&candidate.1, rep);
230 if sim >= threshold {
231 cluster.push(candidate);
232 continue 'outer;
233 }
234 }
235
236 clusters.push(vec![candidate]);
237 }
238
239 clusters
240}
241
242async fn merge_cluster_and_promote(
248 store: &SqliteStore,
249 provider: &AnyProvider,
250 cluster: &[(PromotionCandidate, Vec<f32>)],
251 conversation_id: ConversationId,
252) -> Result<(), MemoryError> {
253 let contents: Vec<&str> = cluster.iter().map(|(c, _)| c.content.as_str()).collect();
254 let original_ids: Vec<crate::types::MessageId> = cluster.iter().map(|(c, _)| c.id).collect();
255
256 let merged = call_merge_llm(provider, &contents).await?;
257
258 let merged = merged.trim().to_owned();
260 if merged.is_empty() {
261 return Err(MemoryError::Other("LLM merge returned empty result".into()));
262 }
263
264 if provider.supports_embeddings() {
267 let embeddings_available = cluster.iter().any(|(_, emb)| !emb.is_empty());
268 if embeddings_available {
269 match provider.embed(&merged).await {
270 Ok(merged_vec) => {
271 let max_sim = cluster
272 .iter()
273 .filter(|(_, emb)| !emb.is_empty())
274 .map(|(_, emb)| cosine_similarity(&merged_vec, emb))
275 .fold(f32::NEG_INFINITY, f32::max);
276
277 if max_sim < MERGE_VALIDATION_MIN_SIMILARITY {
278 return Err(MemoryError::Other(format!(
279 "LLM merge validation failed: max similarity to originals = {max_sim:.3} < {MERGE_VALIDATION_MIN_SIMILARITY}"
280 )));
281 }
282 }
283 Err(e) => {
284 tracing::warn!(
285 error = %e,
286 "tier promotion: failed to embed merged result, skipping similarity validation"
287 );
288 }
289 }
290 }
291 }
292
293 let delays_ms = [50u64, 100, 200];
296 for (attempt, &delay_ms) in delays_ms.iter().enumerate() {
297 match store
298 .promote_to_semantic(conversation_id, &merged, &original_ids)
299 .await
300 {
301 Ok(_) => break,
302 Err(e) => {
303 let is_busy = if let MemoryError::Sqlx(sqlx::Error::Database(ref db_err)) = e {
308 db_err.code().as_deref() == Some("5")
309 } else {
310 e.to_string().contains("database is locked")
311 };
312 if is_busy && attempt < delays_ms.len() - 1 {
313 tracing::warn!(
314 attempt = attempt + 1,
315 delay_ms,
316 "tier promotion: SQLite busy, retrying"
317 );
318 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
319 } else {
320 return Err(e);
321 }
322 }
323 }
324 }
325 tracing::debug!(
326 cluster_size = cluster.len(),
327 merged_len = merged.len(),
328 "tier promotion: cluster promoted to semantic"
329 );
330
331 Ok(())
332}
333
334async fn call_merge_llm(provider: &AnyProvider, contents: &[&str]) -> Result<String, MemoryError> {
336 use zeph_llm::provider::{Message, MessageMetadata, Role};
337
338 let bullet_list: String = contents
339 .iter()
340 .enumerate()
341 .map(|(i, c)| format!("{}. {c}", i + 1))
342 .collect::<Vec<_>>()
343 .join("\n");
344
345 let system_content = "You are a memory consolidation agent. \
346 Merge the following episodic memories into a single concise semantic fact. \
347 Strip timestamps, session context, hedging, and filler. \
348 Output ONLY the distilled fact as a single plain-text sentence or short paragraph. \
349 Do not add prefixes like 'The user...' or 'Fact:'.";
350
351 let user_content =
352 format!("Merge these episodic memories into one semantic fact:\n\n{bullet_list}");
353
354 let messages = vec![
355 Message {
356 role: Role::System,
357 content: system_content.to_owned(),
358 parts: vec![],
359 metadata: MessageMetadata::default(),
360 },
361 Message {
362 role: Role::User,
363 content: user_content,
364 parts: vec![],
365 metadata: MessageMetadata::default(),
366 },
367 ];
368
369 let timeout = Duration::from_secs(15);
370
371 let result = tokio::time::timeout(timeout, provider.chat(&messages))
372 .await
373 .map_err(|_| MemoryError::Other("LLM merge timed out after 15s".into()))?
374 .map_err(MemoryError::Llm)?;
375
376 Ok(result)
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[test]
384 fn cluster_by_similarity_groups_identical() {
385 let v1 = vec![1.0f32, 0.0, 0.0];
387 let v2 = vec![1.0f32, 0.0, 0.0];
388 let v3 = vec![0.0f32, 1.0, 0.0]; let candidates = vec![
391 (make_candidate(1), v1),
392 (make_candidate(2), v2),
393 (make_candidate(3), v3),
394 ];
395
396 let clusters = cluster_by_similarity(candidates, 0.92f32);
397 assert_eq!(clusters.len(), 2, "should produce 2 clusters");
398 assert_eq!(clusters[0].len(), 2, "first cluster should have 2 members");
399 assert_eq!(clusters[1].len(), 1, "second cluster is the orthogonal one");
400 }
401
402 #[test]
403 fn cluster_by_similarity_empty_embeddings_become_singletons() {
404 let candidates = vec![(make_candidate(1), vec![]), (make_candidate(2), vec![])];
405 let clusters = cluster_by_similarity(candidates, 0.92);
406 assert_eq!(clusters.len(), 2);
407 }
408
409 fn make_candidate(id: i64) -> PromotionCandidate {
410 PromotionCandidate {
411 id: crate::types::MessageId(id),
412 conversation_id: ConversationId(1),
413 content: format!("content {id}"),
414 session_count: 3,
415 importance_score: 0.5,
416 }
417 }
418}