Skip to main content

zeph_memory/
tiers.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! AOI three-layer memory tier promotion.
5//!
6//! Provides a background sweep loop that promotes frequently-accessed episodic messages
7//! to the semantic tier by:
8//! 1. Finding candidates with `session_count >= promotion_min_sessions`.
9//! 2. Grouping near-duplicate candidates by cosine similarity (greedy nearest-neighbor).
10//! 3. For each cluster with >= 2 messages, calling the LLM to distill a merged fact.
11//! 4. Validating the merge output (non-empty, similarity >= 0.7 to at least one original).
12//! 5. Inserting the semantic fact and soft-deleting the originals.
13//!
14//! The sweep respects a `CancellationToken` for graceful shutdown, following the
15//! same pattern as `eviction.rs`.
16
17use std::sync::Arc;
18use std::time::Duration;
19
20use tokio_util::sync::CancellationToken;
21use tracing::Instrument as _;
22use zeph_llm::any::AnyProvider;
23use zeph_llm::provider::LlmProvider as _;
24
25use crate::error::MemoryError;
26use crate::store::SqliteStore;
27use crate::store::messages::PromotionCandidate;
28use crate::types::ConversationId;
29use zeph_common::math::cosine_similarity;
30
31/// Minimum cosine similarity between the merged result and at least one original for the
32/// merge to be accepted. Prevents the LLM from producing semantically unrelated output.
33const MERGE_VALIDATION_MIN_SIMILARITY: f32 = 0.7;
34
35/// Configuration for the tier promotion sweep, passed from `zeph-config::TierPromotionConfig`.
36///
37/// Defined locally to avoid a direct dependency from `zeph-memory` on `zeph-config`.
38#[derive(Debug, Clone)]
39pub struct TierPromotionConfig {
40    /// Enable or disable the tier promotion loop.
41    pub enabled: bool,
42    /// Minimum number of distinct sessions in which a message must appear
43    /// before it becomes a promotion candidate.
44    pub promotion_min_sessions: u32,
45    /// Minimum cosine similarity for two messages to be considered duplicates
46    /// eligible for merging into one semantic fact.
47    pub similarity_threshold: f32,
48    /// How often to run a promotion sweep, in seconds.
49    pub sweep_interval_secs: u64,
50    /// Maximum number of candidates to process per sweep.
51    pub sweep_batch_size: usize,
52    /// Per-call timeout for every `embed()` invocation, in seconds. Default: `5`.
53    pub embed_timeout_secs: u64,
54}
55
56/// Start the background tier promotion loop.
57///
58/// Each sweep cycle:
59/// 1. Fetches episodic candidates with `session_count >= config.promotion_min_sessions`.
60/// 2. Embeds candidates and clusters near-duplicates (cosine similarity >= threshold).
61/// 3. For each cluster, calls the LLM to merge into a single semantic fact.
62/// 4. Validates the merged output; skips the cluster on failure.
63/// 5. Promotes validated clusters to semantic tier.
64///
65/// The loop exits immediately if `config.enabled = false`.
66///
67/// Database and LLM errors are logged but do not stop the loop.
68pub async fn start_tier_promotion_loop(
69    store: Arc<SqliteStore>,
70    provider: AnyProvider,
71    config: TierPromotionConfig,
72    cancel: CancellationToken,
73) {
74    if !config.enabled {
75        tracing::debug!("tier promotion disabled (tiers.enabled = false)");
76        return;
77    }
78
79    let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
80    // Skip the first immediate tick so we don't run at startup.
81    ticker.tick().await;
82
83    loop {
84        tokio::select! {
85            () = cancel.cancelled() => {
86                tracing::debug!("tier promotion loop shutting down");
87                return;
88            }
89            _ = ticker.tick() => {}
90        }
91
92        tracing::debug!("tier promotion: starting sweep");
93        let start = std::time::Instant::now();
94
95        let result = run_promotion_sweep(&store, &provider, &config).await;
96
97        let elapsed_ms = start.elapsed().as_millis();
98
99        match result {
100            Ok(stats) => {
101                tracing::info!(
102                    candidates = stats.candidates_evaluated,
103                    clusters = stats.clusters_formed,
104                    promoted = stats.promotions_completed,
105                    merge_failures = stats.merge_failures,
106                    elapsed_ms,
107                    "tier promotion: sweep complete"
108                );
109            }
110            Err(e) => {
111                tracing::warn!(error = %e, elapsed_ms, "tier promotion: sweep failed, will retry");
112            }
113        }
114    }
115}
116
117/// Stats collected during a single promotion sweep.
118#[derive(Debug, Default)]
119struct SweepStats {
120    candidates_evaluated: usize,
121    clusters_formed: usize,
122    promotions_completed: usize,
123    merge_failures: usize,
124}
125
126/// Execute one full promotion sweep cycle.
127#[tracing::instrument(name = "memory.tiers.promotion_sweep", skip_all)]
128async fn run_promotion_sweep(
129    store: &SqliteStore,
130    provider: &AnyProvider,
131    config: &TierPromotionConfig,
132) -> Result<SweepStats, MemoryError> {
133    let candidates = store
134        .find_promotion_candidates(config.promotion_min_sessions, config.sweep_batch_size)
135        .await?;
136
137    if candidates.is_empty() {
138        return Ok(SweepStats::default());
139    }
140
141    let mut stats = SweepStats {
142        candidates_evaluated: candidates.len(),
143        ..SweepStats::default()
144    };
145
146    // Embed all candidates in a single batch call, then zip back with candidates.
147    let embedded: Vec<(PromotionCandidate, Vec<f32>)> = if provider.supports_embeddings() {
148        let texts: Vec<&str> = candidates.iter().map(|c| c.content.as_str()).collect();
149        let span = tracing::info_span!("memory.tiers.embed_batch", count = texts.len());
150        let vecs = provider.embed_batch(&texts).instrument(span).await;
151        match vecs {
152            Ok(vecs) => {
153                if vecs.len() != texts.len() {
154                    tracing::warn!(
155                        expected = texts.len(),
156                        got = vecs.len(),
157                        "tier promotion: embed_batch length mismatch, skipping sweep"
158                    );
159                    return Ok(stats);
160                }
161                candidates.into_iter().zip(vecs).collect()
162            }
163            Err(e) => {
164                tracing::warn!(error = %e, "tier promotion: batch embed failed, skipping sweep");
165                return Ok(stats);
166            }
167        }
168    } else {
169        // No embedding support — push all with empty vecs (will become singletons).
170        candidates.into_iter().map(|c| (c, Vec::new())).collect()
171    };
172
173    if embedded.is_empty() {
174        return Ok(stats);
175    }
176
177    // Cluster candidates by cosine similarity (greedy nearest-neighbor).
178    // Each candidate is assigned to the first existing cluster whose centroid
179    // representative has similarity >= threshold with it, or starts a new cluster.
180    let threshold = config.similarity_threshold;
181    let clusters = cluster_by_similarity(embedded, threshold);
182
183    for cluster in clusters {
184        if cluster.len() < 2 {
185            // Single-member cluster — no merge needed, skip to avoid unnecessary LLM calls.
186            tracing::debug!(
187                cluster_size = cluster.len(),
188                "tier promotion: singleton cluster skipped"
189            );
190            continue;
191        }
192
193        stats.clusters_formed += 1;
194
195        let source_conv_id = cluster[0].0.conversation_id;
196
197        match merge_cluster_and_promote(
198            store,
199            provider,
200            &cluster,
201            source_conv_id,
202            Duration::from_secs(config.embed_timeout_secs),
203        )
204        .await
205        {
206            Ok(()) => stats.promotions_completed += 1,
207            Err(e) => {
208                tracing::warn!(
209                    cluster_size = cluster.len(),
210                    error = %e,
211                    "tier promotion: cluster merge failed, skipping"
212                );
213                stats.merge_failures += 1;
214            }
215        }
216    }
217
218    Ok(stats)
219}
220
221/// Cluster candidates by cosine similarity using greedy nearest-neighbor.
222///
223/// Each candidate is compared to the representative (first member) of existing clusters.
224/// If similarity >= threshold, it joins that cluster; otherwise it starts a new one.
225/// This is O(n * k) where k is the number of clusters formed, not O(n^2).
226fn cluster_by_similarity(
227    candidates: Vec<(PromotionCandidate, Vec<f32>)>,
228    threshold: f32,
229) -> Vec<Vec<(PromotionCandidate, Vec<f32>)>> {
230    let mut clusters: Vec<Vec<(PromotionCandidate, Vec<f32>)>> = Vec::new();
231
232    'outer: for candidate in candidates {
233        if candidate.1.is_empty() {
234            // No embedding — own cluster (will be skipped as singleton).
235            clusters.push(vec![candidate]);
236            continue;
237        }
238
239        for cluster in &mut clusters {
240            let rep = &cluster[0].1;
241            if rep.is_empty() {
242                continue;
243            }
244            let sim = cosine_similarity(&candidate.1, rep);
245            if sim >= threshold {
246                cluster.push(candidate);
247                continue 'outer;
248            }
249        }
250
251        clusters.push(vec![candidate]);
252    }
253
254    clusters
255}
256
257/// Call the LLM to merge a cluster and promote the result to semantic tier.
258///
259/// Validates the merged output before promoting. If the output is empty or has
260/// a cosine similarity below `MERGE_VALIDATION_MIN_SIMILARITY` to all originals,
261/// returns an error without promoting.
262#[tracing::instrument(name = "memory.tiers.merge_cluster_and_promote", skip_all)]
263async fn merge_cluster_and_promote(
264    store: &SqliteStore,
265    provider: &AnyProvider,
266    cluster: &[(PromotionCandidate, Vec<f32>)],
267    conversation_id: ConversationId,
268    embed_timeout: Duration,
269) -> Result<(), MemoryError> {
270    let contents: Vec<&str> = cluster.iter().map(|(c, _)| c.content.as_str()).collect();
271    let original_ids: Vec<crate::types::MessageId> = cluster.iter().map(|(c, _)| c.id).collect();
272
273    let merged = call_merge_llm(provider, &contents).await?;
274
275    // Validate: non-empty result required.
276    let merged = merged.trim().to_owned();
277    if merged.is_empty() {
278        return Err(MemoryError::InvalidInput(
279            "LLM merge returned empty result".into(),
280        ));
281    }
282
283    // Validate: merged result must be semantically related to at least one original.
284    // Embed the merged result and compare against original embeddings.
285    if provider.supports_embeddings() {
286        let embeddings_available = cluster.iter().any(|(_, emb)| !emb.is_empty());
287        if embeddings_available {
288            match tokio::time::timeout(embed_timeout, provider.embed(&merged)).await {
289                Ok(Ok(merged_vec)) => {
290                    let max_sim = cluster
291                        .iter()
292                        .filter(|(_, emb)| !emb.is_empty())
293                        .map(|(_, emb)| cosine_similarity(&merged_vec, emb))
294                        .fold(f32::NEG_INFINITY, f32::max);
295
296                    if max_sim < MERGE_VALIDATION_MIN_SIMILARITY {
297                        return Err(MemoryError::InvalidInput(format!(
298                            "LLM merge validation failed: max similarity to originals = {max_sim:.3} < {MERGE_VALIDATION_MIN_SIMILARITY}"
299                        )));
300                    }
301                }
302                Ok(Err(e)) => {
303                    tracing::warn!(
304                        error = %e,
305                        "tier promotion: failed to embed merged result, skipping similarity validation"
306                    );
307                }
308                Err(_) => {
309                    tracing::warn!(
310                        "tier promotion: embed timed out, skipping similarity validation"
311                    );
312                }
313            }
314        }
315    }
316
317    // Retry the DB write up to 3 times with exponential backoff on SQLITE_BUSY.
318    // The LLM merge above is not retried — only the cheap DB write is.
319    let delays_ms = [50u64, 100, 200];
320    for (attempt, &delay_ms) in delays_ms.iter().enumerate() {
321        match store
322            .promote_to_semantic(conversation_id, &merged, &original_ids)
323            .await
324        {
325            Ok(_) => break,
326            Err(e) => {
327                // Detect SQLITE_BUSY via the sqlx::Error::Database error code ("5") when
328                // available; fall back to string matching. String matching is safe here because
329                // the error originates from SQLite internals, not user input. The fallback
330                // handles wrapping layers where downcasting would add disproportionate complexity.
331                let is_busy = if let MemoryError::Sqlx(sqlx::Error::Database(ref db_err)) = e {
332                    db_err.code().as_deref() == Some("5")
333                } else {
334                    e.to_string().contains("database is locked")
335                };
336                if is_busy && attempt < delays_ms.len() - 1 {
337                    tracing::warn!(
338                        attempt = attempt + 1,
339                        delay_ms,
340                        "tier promotion: SQLite busy, retrying"
341                    );
342                    tokio::time::sleep(Duration::from_millis(delay_ms)).await;
343                } else {
344                    return Err(e);
345                }
346            }
347        }
348    }
349    tracing::debug!(
350        cluster_size = cluster.len(),
351        merged_len = merged.len(),
352        "tier promotion: cluster promoted to semantic"
353    );
354
355    Ok(())
356}
357
358/// Call the LLM to distill a set of episodic messages into a single semantic fact.
359async fn call_merge_llm(provider: &AnyProvider, contents: &[&str]) -> Result<String, MemoryError> {
360    use zeph_llm::provider::{Message, MessageMetadata, Role};
361
362    let bullet_list: String = contents
363        .iter()
364        .enumerate()
365        .map(|(i, c)| format!("{}. {c}", i + 1))
366        .collect::<Vec<_>>()
367        .join("\n");
368
369    let system_content = "You are a memory consolidation agent. \
370        Merge the following episodic memories into a single concise semantic fact. \
371        Strip timestamps, session context, hedging, and filler. \
372        Output ONLY the distilled fact as a single plain-text sentence or short paragraph. \
373        Do not add prefixes like 'The user...' or 'Fact:'.";
374
375    let user_content =
376        format!("Merge these episodic memories into one semantic fact:\n\n{bullet_list}");
377
378    let messages = vec![
379        Message {
380            role: Role::System,
381            content: system_content.to_owned(),
382            parts: vec![],
383            metadata: MessageMetadata::default(),
384        },
385        Message {
386            role: Role::User,
387            content: user_content,
388            parts: vec![],
389            metadata: MessageMetadata::default(),
390        },
391    ];
392
393    let timeout = Duration::from_secs(15);
394
395    let result = tokio::time::timeout(timeout, provider.chat(&messages))
396        .await
397        .map_err(|_| MemoryError::Timeout("LLM merge timed out after 15s".into()))?
398        .map_err(MemoryError::Llm)?;
399
400    Ok(result)
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn cluster_by_similarity_groups_identical() {
409        // Two identical unit vectors should cluster together at any threshold <= 1.0.
410        let v1 = vec![1.0f32, 0.0, 0.0];
411        let v2 = vec![1.0f32, 0.0, 0.0];
412        let v3 = vec![0.0f32, 1.0, 0.0]; // orthogonal
413
414        let candidates = vec![
415            (make_candidate(1), v1),
416            (make_candidate(2), v2),
417            (make_candidate(3), v3),
418        ];
419
420        let clusters = cluster_by_similarity(candidates, 0.92f32);
421        assert_eq!(clusters.len(), 2, "should produce 2 clusters");
422        assert_eq!(clusters[0].len(), 2, "first cluster should have 2 members");
423        assert_eq!(clusters[1].len(), 1, "second cluster is the orthogonal one");
424    }
425
426    #[test]
427    fn cluster_by_similarity_empty_embeddings_become_singletons() {
428        let candidates = vec![(make_candidate(1), vec![]), (make_candidate(2), vec![])];
429        let clusters = cluster_by_similarity(candidates, 0.92);
430        assert_eq!(clusters.len(), 2);
431    }
432
433    fn make_candidate(id: i64) -> PromotionCandidate {
434        PromotionCandidate {
435            id: crate::types::MessageId(id),
436            conversation_id: ConversationId(1),
437            content: format!("content {id}"),
438            session_count: 3,
439            importance_score: 0.5,
440        }
441    }
442
443    /// `embed()` failure during merge validation → fail-open: merge proceeds without rejecting.
444    ///
445    /// `merge_cluster_and_promote` must return `Ok(())` when the `embed` call for similarity
446    /// validation errors (covers both timeout and provider error — both are handled fail-open
447    /// by the same `Ok(Err(e))` arm in the production code).
448    #[tokio::test]
449    async fn merge_validation_embed_failure_is_fail_open() {
450        let store = crate::store::SqliteStore::new(":memory:").await.unwrap();
451        let conv_id = store.create_conversation().await.unwrap();
452        let m1 = store
453            .save_message(conv_id, "user", "Alice uses Rust")
454            .await
455            .unwrap();
456        let m2 = store
457            .save_message(conv_id, "user", "Alice loves Rust")
458            .await
459            .unwrap();
460
461        // Provider: instant LLM chat reply + embed always errors (InvalidInput).
462        // Simulates any non-timeout embed failure; the timeout path maps to the same fail-open arm.
463        let provider = zeph_llm::any::AnyProvider::Mock(
464            zeph_llm::mock::MockProvider::with_responses(vec!["Alice uses and loves Rust".into()])
465                .with_embed_invalid_input(),
466        );
467
468        let cluster = vec![
469            (make_candidate(m1.0), vec![1.0_f32, 0.0, 0.0]),
470            (make_candidate(m2.0), vec![1.0_f32, 0.0, 0.0]),
471        ];
472
473        let result =
474            merge_cluster_and_promote(&store, &provider, &cluster, conv_id, Duration::from_secs(5))
475                .await;
476        assert!(
477            result.is_ok(),
478            "embed failure during merge validation must be fail-open (Ok), got {result:?}"
479        );
480    }
481}