1use std::sync::Arc;
18use std::time::Duration;
19
20use tokio_util::sync::CancellationToken;
21use tracing::Instrument as _;
22use zeph_llm::any::AnyProvider;
23use zeph_llm::provider::LlmProvider as _;
24
25use crate::error::MemoryError;
26use crate::store::SqliteStore;
27use crate::store::messages::PromotionCandidate;
28use crate::types::ConversationId;
29use zeph_common::math::cosine_similarity;
30
31const MERGE_VALIDATION_MIN_SIMILARITY: f32 = 0.7;
34
35#[derive(Debug, Clone)]
39pub struct TierPromotionConfig {
40 pub enabled: bool,
42 pub promotion_min_sessions: u32,
45 pub similarity_threshold: f32,
48 pub sweep_interval_secs: u64,
50 pub sweep_batch_size: usize,
52}
53
54pub async fn start_tier_promotion_loop(
67 store: Arc<SqliteStore>,
68 provider: AnyProvider,
69 config: TierPromotionConfig,
70 cancel: CancellationToken,
71) {
72 if !config.enabled {
73 tracing::debug!("tier promotion disabled (tiers.enabled = false)");
74 return;
75 }
76
77 let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
78 ticker.tick().await;
80
81 loop {
82 tokio::select! {
83 () = cancel.cancelled() => {
84 tracing::debug!("tier promotion loop shutting down");
85 return;
86 }
87 _ = ticker.tick() => {}
88 }
89
90 tracing::debug!("tier promotion: starting sweep");
91 let start = std::time::Instant::now();
92
93 let result = run_promotion_sweep(&store, &provider, &config).await;
94
95 let elapsed_ms = start.elapsed().as_millis();
96
97 match result {
98 Ok(stats) => {
99 tracing::info!(
100 candidates = stats.candidates_evaluated,
101 clusters = stats.clusters_formed,
102 promoted = stats.promotions_completed,
103 merge_failures = stats.merge_failures,
104 elapsed_ms,
105 "tier promotion: sweep complete"
106 );
107 }
108 Err(e) => {
109 tracing::warn!(error = %e, elapsed_ms, "tier promotion: sweep failed, will retry");
110 }
111 }
112 }
113}
114
115#[derive(Debug, Default)]
117struct SweepStats {
118 candidates_evaluated: usize,
119 clusters_formed: usize,
120 promotions_completed: usize,
121 merge_failures: usize,
122}
123
124#[tracing::instrument(name = "memory.tiers.promotion_sweep", skip_all)]
126async fn run_promotion_sweep(
127 store: &SqliteStore,
128 provider: &AnyProvider,
129 config: &TierPromotionConfig,
130) -> Result<SweepStats, MemoryError> {
131 let candidates = store
132 .find_promotion_candidates(config.promotion_min_sessions, config.sweep_batch_size)
133 .await?;
134
135 if candidates.is_empty() {
136 return Ok(SweepStats::default());
137 }
138
139 let mut stats = SweepStats {
140 candidates_evaluated: candidates.len(),
141 ..SweepStats::default()
142 };
143
144 let embedded: Vec<(PromotionCandidate, Vec<f32>)> = if provider.supports_embeddings() {
146 let texts: Vec<&str> = candidates.iter().map(|c| c.content.as_str()).collect();
147 let span = tracing::info_span!("memory.tiers.embed_batch", count = texts.len());
148 let vecs = provider.embed_batch(&texts).instrument(span).await;
149 match vecs {
150 Ok(vecs) => {
151 if vecs.len() != texts.len() {
152 tracing::warn!(
153 expected = texts.len(),
154 got = vecs.len(),
155 "tier promotion: embed_batch length mismatch, skipping sweep"
156 );
157 return Ok(stats);
158 }
159 candidates.into_iter().zip(vecs).collect()
160 }
161 Err(e) => {
162 tracing::warn!(error = %e, "tier promotion: batch embed failed, skipping sweep");
163 return Ok(stats);
164 }
165 }
166 } else {
167 candidates.into_iter().map(|c| (c, Vec::new())).collect()
169 };
170
171 if embedded.is_empty() {
172 return Ok(stats);
173 }
174
175 let threshold = config.similarity_threshold;
179 let clusters = cluster_by_similarity(embedded, threshold);
180
181 for cluster in clusters {
182 if cluster.len() < 2 {
183 tracing::debug!(
185 cluster_size = cluster.len(),
186 "tier promotion: singleton cluster skipped"
187 );
188 continue;
189 }
190
191 stats.clusters_formed += 1;
192
193 let source_conv_id = cluster[0].0.conversation_id;
194
195 match merge_cluster_and_promote(store, provider, &cluster, source_conv_id).await {
196 Ok(()) => stats.promotions_completed += 1,
197 Err(e) => {
198 tracing::warn!(
199 cluster_size = cluster.len(),
200 error = %e,
201 "tier promotion: cluster merge failed, skipping"
202 );
203 stats.merge_failures += 1;
204 }
205 }
206 }
207
208 Ok(stats)
209}
210
211fn cluster_by_similarity(
217 candidates: Vec<(PromotionCandidate, Vec<f32>)>,
218 threshold: f32,
219) -> Vec<Vec<(PromotionCandidate, Vec<f32>)>> {
220 let mut clusters: Vec<Vec<(PromotionCandidate, Vec<f32>)>> = Vec::new();
221
222 'outer: for candidate in candidates {
223 if candidate.1.is_empty() {
224 clusters.push(vec![candidate]);
226 continue;
227 }
228
229 for cluster in &mut clusters {
230 let rep = &cluster[0].1;
231 if rep.is_empty() {
232 continue;
233 }
234 let sim = cosine_similarity(&candidate.1, rep);
235 if sim >= threshold {
236 cluster.push(candidate);
237 continue 'outer;
238 }
239 }
240
241 clusters.push(vec![candidate]);
242 }
243
244 clusters
245}
246
247#[tracing::instrument(name = "memory.tiers.merge_cluster_and_promote", skip_all)]
253async fn merge_cluster_and_promote(
254 store: &SqliteStore,
255 provider: &AnyProvider,
256 cluster: &[(PromotionCandidate, Vec<f32>)],
257 conversation_id: ConversationId,
258) -> Result<(), MemoryError> {
259 let contents: Vec<&str> = cluster.iter().map(|(c, _)| c.content.as_str()).collect();
260 let original_ids: Vec<crate::types::MessageId> = cluster.iter().map(|(c, _)| c.id).collect();
261
262 let merged = call_merge_llm(provider, &contents).await?;
263
264 let merged = merged.trim().to_owned();
266 if merged.is_empty() {
267 return Err(MemoryError::InvalidInput(
268 "LLM merge returned empty result".into(),
269 ));
270 }
271
272 if provider.supports_embeddings() {
275 let embeddings_available = cluster.iter().any(|(_, emb)| !emb.is_empty());
276 if embeddings_available {
277 match tokio::time::timeout(Duration::from_secs(5), provider.embed(&merged)).await {
278 Ok(Ok(merged_vec)) => {
279 let max_sim = cluster
280 .iter()
281 .filter(|(_, emb)| !emb.is_empty())
282 .map(|(_, emb)| cosine_similarity(&merged_vec, emb))
283 .fold(f32::NEG_INFINITY, f32::max);
284
285 if max_sim < MERGE_VALIDATION_MIN_SIMILARITY {
286 return Err(MemoryError::InvalidInput(format!(
287 "LLM merge validation failed: max similarity to originals = {max_sim:.3} < {MERGE_VALIDATION_MIN_SIMILARITY}"
288 )));
289 }
290 }
291 Ok(Err(e)) => {
292 tracing::warn!(
293 error = %e,
294 "tier promotion: failed to embed merged result, skipping similarity validation"
295 );
296 }
297 Err(_) => {
298 tracing::warn!(
299 "tier promotion: embed timed out, skipping similarity validation"
300 );
301 }
302 }
303 }
304 }
305
306 let delays_ms = [50u64, 100, 200];
309 for (attempt, &delay_ms) in delays_ms.iter().enumerate() {
310 match store
311 .promote_to_semantic(conversation_id, &merged, &original_ids)
312 .await
313 {
314 Ok(_) => break,
315 Err(e) => {
316 let is_busy = if let MemoryError::Sqlx(sqlx::Error::Database(ref db_err)) = e {
321 db_err.code().as_deref() == Some("5")
322 } else {
323 e.to_string().contains("database is locked")
324 };
325 if is_busy && attempt < delays_ms.len() - 1 {
326 tracing::warn!(
327 attempt = attempt + 1,
328 delay_ms,
329 "tier promotion: SQLite busy, retrying"
330 );
331 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
332 } else {
333 return Err(e);
334 }
335 }
336 }
337 }
338 tracing::debug!(
339 cluster_size = cluster.len(),
340 merged_len = merged.len(),
341 "tier promotion: cluster promoted to semantic"
342 );
343
344 Ok(())
345}
346
347async fn call_merge_llm(provider: &AnyProvider, contents: &[&str]) -> Result<String, MemoryError> {
349 use zeph_llm::provider::{Message, MessageMetadata, Role};
350
351 let bullet_list: String = contents
352 .iter()
353 .enumerate()
354 .map(|(i, c)| format!("{}. {c}", i + 1))
355 .collect::<Vec<_>>()
356 .join("\n");
357
358 let system_content = "You are a memory consolidation agent. \
359 Merge the following episodic memories into a single concise semantic fact. \
360 Strip timestamps, session context, hedging, and filler. \
361 Output ONLY the distilled fact as a single plain-text sentence or short paragraph. \
362 Do not add prefixes like 'The user...' or 'Fact:'.";
363
364 let user_content =
365 format!("Merge these episodic memories into one semantic fact:\n\n{bullet_list}");
366
367 let messages = vec![
368 Message {
369 role: Role::System,
370 content: system_content.to_owned(),
371 parts: vec![],
372 metadata: MessageMetadata::default(),
373 },
374 Message {
375 role: Role::User,
376 content: user_content,
377 parts: vec![],
378 metadata: MessageMetadata::default(),
379 },
380 ];
381
382 let timeout = Duration::from_secs(15);
383
384 let result = tokio::time::timeout(timeout, provider.chat(&messages))
385 .await
386 .map_err(|_| MemoryError::Timeout("LLM merge timed out after 15s".into()))?
387 .map_err(MemoryError::Llm)?;
388
389 Ok(result)
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395
396 #[test]
397 fn cluster_by_similarity_groups_identical() {
398 let v1 = vec![1.0f32, 0.0, 0.0];
400 let v2 = vec![1.0f32, 0.0, 0.0];
401 let v3 = vec![0.0f32, 1.0, 0.0]; let candidates = vec![
404 (make_candidate(1), v1),
405 (make_candidate(2), v2),
406 (make_candidate(3), v3),
407 ];
408
409 let clusters = cluster_by_similarity(candidates, 0.92f32);
410 assert_eq!(clusters.len(), 2, "should produce 2 clusters");
411 assert_eq!(clusters[0].len(), 2, "first cluster should have 2 members");
412 assert_eq!(clusters[1].len(), 1, "second cluster is the orthogonal one");
413 }
414
415 #[test]
416 fn cluster_by_similarity_empty_embeddings_become_singletons() {
417 let candidates = vec![(make_candidate(1), vec![]), (make_candidate(2), vec![])];
418 let clusters = cluster_by_similarity(candidates, 0.92);
419 assert_eq!(clusters.len(), 2);
420 }
421
422 fn make_candidate(id: i64) -> PromotionCandidate {
423 PromotionCandidate {
424 id: crate::types::MessageId(id),
425 conversation_id: ConversationId(1),
426 content: format!("content {id}"),
427 session_count: 3,
428 importance_score: 0.5,
429 }
430 }
431
432 #[tokio::test]
438 async fn merge_validation_embed_failure_is_fail_open() {
439 let store = crate::store::SqliteStore::new(":memory:").await.unwrap();
440 let conv_id = store.create_conversation().await.unwrap();
441 let m1 = store
442 .save_message(conv_id, "user", "Alice uses Rust")
443 .await
444 .unwrap();
445 let m2 = store
446 .save_message(conv_id, "user", "Alice loves Rust")
447 .await
448 .unwrap();
449
450 let provider = zeph_llm::any::AnyProvider::Mock(
453 zeph_llm::mock::MockProvider::with_responses(vec!["Alice uses and loves Rust".into()])
454 .with_embed_invalid_input(),
455 );
456
457 let cluster = vec![
458 (make_candidate(m1.0), vec![1.0_f32, 0.0, 0.0]),
459 (make_candidate(m2.0), vec![1.0_f32, 0.0, 0.0]),
460 ];
461
462 let result = merge_cluster_and_promote(&store, &provider, &cluster, conv_id).await;
463 assert!(
464 result.is_ok(),
465 "embed failure during merge validation must be fail-open (Ok), got {result:?}"
466 );
467 }
468}