1use std::sync::Arc;
18use std::time::Duration;
19
20use tokio::task::JoinHandle;
21use tokio_util::sync::CancellationToken;
22use zeph_llm::any::AnyProvider;
23use zeph_llm::provider::LlmProvider as _;
24
25use crate::error::MemoryError;
26use crate::math::cosine_similarity;
27use crate::sqlite::SqliteStore;
28use crate::sqlite::messages::PromotionCandidate;
29use crate::types::ConversationId;
30
31const MERGE_VALIDATION_MIN_SIMILARITY: f32 = 0.7;
34
35#[derive(Debug, Clone)]
39pub struct TierPromotionConfig {
40 pub enabled: bool,
41 pub promotion_min_sessions: u32,
42 pub similarity_threshold: f32,
43 pub sweep_interval_secs: u64,
44 pub sweep_batch_size: usize,
45}
46
47pub fn start_tier_promotion_loop(
60 store: Arc<SqliteStore>,
61 provider: AnyProvider,
62 config: TierPromotionConfig,
63 cancel: CancellationToken,
64) -> JoinHandle<()> {
65 tokio::spawn(async move {
66 if !config.enabled {
67 tracing::debug!("tier promotion disabled (tiers.enabled = false)");
68 return;
69 }
70
71 let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
72 ticker.tick().await;
74
75 loop {
76 tokio::select! {
77 () = cancel.cancelled() => {
78 tracing::debug!("tier promotion loop shutting down");
79 return;
80 }
81 _ = ticker.tick() => {}
82 }
83
84 tracing::debug!("tier promotion: starting sweep");
85 let start = std::time::Instant::now();
86
87 let result = run_promotion_sweep(&store, &provider, &config).await;
88
89 let elapsed_ms = start.elapsed().as_millis();
90
91 match result {
92 Ok(stats) => {
93 tracing::info!(
94 candidates = stats.candidates_evaluated,
95 clusters = stats.clusters_formed,
96 promoted = stats.promotions_completed,
97 merge_failures = stats.merge_failures,
98 elapsed_ms,
99 "tier promotion: sweep complete"
100 );
101 }
102 Err(e) => {
103 tracing::warn!(error = %e, elapsed_ms, "tier promotion: sweep failed, will retry");
104 }
105 }
106 }
107 })
108}
109
110#[derive(Debug, Default)]
112struct SweepStats {
113 candidates_evaluated: usize,
114 clusters_formed: usize,
115 promotions_completed: usize,
116 merge_failures: usize,
117}
118
119async fn run_promotion_sweep(
121 store: &SqliteStore,
122 provider: &AnyProvider,
123 config: &TierPromotionConfig,
124) -> Result<SweepStats, MemoryError> {
125 let candidates = store
126 .find_promotion_candidates(config.promotion_min_sessions, config.sweep_batch_size)
127 .await?;
128
129 if candidates.is_empty() {
130 return Ok(SweepStats::default());
131 }
132
133 let mut stats = SweepStats {
134 candidates_evaluated: candidates.len(),
135 ..SweepStats::default()
136 };
137
138 let mut embedded: Vec<(PromotionCandidate, Vec<f32>)> = Vec::with_capacity(candidates.len());
140 for candidate in candidates {
141 if !provider.supports_embeddings() {
142 embedded.push((candidate, Vec::new()));
144 continue;
145 }
146 match provider.embed(&candidate.content).await {
147 Ok(vec) => embedded.push((candidate, vec)),
148 Err(e) => {
149 tracing::warn!(
150 message_id = candidate.id.0,
151 error = %e,
152 "tier promotion: failed to embed candidate, skipping"
153 );
154 }
155 }
156 }
157
158 if embedded.is_empty() {
159 return Ok(stats);
160 }
161
162 let threshold = config.similarity_threshold;
166 let clusters = cluster_by_similarity(embedded, threshold);
167
168 for cluster in clusters {
169 if cluster.len() < 2 {
170 tracing::debug!(
172 cluster_size = cluster.len(),
173 "tier promotion: singleton cluster skipped"
174 );
175 continue;
176 }
177
178 stats.clusters_formed += 1;
179
180 let source_conv_id = cluster[0].0.conversation_id;
181
182 match merge_cluster_and_promote(store, provider, &cluster, source_conv_id).await {
183 Ok(()) => stats.promotions_completed += 1,
184 Err(e) => {
185 tracing::warn!(
186 cluster_size = cluster.len(),
187 error = %e,
188 "tier promotion: cluster merge failed, skipping"
189 );
190 stats.merge_failures += 1;
191 }
192 }
193 }
194
195 Ok(stats)
196}
197
198fn cluster_by_similarity(
204 candidates: Vec<(PromotionCandidate, Vec<f32>)>,
205 threshold: f32,
206) -> Vec<Vec<(PromotionCandidate, Vec<f32>)>> {
207 let mut clusters: Vec<Vec<(PromotionCandidate, Vec<f32>)>> = Vec::new();
208
209 'outer: for candidate in candidates {
210 if candidate.1.is_empty() {
211 clusters.push(vec![candidate]);
213 continue;
214 }
215
216 for cluster in &mut clusters {
217 let rep = &cluster[0].1;
218 if rep.is_empty() {
219 continue;
220 }
221 let sim = cosine_similarity(&candidate.1, rep);
222 if sim >= threshold {
223 cluster.push(candidate);
224 continue 'outer;
225 }
226 }
227
228 clusters.push(vec![candidate]);
229 }
230
231 clusters
232}
233
234async fn merge_cluster_and_promote(
240 store: &SqliteStore,
241 provider: &AnyProvider,
242 cluster: &[(PromotionCandidate, Vec<f32>)],
243 conversation_id: ConversationId,
244) -> Result<(), MemoryError> {
245 let contents: Vec<&str> = cluster.iter().map(|(c, _)| c.content.as_str()).collect();
246 let original_ids: Vec<crate::types::MessageId> = cluster.iter().map(|(c, _)| c.id).collect();
247
248 let merged = call_merge_llm(provider, &contents).await?;
249
250 let merged = merged.trim().to_owned();
252 if merged.is_empty() {
253 return Err(MemoryError::Other("LLM merge returned empty result".into()));
254 }
255
256 if provider.supports_embeddings() {
259 let embeddings_available = cluster.iter().any(|(_, emb)| !emb.is_empty());
260 if embeddings_available {
261 match provider.embed(&merged).await {
262 Ok(merged_vec) => {
263 let max_sim = cluster
264 .iter()
265 .filter(|(_, emb)| !emb.is_empty())
266 .map(|(_, emb)| cosine_similarity(&merged_vec, emb))
267 .fold(f32::NEG_INFINITY, f32::max);
268
269 if max_sim < MERGE_VALIDATION_MIN_SIMILARITY {
270 return Err(MemoryError::Other(format!(
271 "LLM merge validation failed: max similarity to originals = {max_sim:.3} < {MERGE_VALIDATION_MIN_SIMILARITY}"
272 )));
273 }
274 }
275 Err(e) => {
276 tracing::warn!(
277 error = %e,
278 "tier promotion: failed to embed merged result, skipping similarity validation"
279 );
280 }
281 }
282 }
283 }
284
285 store
286 .promote_to_semantic(conversation_id, &merged, &original_ids)
287 .await?;
288
289 tracing::debug!(
290 cluster_size = cluster.len(),
291 merged_len = merged.len(),
292 "tier promotion: cluster promoted to semantic"
293 );
294
295 Ok(())
296}
297
298async fn call_merge_llm(provider: &AnyProvider, contents: &[&str]) -> Result<String, MemoryError> {
300 use zeph_llm::provider::{Message, MessageMetadata, Role};
301
302 let bullet_list: String = contents
303 .iter()
304 .enumerate()
305 .map(|(i, c)| format!("{}. {c}", i + 1))
306 .collect::<Vec<_>>()
307 .join("\n");
308
309 let system_content = "You are a memory consolidation agent. \
310 Merge the following episodic memories into a single concise semantic fact. \
311 Strip timestamps, session context, hedging, and filler. \
312 Output ONLY the distilled fact as a single plain-text sentence or short paragraph. \
313 Do not add prefixes like 'The user...' or 'Fact:'.";
314
315 let user_content =
316 format!("Merge these episodic memories into one semantic fact:\n\n{bullet_list}");
317
318 let messages = vec![
319 Message {
320 role: Role::System,
321 content: system_content.to_owned(),
322 parts: vec![],
323 metadata: MessageMetadata::default(),
324 },
325 Message {
326 role: Role::User,
327 content: user_content,
328 parts: vec![],
329 metadata: MessageMetadata::default(),
330 },
331 ];
332
333 let timeout = Duration::from_secs(15);
334
335 let result = tokio::time::timeout(timeout, provider.chat(&messages))
336 .await
337 .map_err(|_| MemoryError::Other("LLM merge timed out after 15s".into()))?
338 .map_err(MemoryError::Llm)?;
339
340 Ok(result)
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn cluster_by_similarity_groups_identical() {
349 let v1 = vec![1.0f32, 0.0, 0.0];
351 let v2 = vec![1.0f32, 0.0, 0.0];
352 let v3 = vec![0.0f32, 1.0, 0.0]; let candidates = vec![
355 (make_candidate(1), v1),
356 (make_candidate(2), v2),
357 (make_candidate(3), v3),
358 ];
359
360 let clusters = cluster_by_similarity(candidates, 0.92f32);
361 assert_eq!(clusters.len(), 2, "should produce 2 clusters");
362 assert_eq!(clusters[0].len(), 2, "first cluster should have 2 members");
363 assert_eq!(clusters[1].len(), 1, "second cluster is the orthogonal one");
364 }
365
366 #[test]
367 fn cluster_by_similarity_empty_embeddings_become_singletons() {
368 let candidates = vec![(make_candidate(1), vec![]), (make_candidate(2), vec![])];
369 let clusters = cluster_by_similarity(candidates, 0.92);
370 assert_eq!(clusters.len(), 2);
371 }
372
373 fn make_candidate(id: i64) -> PromotionCandidate {
374 PromotionCandidate {
375 id: crate::types::MessageId(id),
376 conversation_id: ConversationId(1),
377 content: format!("content {id}"),
378 session_count: 3,
379 importance_score: 0.5,
380 }
381 }
382}