1use std::sync::Arc;
18use std::time::Duration;
19
20use tokio::task::JoinHandle;
21use tokio_util::sync::CancellationToken;
22use zeph_llm::any::AnyProvider;
23use zeph_llm::provider::LlmProvider as _;
24
25use crate::error::MemoryError;
26use crate::store::SqliteStore;
27use crate::store::messages::PromotionCandidate;
28use crate::types::ConversationId;
29use zeph_common::math::cosine_similarity;
30
31const MERGE_VALIDATION_MIN_SIMILARITY: f32 = 0.7;
34
35#[derive(Debug, Clone)]
39pub struct TierPromotionConfig {
40 pub enabled: bool,
42 pub promotion_min_sessions: u32,
45 pub similarity_threshold: f32,
48 pub sweep_interval_secs: u64,
50 pub sweep_batch_size: usize,
52}
53
54pub fn start_tier_promotion_loop(
67 store: Arc<SqliteStore>,
68 provider: AnyProvider,
69 config: TierPromotionConfig,
70 cancel: CancellationToken,
71) -> JoinHandle<()> {
72 tokio::spawn(async move {
73 if !config.enabled {
74 tracing::debug!("tier promotion disabled (tiers.enabled = false)");
75 return;
76 }
77
78 let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
79 ticker.tick().await;
81
82 loop {
83 tokio::select! {
84 () = cancel.cancelled() => {
85 tracing::debug!("tier promotion loop shutting down");
86 return;
87 }
88 _ = ticker.tick() => {}
89 }
90
91 tracing::debug!("tier promotion: starting sweep");
92 let start = std::time::Instant::now();
93
94 let result = run_promotion_sweep(&store, &provider, &config).await;
95
96 let elapsed_ms = start.elapsed().as_millis();
97
98 match result {
99 Ok(stats) => {
100 tracing::info!(
101 candidates = stats.candidates_evaluated,
102 clusters = stats.clusters_formed,
103 promoted = stats.promotions_completed,
104 merge_failures = stats.merge_failures,
105 elapsed_ms,
106 "tier promotion: sweep complete"
107 );
108 }
109 Err(e) => {
110 tracing::warn!(error = %e, elapsed_ms, "tier promotion: sweep failed, will retry");
111 }
112 }
113 }
114 })
115}
116
117#[derive(Debug, Default)]
119struct SweepStats {
120 candidates_evaluated: usize,
121 clusters_formed: usize,
122 promotions_completed: usize,
123 merge_failures: usize,
124}
125
126async fn run_promotion_sweep(
128 store: &SqliteStore,
129 provider: &AnyProvider,
130 config: &TierPromotionConfig,
131) -> Result<SweepStats, MemoryError> {
132 let candidates = store
133 .find_promotion_candidates(config.promotion_min_sessions, config.sweep_batch_size)
134 .await?;
135
136 if candidates.is_empty() {
137 return Ok(SweepStats::default());
138 }
139
140 let mut stats = SweepStats {
141 candidates_evaluated: candidates.len(),
142 ..SweepStats::default()
143 };
144
145 let mut embedded: Vec<(PromotionCandidate, Vec<f32>)> = Vec::with_capacity(candidates.len());
147 for candidate in candidates {
148 if !provider.supports_embeddings() {
149 embedded.push((candidate, Vec::new()));
151 continue;
152 }
153 match provider.embed(&candidate.content).await {
154 Ok(vec) => embedded.push((candidate, vec)),
155 Err(e) => {
156 tracing::warn!(
157 message_id = candidate.id.0,
158 error = %e,
159 "tier promotion: failed to embed candidate, skipping"
160 );
161 }
162 }
163 }
164
165 if embedded.is_empty() {
166 return Ok(stats);
167 }
168
169 let threshold = config.similarity_threshold;
173 let clusters = cluster_by_similarity(embedded, threshold);
174
175 for cluster in clusters {
176 if cluster.len() < 2 {
177 tracing::debug!(
179 cluster_size = cluster.len(),
180 "tier promotion: singleton cluster skipped"
181 );
182 continue;
183 }
184
185 stats.clusters_formed += 1;
186
187 let source_conv_id = cluster[0].0.conversation_id;
188
189 match merge_cluster_and_promote(store, provider, &cluster, source_conv_id).await {
190 Ok(()) => stats.promotions_completed += 1,
191 Err(e) => {
192 tracing::warn!(
193 cluster_size = cluster.len(),
194 error = %e,
195 "tier promotion: cluster merge failed, skipping"
196 );
197 stats.merge_failures += 1;
198 }
199 }
200 }
201
202 Ok(stats)
203}
204
205fn cluster_by_similarity(
211 candidates: Vec<(PromotionCandidate, Vec<f32>)>,
212 threshold: f32,
213) -> Vec<Vec<(PromotionCandidate, Vec<f32>)>> {
214 let mut clusters: Vec<Vec<(PromotionCandidate, Vec<f32>)>> = Vec::new();
215
216 'outer: for candidate in candidates {
217 if candidate.1.is_empty() {
218 clusters.push(vec![candidate]);
220 continue;
221 }
222
223 for cluster in &mut clusters {
224 let rep = &cluster[0].1;
225 if rep.is_empty() {
226 continue;
227 }
228 let sim = cosine_similarity(&candidate.1, rep);
229 if sim >= threshold {
230 cluster.push(candidate);
231 continue 'outer;
232 }
233 }
234
235 clusters.push(vec![candidate]);
236 }
237
238 clusters
239}
240
241async fn merge_cluster_and_promote(
247 store: &SqliteStore,
248 provider: &AnyProvider,
249 cluster: &[(PromotionCandidate, Vec<f32>)],
250 conversation_id: ConversationId,
251) -> Result<(), MemoryError> {
252 let contents: Vec<&str> = cluster.iter().map(|(c, _)| c.content.as_str()).collect();
253 let original_ids: Vec<crate::types::MessageId> = cluster.iter().map(|(c, _)| c.id).collect();
254
255 let merged = call_merge_llm(provider, &contents).await?;
256
257 let merged = merged.trim().to_owned();
259 if merged.is_empty() {
260 return Err(MemoryError::Other("LLM merge returned empty result".into()));
261 }
262
263 if provider.supports_embeddings() {
266 let embeddings_available = cluster.iter().any(|(_, emb)| !emb.is_empty());
267 if embeddings_available {
268 match provider.embed(&merged).await {
269 Ok(merged_vec) => {
270 let max_sim = cluster
271 .iter()
272 .filter(|(_, emb)| !emb.is_empty())
273 .map(|(_, emb)| cosine_similarity(&merged_vec, emb))
274 .fold(f32::NEG_INFINITY, f32::max);
275
276 if max_sim < MERGE_VALIDATION_MIN_SIMILARITY {
277 return Err(MemoryError::Other(format!(
278 "LLM merge validation failed: max similarity to originals = {max_sim:.3} < {MERGE_VALIDATION_MIN_SIMILARITY}"
279 )));
280 }
281 }
282 Err(e) => {
283 tracing::warn!(
284 error = %e,
285 "tier promotion: failed to embed merged result, skipping similarity validation"
286 );
287 }
288 }
289 }
290 }
291
292 let delays_ms = [50u64, 100, 200];
295 for (attempt, &delay_ms) in delays_ms.iter().enumerate() {
296 match store
297 .promote_to_semantic(conversation_id, &merged, &original_ids)
298 .await
299 {
300 Ok(_) => break,
301 Err(e) => {
302 let is_busy = if let MemoryError::Sqlx(sqlx::Error::Database(ref db_err)) = e {
307 db_err.code().as_deref() == Some("5")
308 } else {
309 e.to_string().contains("database is locked")
310 };
311 if is_busy && attempt < delays_ms.len() - 1 {
312 tracing::warn!(
313 attempt = attempt + 1,
314 delay_ms,
315 "tier promotion: SQLite busy, retrying"
316 );
317 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
318 } else {
319 return Err(e);
320 }
321 }
322 }
323 }
324 tracing::debug!(
325 cluster_size = cluster.len(),
326 merged_len = merged.len(),
327 "tier promotion: cluster promoted to semantic"
328 );
329
330 Ok(())
331}
332
333async fn call_merge_llm(provider: &AnyProvider, contents: &[&str]) -> Result<String, MemoryError> {
335 use zeph_llm::provider::{Message, MessageMetadata, Role};
336
337 let bullet_list: String = contents
338 .iter()
339 .enumerate()
340 .map(|(i, c)| format!("{}. {c}", i + 1))
341 .collect::<Vec<_>>()
342 .join("\n");
343
344 let system_content = "You are a memory consolidation agent. \
345 Merge the following episodic memories into a single concise semantic fact. \
346 Strip timestamps, session context, hedging, and filler. \
347 Output ONLY the distilled fact as a single plain-text sentence or short paragraph. \
348 Do not add prefixes like 'The user...' or 'Fact:'.";
349
350 let user_content =
351 format!("Merge these episodic memories into one semantic fact:\n\n{bullet_list}");
352
353 let messages = vec![
354 Message {
355 role: Role::System,
356 content: system_content.to_owned(),
357 parts: vec![],
358 metadata: MessageMetadata::default(),
359 },
360 Message {
361 role: Role::User,
362 content: user_content,
363 parts: vec![],
364 metadata: MessageMetadata::default(),
365 },
366 ];
367
368 let timeout = Duration::from_secs(15);
369
370 let result = tokio::time::timeout(timeout, provider.chat(&messages))
371 .await
372 .map_err(|_| MemoryError::Other("LLM merge timed out after 15s".into()))?
373 .map_err(MemoryError::Llm)?;
374
375 Ok(result)
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn cluster_by_similarity_groups_identical() {
384 let v1 = vec![1.0f32, 0.0, 0.0];
386 let v2 = vec![1.0f32, 0.0, 0.0];
387 let v3 = vec![0.0f32, 1.0, 0.0]; let candidates = vec![
390 (make_candidate(1), v1),
391 (make_candidate(2), v2),
392 (make_candidate(3), v3),
393 ];
394
395 let clusters = cluster_by_similarity(candidates, 0.92f32);
396 assert_eq!(clusters.len(), 2, "should produce 2 clusters");
397 assert_eq!(clusters[0].len(), 2, "first cluster should have 2 members");
398 assert_eq!(clusters[1].len(), 1, "second cluster is the orthogonal one");
399 }
400
401 #[test]
402 fn cluster_by_similarity_empty_embeddings_become_singletons() {
403 let candidates = vec![(make_candidate(1), vec![]), (make_candidate(2), vec![])];
404 let clusters = cluster_by_similarity(candidates, 0.92);
405 assert_eq!(clusters.len(), 2);
406 }
407
408 fn make_candidate(id: i64) -> PromotionCandidate {
409 PromotionCandidate {
410 id: crate::types::MessageId(id),
411 conversation_id: ConversationId(1),
412 content: format!("content {id}"),
413 session_count: 3,
414 importance_score: 0.5,
415 }
416 }
417}