Skip to main content

zeph_memory/
tiers.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! AOI three-layer memory tier promotion.
5//!
6//! Provides a background sweep loop that promotes frequently-accessed episodic messages
7//! to the semantic tier by:
8//! 1. Finding candidates with `session_count >= promotion_min_sessions`.
9//! 2. Grouping near-duplicate candidates by cosine similarity (greedy nearest-neighbor).
10//! 3. For each cluster with >= 2 messages, calling the LLM to distill a merged fact.
11//! 4. Validating the merge output (non-empty, similarity >= 0.7 to at least one original).
12//! 5. Inserting the semantic fact and soft-deleting the originals.
13//!
14//! The sweep respects a `CancellationToken` for graceful shutdown, following the
15//! same pattern as `eviction.rs`.
16
17use std::sync::Arc;
18use std::time::Duration;
19
20use tokio_util::sync::CancellationToken;
21use tracing::Instrument as _;
22use zeph_llm::any::AnyProvider;
23use zeph_llm::provider::LlmProvider as _;
24
25use crate::error::MemoryError;
26use crate::store::SqliteStore;
27use crate::store::messages::PromotionCandidate;
28use crate::types::ConversationId;
29use zeph_common::math::cosine_similarity;
30
31/// Minimum cosine similarity between the merged result and at least one original for the
32/// merge to be accepted. Prevents the LLM from producing semantically unrelated output.
33const MERGE_VALIDATION_MIN_SIMILARITY: f32 = 0.7;
34
35/// Configuration for the tier promotion sweep, passed from `zeph-config::TierPromotionConfig`.
36///
37/// Defined locally to avoid a direct dependency from `zeph-memory` on `zeph-config`.
38#[derive(Debug, Clone)]
39pub struct TierPromotionConfig {
40    /// Enable or disable the tier promotion loop.
41    pub enabled: bool,
42    /// Minimum number of distinct sessions in which a message must appear
43    /// before it becomes a promotion candidate.
44    pub promotion_min_sessions: u32,
45    /// Minimum cosine similarity for two messages to be considered duplicates
46    /// eligible for merging into one semantic fact.
47    pub similarity_threshold: f32,
48    /// How often to run a promotion sweep, in seconds.
49    pub sweep_interval_secs: u64,
50    /// Maximum number of candidates to process per sweep.
51    pub sweep_batch_size: usize,
52}
53
54/// Start the background tier promotion loop.
55///
56/// Each sweep cycle:
57/// 1. Fetches episodic candidates with `session_count >= config.promotion_min_sessions`.
58/// 2. Embeds candidates and clusters near-duplicates (cosine similarity >= threshold).
59/// 3. For each cluster, calls the LLM to merge into a single semantic fact.
60/// 4. Validates the merged output; skips the cluster on failure.
61/// 5. Promotes validated clusters to semantic tier.
62///
63/// The loop exits immediately if `config.enabled = false`.
64///
65/// Database and LLM errors are logged but do not stop the loop.
66pub async fn start_tier_promotion_loop(
67    store: Arc<SqliteStore>,
68    provider: AnyProvider,
69    config: TierPromotionConfig,
70    cancel: CancellationToken,
71) {
72    if !config.enabled {
73        tracing::debug!("tier promotion disabled (tiers.enabled = false)");
74        return;
75    }
76
77    let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
78    // Skip the first immediate tick so we don't run at startup.
79    ticker.tick().await;
80
81    loop {
82        tokio::select! {
83            () = cancel.cancelled() => {
84                tracing::debug!("tier promotion loop shutting down");
85                return;
86            }
87            _ = ticker.tick() => {}
88        }
89
90        tracing::debug!("tier promotion: starting sweep");
91        let start = std::time::Instant::now();
92
93        let result = run_promotion_sweep(&store, &provider, &config).await;
94
95        let elapsed_ms = start.elapsed().as_millis();
96
97        match result {
98            Ok(stats) => {
99                tracing::info!(
100                    candidates = stats.candidates_evaluated,
101                    clusters = stats.clusters_formed,
102                    promoted = stats.promotions_completed,
103                    merge_failures = stats.merge_failures,
104                    elapsed_ms,
105                    "tier promotion: sweep complete"
106                );
107            }
108            Err(e) => {
109                tracing::warn!(error = %e, elapsed_ms, "tier promotion: sweep failed, will retry");
110            }
111        }
112    }
113}
114
115/// Stats collected during a single promotion sweep.
116#[derive(Debug, Default)]
117struct SweepStats {
118    candidates_evaluated: usize,
119    clusters_formed: usize,
120    promotions_completed: usize,
121    merge_failures: usize,
122}
123
124/// Execute one full promotion sweep cycle.
125#[tracing::instrument(name = "memory.tiers.promotion_sweep", skip_all)]
126async fn run_promotion_sweep(
127    store: &SqliteStore,
128    provider: &AnyProvider,
129    config: &TierPromotionConfig,
130) -> Result<SweepStats, MemoryError> {
131    let candidates = store
132        .find_promotion_candidates(config.promotion_min_sessions, config.sweep_batch_size)
133        .await?;
134
135    if candidates.is_empty() {
136        return Ok(SweepStats::default());
137    }
138
139    let mut stats = SweepStats {
140        candidates_evaluated: candidates.len(),
141        ..SweepStats::default()
142    };
143
144    // Embed all candidates in a single batch call, then zip back with candidates.
145    let embedded: Vec<(PromotionCandidate, Vec<f32>)> = if provider.supports_embeddings() {
146        let texts: Vec<&str> = candidates.iter().map(|c| c.content.as_str()).collect();
147        let span = tracing::info_span!("memory.tiers.embed_batch", count = texts.len());
148        let vecs = provider.embed_batch(&texts).instrument(span).await;
149        match vecs {
150            Ok(vecs) => {
151                if vecs.len() != texts.len() {
152                    tracing::warn!(
153                        expected = texts.len(),
154                        got = vecs.len(),
155                        "tier promotion: embed_batch length mismatch, skipping sweep"
156                    );
157                    return Ok(stats);
158                }
159                candidates.into_iter().zip(vecs).collect()
160            }
161            Err(e) => {
162                tracing::warn!(error = %e, "tier promotion: batch embed failed, skipping sweep");
163                return Ok(stats);
164            }
165        }
166    } else {
167        // No embedding support — push all with empty vecs (will become singletons).
168        candidates.into_iter().map(|c| (c, Vec::new())).collect()
169    };
170
171    if embedded.is_empty() {
172        return Ok(stats);
173    }
174
175    // Cluster candidates by cosine similarity (greedy nearest-neighbor).
176    // Each candidate is assigned to the first existing cluster whose centroid
177    // representative has similarity >= threshold with it, or starts a new cluster.
178    let threshold = config.similarity_threshold;
179    let clusters = cluster_by_similarity(embedded, threshold);
180
181    for cluster in clusters {
182        if cluster.len() < 2 {
183            // Single-member cluster — no merge needed, skip to avoid unnecessary LLM calls.
184            tracing::debug!(
185                cluster_size = cluster.len(),
186                "tier promotion: singleton cluster skipped"
187            );
188            continue;
189        }
190
191        stats.clusters_formed += 1;
192
193        let source_conv_id = cluster[0].0.conversation_id;
194
195        match merge_cluster_and_promote(store, provider, &cluster, source_conv_id).await {
196            Ok(()) => stats.promotions_completed += 1,
197            Err(e) => {
198                tracing::warn!(
199                    cluster_size = cluster.len(),
200                    error = %e,
201                    "tier promotion: cluster merge failed, skipping"
202                );
203                stats.merge_failures += 1;
204            }
205        }
206    }
207
208    Ok(stats)
209}
210
211/// Cluster candidates by cosine similarity using greedy nearest-neighbor.
212///
213/// Each candidate is compared to the representative (first member) of existing clusters.
214/// If similarity >= threshold, it joins that cluster; otherwise it starts a new one.
215/// This is O(n * k) where k is the number of clusters formed, not O(n^2).
216fn cluster_by_similarity(
217    candidates: Vec<(PromotionCandidate, Vec<f32>)>,
218    threshold: f32,
219) -> Vec<Vec<(PromotionCandidate, Vec<f32>)>> {
220    let mut clusters: Vec<Vec<(PromotionCandidate, Vec<f32>)>> = Vec::new();
221
222    'outer: for candidate in candidates {
223        if candidate.1.is_empty() {
224            // No embedding — own cluster (will be skipped as singleton).
225            clusters.push(vec![candidate]);
226            continue;
227        }
228
229        for cluster in &mut clusters {
230            let rep = &cluster[0].1;
231            if rep.is_empty() {
232                continue;
233            }
234            let sim = cosine_similarity(&candidate.1, rep);
235            if sim >= threshold {
236                cluster.push(candidate);
237                continue 'outer;
238            }
239        }
240
241        clusters.push(vec![candidate]);
242    }
243
244    clusters
245}
246
247/// Call the LLM to merge a cluster and promote the result to semantic tier.
248///
249/// Validates the merged output before promoting. If the output is empty or has
250/// a cosine similarity below `MERGE_VALIDATION_MIN_SIMILARITY` to all originals,
251/// returns an error without promoting.
252#[tracing::instrument(name = "memory.tiers.merge_cluster_and_promote", skip_all)]
253async fn merge_cluster_and_promote(
254    store: &SqliteStore,
255    provider: &AnyProvider,
256    cluster: &[(PromotionCandidate, Vec<f32>)],
257    conversation_id: ConversationId,
258) -> Result<(), MemoryError> {
259    let contents: Vec<&str> = cluster.iter().map(|(c, _)| c.content.as_str()).collect();
260    let original_ids: Vec<crate::types::MessageId> = cluster.iter().map(|(c, _)| c.id).collect();
261
262    let merged = call_merge_llm(provider, &contents).await?;
263
264    // Validate: non-empty result required.
265    let merged = merged.trim().to_owned();
266    if merged.is_empty() {
267        return Err(MemoryError::InvalidInput(
268            "LLM merge returned empty result".into(),
269        ));
270    }
271
272    // Validate: merged result must be semantically related to at least one original.
273    // Embed the merged result and compare against original embeddings.
274    if provider.supports_embeddings() {
275        let embeddings_available = cluster.iter().any(|(_, emb)| !emb.is_empty());
276        if embeddings_available {
277            match tokio::time::timeout(Duration::from_secs(5), provider.embed(&merged)).await {
278                Ok(Ok(merged_vec)) => {
279                    let max_sim = cluster
280                        .iter()
281                        .filter(|(_, emb)| !emb.is_empty())
282                        .map(|(_, emb)| cosine_similarity(&merged_vec, emb))
283                        .fold(f32::NEG_INFINITY, f32::max);
284
285                    if max_sim < MERGE_VALIDATION_MIN_SIMILARITY {
286                        return Err(MemoryError::InvalidInput(format!(
287                            "LLM merge validation failed: max similarity to originals = {max_sim:.3} < {MERGE_VALIDATION_MIN_SIMILARITY}"
288                        )));
289                    }
290                }
291                Ok(Err(e)) => {
292                    tracing::warn!(
293                        error = %e,
294                        "tier promotion: failed to embed merged result, skipping similarity validation"
295                    );
296                }
297                Err(_) => {
298                    tracing::warn!(
299                        "tier promotion: embed timed out, skipping similarity validation"
300                    );
301                }
302            }
303        }
304    }
305
306    // Retry the DB write up to 3 times with exponential backoff on SQLITE_BUSY.
307    // The LLM merge above is not retried — only the cheap DB write is.
308    let delays_ms = [50u64, 100, 200];
309    for (attempt, &delay_ms) in delays_ms.iter().enumerate() {
310        match store
311            .promote_to_semantic(conversation_id, &merged, &original_ids)
312            .await
313        {
314            Ok(_) => break,
315            Err(e) => {
316                // Detect SQLITE_BUSY via the sqlx::Error::Database error code ("5") when
317                // available; fall back to string matching. String matching is safe here because
318                // the error originates from SQLite internals, not user input. The fallback
319                // handles wrapping layers where downcasting would add disproportionate complexity.
320                let is_busy = if let MemoryError::Sqlx(sqlx::Error::Database(ref db_err)) = e {
321                    db_err.code().as_deref() == Some("5")
322                } else {
323                    e.to_string().contains("database is locked")
324                };
325                if is_busy && attempt < delays_ms.len() - 1 {
326                    tracing::warn!(
327                        attempt = attempt + 1,
328                        delay_ms,
329                        "tier promotion: SQLite busy, retrying"
330                    );
331                    tokio::time::sleep(Duration::from_millis(delay_ms)).await;
332                } else {
333                    return Err(e);
334                }
335            }
336        }
337    }
338    tracing::debug!(
339        cluster_size = cluster.len(),
340        merged_len = merged.len(),
341        "tier promotion: cluster promoted to semantic"
342    );
343
344    Ok(())
345}
346
347/// Call the LLM to distill a set of episodic messages into a single semantic fact.
348async fn call_merge_llm(provider: &AnyProvider, contents: &[&str]) -> Result<String, MemoryError> {
349    use zeph_llm::provider::{Message, MessageMetadata, Role};
350
351    let bullet_list: String = contents
352        .iter()
353        .enumerate()
354        .map(|(i, c)| format!("{}. {c}", i + 1))
355        .collect::<Vec<_>>()
356        .join("\n");
357
358    let system_content = "You are a memory consolidation agent. \
359        Merge the following episodic memories into a single concise semantic fact. \
360        Strip timestamps, session context, hedging, and filler. \
361        Output ONLY the distilled fact as a single plain-text sentence or short paragraph. \
362        Do not add prefixes like 'The user...' or 'Fact:'.";
363
364    let user_content =
365        format!("Merge these episodic memories into one semantic fact:\n\n{bullet_list}");
366
367    let messages = vec![
368        Message {
369            role: Role::System,
370            content: system_content.to_owned(),
371            parts: vec![],
372            metadata: MessageMetadata::default(),
373        },
374        Message {
375            role: Role::User,
376            content: user_content,
377            parts: vec![],
378            metadata: MessageMetadata::default(),
379        },
380    ];
381
382    let timeout = Duration::from_secs(15);
383
384    let result = tokio::time::timeout(timeout, provider.chat(&messages))
385        .await
386        .map_err(|_| MemoryError::Timeout("LLM merge timed out after 15s".into()))?
387        .map_err(MemoryError::Llm)?;
388
389    Ok(result)
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    #[test]
397    fn cluster_by_similarity_groups_identical() {
398        // Two identical unit vectors should cluster together at any threshold <= 1.0.
399        let v1 = vec![1.0f32, 0.0, 0.0];
400        let v2 = vec![1.0f32, 0.0, 0.0];
401        let v3 = vec![0.0f32, 1.0, 0.0]; // orthogonal
402
403        let candidates = vec![
404            (make_candidate(1), v1),
405            (make_candidate(2), v2),
406            (make_candidate(3), v3),
407        ];
408
409        let clusters = cluster_by_similarity(candidates, 0.92f32);
410        assert_eq!(clusters.len(), 2, "should produce 2 clusters");
411        assert_eq!(clusters[0].len(), 2, "first cluster should have 2 members");
412        assert_eq!(clusters[1].len(), 1, "second cluster is the orthogonal one");
413    }
414
415    #[test]
416    fn cluster_by_similarity_empty_embeddings_become_singletons() {
417        let candidates = vec![(make_candidate(1), vec![]), (make_candidate(2), vec![])];
418        let clusters = cluster_by_similarity(candidates, 0.92);
419        assert_eq!(clusters.len(), 2);
420    }
421
422    fn make_candidate(id: i64) -> PromotionCandidate {
423        PromotionCandidate {
424            id: crate::types::MessageId(id),
425            conversation_id: ConversationId(1),
426            content: format!("content {id}"),
427            session_count: 3,
428            importance_score: 0.5,
429        }
430    }
431
432    /// `embed()` failure during merge validation → fail-open: merge proceeds without rejecting.
433    ///
434    /// `merge_cluster_and_promote` must return `Ok(())` when the `embed` call for similarity
435    /// validation errors (covers both timeout and provider error — both are handled fail-open
436    /// by the same `Ok(Err(e))` arm in the production code).
437    #[tokio::test]
438    async fn merge_validation_embed_failure_is_fail_open() {
439        let store = crate::store::SqliteStore::new(":memory:").await.unwrap();
440        let conv_id = store.create_conversation().await.unwrap();
441        let m1 = store
442            .save_message(conv_id, "user", "Alice uses Rust")
443            .await
444            .unwrap();
445        let m2 = store
446            .save_message(conv_id, "user", "Alice loves Rust")
447            .await
448            .unwrap();
449
450        // Provider: instant LLM chat reply + embed always errors (InvalidInput).
451        // Simulates any non-timeout embed failure; the timeout path maps to the same fail-open arm.
452        let provider = zeph_llm::any::AnyProvider::Mock(
453            zeph_llm::mock::MockProvider::with_responses(vec!["Alice uses and loves Rust".into()])
454                .with_embed_invalid_input(),
455        );
456
457        let cluster = vec![
458            (make_candidate(m1.0), vec![1.0_f32, 0.0, 0.0]),
459            (make_candidate(m2.0), vec![1.0_f32, 0.0, 0.0]),
460        ];
461
462        let result = merge_cluster_and_promote(&store, &provider, &cluster, conv_id).await;
463        assert!(
464            result.is_ok(),
465            "embed failure during merge validation must be fail-open (Ok), got {result:?}"
466        );
467    }
468}