1use std::sync::Arc;
18use std::time::Duration;
19
20use tokio_util::sync::CancellationToken;
21use tracing::Instrument as _;
22use zeph_llm::any::AnyProvider;
23use zeph_llm::provider::LlmProvider as _;
24
25use crate::error::MemoryError;
26use crate::store::SqliteStore;
27use crate::store::messages::PromotionCandidate;
28use crate::types::ConversationId;
29use zeph_common::math::cosine_similarity;
30
31const MERGE_VALIDATION_MIN_SIMILARITY: f32 = 0.7;
34
35#[derive(Debug, Clone)]
39pub struct TierPromotionConfig {
40 pub enabled: bool,
42 pub promotion_min_sessions: u32,
45 pub similarity_threshold: f32,
48 pub sweep_interval_secs: u64,
50 pub sweep_batch_size: usize,
52 pub embed_timeout_secs: u64,
54}
55
56pub async fn start_tier_promotion_loop(
69 store: Arc<SqliteStore>,
70 provider: AnyProvider,
71 config: TierPromotionConfig,
72 cancel: CancellationToken,
73) {
74 if !config.enabled {
75 tracing::debug!("tier promotion disabled (tiers.enabled = false)");
76 return;
77 }
78
79 let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
80 ticker.tick().await;
82
83 loop {
84 tokio::select! {
85 () = cancel.cancelled() => {
86 tracing::debug!("tier promotion loop shutting down");
87 return;
88 }
89 _ = ticker.tick() => {}
90 }
91
92 tracing::debug!("tier promotion: starting sweep");
93 let start = std::time::Instant::now();
94
95 let result = run_promotion_sweep(&store, &provider, &config).await;
96
97 let elapsed_ms = start.elapsed().as_millis();
98
99 match result {
100 Ok(stats) => {
101 tracing::info!(
102 candidates = stats.candidates_evaluated,
103 clusters = stats.clusters_formed,
104 promoted = stats.promotions_completed,
105 merge_failures = stats.merge_failures,
106 elapsed_ms,
107 "tier promotion: sweep complete"
108 );
109 }
110 Err(e) => {
111 tracing::warn!(error = %e, elapsed_ms, "tier promotion: sweep failed, will retry");
112 }
113 }
114 }
115}
116
117#[derive(Debug, Default)]
119struct SweepStats {
120 candidates_evaluated: usize,
121 clusters_formed: usize,
122 promotions_completed: usize,
123 merge_failures: usize,
124}
125
126#[tracing::instrument(name = "memory.tiers.promotion_sweep", skip_all)]
128async fn run_promotion_sweep(
129 store: &SqliteStore,
130 provider: &AnyProvider,
131 config: &TierPromotionConfig,
132) -> Result<SweepStats, MemoryError> {
133 let candidates = store
134 .find_promotion_candidates(config.promotion_min_sessions, config.sweep_batch_size)
135 .await?;
136
137 if candidates.is_empty() {
138 return Ok(SweepStats::default());
139 }
140
141 let mut stats = SweepStats {
142 candidates_evaluated: candidates.len(),
143 ..SweepStats::default()
144 };
145
146 let embedded: Vec<(PromotionCandidate, Vec<f32>)> = if provider.supports_embeddings() {
148 let texts: Vec<&str> = candidates.iter().map(|c| c.content.as_str()).collect();
149 let span = tracing::info_span!("memory.tiers.embed_batch", count = texts.len());
150 let vecs = provider.embed_batch(&texts).instrument(span).await;
151 match vecs {
152 Ok(vecs) => {
153 if vecs.len() != texts.len() {
154 tracing::warn!(
155 expected = texts.len(),
156 got = vecs.len(),
157 "tier promotion: embed_batch length mismatch, skipping sweep"
158 );
159 return Ok(stats);
160 }
161 candidates.into_iter().zip(vecs).collect()
162 }
163 Err(e) => {
164 tracing::warn!(error = %e, "tier promotion: batch embed failed, skipping sweep");
165 return Ok(stats);
166 }
167 }
168 } else {
169 candidates.into_iter().map(|c| (c, Vec::new())).collect()
171 };
172
173 if embedded.is_empty() {
174 return Ok(stats);
175 }
176
177 let threshold = config.similarity_threshold;
181 let clusters = cluster_by_similarity(embedded, threshold);
182
183 for cluster in clusters {
184 if cluster.len() < 2 {
185 tracing::debug!(
187 cluster_size = cluster.len(),
188 "tier promotion: singleton cluster skipped"
189 );
190 continue;
191 }
192
193 stats.clusters_formed += 1;
194
195 let source_conv_id = cluster[0].0.conversation_id;
196
197 match merge_cluster_and_promote(
198 store,
199 provider,
200 &cluster,
201 source_conv_id,
202 Duration::from_secs(config.embed_timeout_secs),
203 )
204 .await
205 {
206 Ok(()) => stats.promotions_completed += 1,
207 Err(e) => {
208 tracing::warn!(
209 cluster_size = cluster.len(),
210 error = %e,
211 "tier promotion: cluster merge failed, skipping"
212 );
213 stats.merge_failures += 1;
214 }
215 }
216 }
217
218 Ok(stats)
219}
220
221fn cluster_by_similarity(
227 candidates: Vec<(PromotionCandidate, Vec<f32>)>,
228 threshold: f32,
229) -> Vec<Vec<(PromotionCandidate, Vec<f32>)>> {
230 let mut clusters: Vec<Vec<(PromotionCandidate, Vec<f32>)>> = Vec::new();
231
232 'outer: for candidate in candidates {
233 if candidate.1.is_empty() {
234 clusters.push(vec![candidate]);
236 continue;
237 }
238
239 for cluster in &mut clusters {
240 let rep = &cluster[0].1;
241 if rep.is_empty() {
242 continue;
243 }
244 let sim = cosine_similarity(&candidate.1, rep);
245 if sim >= threshold {
246 cluster.push(candidate);
247 continue 'outer;
248 }
249 }
250
251 clusters.push(vec![candidate]);
252 }
253
254 clusters
255}
256
257#[tracing::instrument(name = "memory.tiers.merge_cluster_and_promote", skip_all)]
263async fn merge_cluster_and_promote(
264 store: &SqliteStore,
265 provider: &AnyProvider,
266 cluster: &[(PromotionCandidate, Vec<f32>)],
267 conversation_id: ConversationId,
268 embed_timeout: Duration,
269) -> Result<(), MemoryError> {
270 let contents: Vec<&str> = cluster.iter().map(|(c, _)| c.content.as_str()).collect();
271 let original_ids: Vec<crate::types::MessageId> = cluster.iter().map(|(c, _)| c.id).collect();
272
273 let merged = call_merge_llm(provider, &contents).await?;
274
275 let merged = merged.trim().to_owned();
277 if merged.is_empty() {
278 return Err(MemoryError::InvalidInput(
279 "LLM merge returned empty result".into(),
280 ));
281 }
282
283 if provider.supports_embeddings() {
286 let embeddings_available = cluster.iter().any(|(_, emb)| !emb.is_empty());
287 if embeddings_available {
288 match tokio::time::timeout(embed_timeout, provider.embed(&merged)).await {
289 Ok(Ok(merged_vec)) => {
290 let max_sim = cluster
291 .iter()
292 .filter(|(_, emb)| !emb.is_empty())
293 .map(|(_, emb)| cosine_similarity(&merged_vec, emb))
294 .fold(f32::NEG_INFINITY, f32::max);
295
296 if max_sim < MERGE_VALIDATION_MIN_SIMILARITY {
297 return Err(MemoryError::InvalidInput(format!(
298 "LLM merge validation failed: max similarity to originals = {max_sim:.3} < {MERGE_VALIDATION_MIN_SIMILARITY}"
299 )));
300 }
301 }
302 Ok(Err(e)) => {
303 tracing::warn!(
304 error = %e,
305 "tier promotion: failed to embed merged result, skipping similarity validation"
306 );
307 }
308 Err(_) => {
309 tracing::warn!(
310 "tier promotion: embed timed out, skipping similarity validation"
311 );
312 }
313 }
314 }
315 }
316
317 let delays_ms = [50u64, 100, 200];
320 for (attempt, &delay_ms) in delays_ms.iter().enumerate() {
321 match store
322 .promote_to_semantic(conversation_id, &merged, &original_ids)
323 .await
324 {
325 Ok(_) => break,
326 Err(e) => {
327 let is_busy = if let MemoryError::Sqlx(sqlx::Error::Database(ref db_err)) = e {
332 db_err.code().as_deref() == Some("5")
333 } else {
334 e.to_string().contains("database is locked")
335 };
336 if is_busy && attempt < delays_ms.len() - 1 {
337 tracing::warn!(
338 attempt = attempt + 1,
339 delay_ms,
340 "tier promotion: SQLite busy, retrying"
341 );
342 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
343 } else {
344 return Err(e);
345 }
346 }
347 }
348 }
349 tracing::debug!(
350 cluster_size = cluster.len(),
351 merged_len = merged.len(),
352 "tier promotion: cluster promoted to semantic"
353 );
354
355 Ok(())
356}
357
358async fn call_merge_llm(provider: &AnyProvider, contents: &[&str]) -> Result<String, MemoryError> {
360 use zeph_llm::provider::{Message, MessageMetadata, Role};
361
362 let bullet_list: String = contents
363 .iter()
364 .enumerate()
365 .map(|(i, c)| format!("{}. {c}", i + 1))
366 .collect::<Vec<_>>()
367 .join("\n");
368
369 let system_content = "You are a memory consolidation agent. \
370 Merge the following episodic memories into a single concise semantic fact. \
371 Strip timestamps, session context, hedging, and filler. \
372 Output ONLY the distilled fact as a single plain-text sentence or short paragraph. \
373 Do not add prefixes like 'The user...' or 'Fact:'.";
374
375 let user_content =
376 format!("Merge these episodic memories into one semantic fact:\n\n{bullet_list}");
377
378 let messages = vec![
379 Message {
380 role: Role::System,
381 content: system_content.to_owned(),
382 parts: vec![],
383 metadata: MessageMetadata::default(),
384 },
385 Message {
386 role: Role::User,
387 content: user_content,
388 parts: vec![],
389 metadata: MessageMetadata::default(),
390 },
391 ];
392
393 let timeout = Duration::from_secs(15);
394
395 let result = tokio::time::timeout(timeout, provider.chat(&messages))
396 .await
397 .map_err(|_| MemoryError::Timeout("LLM merge timed out after 15s".into()))?
398 .map_err(MemoryError::Llm)?;
399
400 Ok(result)
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn cluster_by_similarity_groups_identical() {
409 let v1 = vec![1.0f32, 0.0, 0.0];
411 let v2 = vec![1.0f32, 0.0, 0.0];
412 let v3 = vec![0.0f32, 1.0, 0.0]; let candidates = vec![
415 (make_candidate(1), v1),
416 (make_candidate(2), v2),
417 (make_candidate(3), v3),
418 ];
419
420 let clusters = cluster_by_similarity(candidates, 0.92f32);
421 assert_eq!(clusters.len(), 2, "should produce 2 clusters");
422 assert_eq!(clusters[0].len(), 2, "first cluster should have 2 members");
423 assert_eq!(clusters[1].len(), 1, "second cluster is the orthogonal one");
424 }
425
426 #[test]
427 fn cluster_by_similarity_empty_embeddings_become_singletons() {
428 let candidates = vec![(make_candidate(1), vec![]), (make_candidate(2), vec![])];
429 let clusters = cluster_by_similarity(candidates, 0.92);
430 assert_eq!(clusters.len(), 2);
431 }
432
433 fn make_candidate(id: i64) -> PromotionCandidate {
434 PromotionCandidate {
435 id: crate::types::MessageId(id),
436 conversation_id: ConversationId(1),
437 content: format!("content {id}"),
438 session_count: 3,
439 importance_score: 0.5,
440 }
441 }
442
443 #[tokio::test]
449 async fn merge_validation_embed_failure_is_fail_open() {
450 let store = crate::store::SqliteStore::new(":memory:").await.unwrap();
451 let conv_id = store.create_conversation().await.unwrap();
452 let m1 = store
453 .save_message(conv_id, "user", "Alice uses Rust")
454 .await
455 .unwrap();
456 let m2 = store
457 .save_message(conv_id, "user", "Alice loves Rust")
458 .await
459 .unwrap();
460
461 let provider = zeph_llm::any::AnyProvider::Mock(
464 zeph_llm::mock::MockProvider::with_responses(vec!["Alice uses and loves Rust".into()])
465 .with_embed_invalid_input(),
466 );
467
468 let cluster = vec![
469 (make_candidate(m1.0), vec![1.0_f32, 0.0, 0.0]),
470 (make_candidate(m2.0), vec![1.0_f32, 0.0, 0.0]),
471 ];
472
473 let result =
474 merge_cluster_and_promote(&store, &provider, &cluster, conv_id, Duration::from_secs(5))
475 .await;
476 assert!(
477 result.is_ok(),
478 "embed failure during merge validation must be fail-open (Ok), got {result:?}"
479 );
480 }
481}