Skip to main content

zeph_memory/
scenes.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! `MemScene` consolidation (#2332).
5//!
6//! Groups semantically related semantic-tier messages into stable entity profiles (scenes).
7//! Runs as a separate background loop, decoupled from tier promotion timing.
8
9use std::sync::Arc;
10use std::time::Duration;
11
12use tokio_util::sync::CancellationToken;
13use zeph_llm::any::AnyProvider;
14use zeph_llm::provider::LlmProvider as _;
15
16use crate::error::MemoryError;
17use crate::store::SqliteStore;
18use crate::types::{MemSceneId, MessageId};
19use zeph_common::math::cosine_similarity;
20
21/// A `MemScene` groups semantically related semantic-tier messages with a stable entity profile.
22///
23/// Scenes are created and updated by [`start_scene_consolidation_loop`] and can be
24/// listed via [`list_scenes`].
25#[derive(Debug, Clone)]
26pub struct MemScene {
27    /// `SQLite` row ID of the scene.
28    pub id: MemSceneId,
29    /// Short human-readable label for the scene (e.g. `"Rust programming"`).
30    pub label: String,
31    /// LLM-generated entity profile summarising the scene's members.
32    pub profile: String,
33    /// Number of messages currently assigned to this scene.
34    pub member_count: u32,
35    /// Unix timestamp when the scene was first created.
36    pub created_at: i64,
37    /// Unix timestamp of the last profile update.
38    pub updated_at: i64,
39}
40
41/// Configuration for the scene consolidation background loop.
42#[derive(Debug, Clone)]
43pub struct SceneConfig {
44    /// Enable or disable the scene consolidation loop.
45    pub enabled: bool,
46    /// Minimum cosine similarity for two messages to be assigned to the same scene.
47    pub similarity_threshold: f32,
48    /// Maximum number of unassigned messages to process per sweep.
49    pub batch_size: usize,
50    /// How often to run a sweep, in seconds.
51    pub sweep_interval_secs: u64,
52}
53
54/// Start the background scene consolidation loop.
55///
56/// Each sweep clusters unassigned semantic-tier messages into `MemScenes`.
57/// Runs independently from the tier promotion loop.
58pub async fn start_scene_consolidation_loop(
59    store: Arc<SqliteStore>,
60    provider: AnyProvider,
61    config: SceneConfig,
62    cancel: CancellationToken,
63) {
64    if !config.enabled {
65        tracing::debug!("scene consolidation disabled (tiers.scene_enabled = false)");
66        return;
67    }
68
69    let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
70    // Skip first tick to avoid running immediately at startup.
71    ticker.tick().await;
72
73    loop {
74        tokio::select! {
75            () = cancel.cancelled() => {
76                tracing::debug!("scene consolidation loop shutting down");
77                return;
78            }
79            _ = ticker.tick() => {}
80        }
81
82        tracing::debug!("scene consolidation: starting sweep");
83        let start = std::time::Instant::now();
84
85        match consolidate_scenes(&store, &provider, &config).await {
86            Ok(stats) => {
87                tracing::info!(
88                    candidates = stats.candidates,
89                    scenes_created = stats.scenes_created,
90                    messages_assigned = stats.messages_assigned,
91                    elapsed_ms = start.elapsed().as_millis(),
92                    "scene consolidation: sweep complete"
93                );
94            }
95            Err(e) => {
96                tracing::warn!(
97                    error = %e,
98                    elapsed_ms = start.elapsed().as_millis(),
99                    "scene consolidation: sweep failed, will retry"
100                );
101            }
102        }
103    }
104}
105
106/// Stats collected during a single scene consolidation sweep.
107#[derive(Debug, Default)]
108pub struct SceneStats {
109    pub candidates: usize,
110    pub scenes_created: usize,
111    pub messages_assigned: usize,
112}
113
114/// Execute one full scene consolidation sweep.
115///
116/// # Errors
117///
118/// Returns an error if the `SQLite` query fails. LLM and embedding errors are logged but skipped.
119#[cfg_attr(
120    feature = "profiling",
121    tracing::instrument(name = "memory.consolidate_scenes", skip_all)
122)]
123pub async fn consolidate_scenes(
124    store: &SqliteStore,
125    provider: &AnyProvider,
126    config: &SceneConfig,
127) -> Result<SceneStats, MemoryError> {
128    let candidates = store
129        .find_unscened_semantic_messages(config.batch_size)
130        .await?;
131
132    if candidates.len() < 2 {
133        return Ok(SceneStats::default());
134    }
135
136    let mut stats = SceneStats {
137        candidates: candidates.len(),
138        ..SceneStats::default()
139    };
140
141    // Embed all candidates.
142    let mut embedded: Vec<(MessageId, String, Vec<f32>)> = Vec::with_capacity(candidates.len());
143    if provider.supports_embeddings() {
144        for (msg_id, content) in candidates {
145            match provider.embed(&content).await {
146                Ok(vec) => embedded.push((msg_id, content, vec)),
147                Err(e) => {
148                    tracing::warn!(
149                        message_id = msg_id.0,
150                        error = %e,
151                        "scene consolidation: failed to embed candidate, skipping"
152                    );
153                }
154            }
155        }
156    } else {
157        return Ok(stats);
158    }
159
160    if embedded.len() < 2 {
161        return Ok(stats);
162    }
163
164    // Cluster by cosine similarity.
165    let clusters = cluster_messages(embedded, config.similarity_threshold);
166
167    for cluster in clusters {
168        if cluster.len() < 2 {
169            continue;
170        }
171
172        let contents: Vec<&str> = cluster.iter().map(|(_, c, _)| c.as_str()).collect();
173        let msg_ids: Vec<MessageId> = cluster.iter().map(|(id, _, _)| *id).collect();
174
175        match generate_scene_label_and_profile(provider, &contents).await {
176            Ok((label, profile)) => {
177                let label = label.chars().take(100).collect::<String>();
178                let profile = profile.chars().take(2000).collect::<String>();
179                match store.insert_mem_scene(&label, &profile, &msg_ids).await {
180                    Ok(_scene_id) => {
181                        stats.scenes_created += 1;
182                        stats.messages_assigned += msg_ids.len();
183                    }
184                    Err(e) => {
185                        tracing::warn!(
186                            error = %e,
187                            cluster_size = msg_ids.len(),
188                            "scene consolidation: failed to insert scene"
189                        );
190                    }
191                }
192            }
193            Err(e) => {
194                tracing::warn!(
195                    error = %e,
196                    cluster_size = msg_ids.len(),
197                    "scene consolidation: LLM label generation failed, skipping cluster"
198                );
199            }
200        }
201    }
202
203    Ok(stats)
204}
205
206fn cluster_messages(
207    candidates: Vec<(MessageId, String, Vec<f32>)>,
208    threshold: f32,
209) -> Vec<Vec<(MessageId, String, Vec<f32>)>> {
210    let mut clusters: Vec<Vec<(MessageId, String, Vec<f32>)>> = Vec::new();
211
212    'outer: for candidate in candidates {
213        for cluster in &mut clusters {
214            let rep = &cluster[0].2;
215            if cosine_similarity(&candidate.2, rep) >= threshold {
216                cluster.push(candidate);
217                continue 'outer;
218            }
219        }
220        clusters.push(vec![candidate]);
221    }
222
223    clusters
224}
225
226async fn generate_scene_label_and_profile(
227    provider: &AnyProvider,
228    contents: &[&str],
229) -> Result<(String, String), MemoryError> {
230    use zeph_llm::provider::{Message, MessageMetadata, Role};
231
232    let bullet_list: String = contents
233        .iter()
234        .enumerate()
235        .map(|(i, c)| format!("{}. {c}", i + 1))
236        .collect::<Vec<_>>()
237        .join("\n");
238
239    let system_content = "You are a memory scene architect. \
240        Given a set of related semantic facts, generate:\n\
241        1. A short label (5 words max) identifying the core entity or topic.\n\
242        2. A 2-3 sentence entity profile summarizing the key facts.\n\
243        Respond in JSON: {\"label\": \"...\", \"profile\": \"...\"}";
244
245    let user_content =
246        format!("Generate a label and profile for these related facts:\n\n{bullet_list}");
247
248    let messages = vec![
249        Message {
250            role: Role::System,
251            content: system_content.to_owned(),
252            parts: vec![],
253            metadata: MessageMetadata::default(),
254        },
255        Message {
256            role: Role::User,
257            content: user_content,
258            parts: vec![],
259            metadata: MessageMetadata::default(),
260        },
261    ];
262
263    let result = tokio::time::timeout(Duration::from_secs(15), provider.chat(&messages))
264        .await
265        .map_err(|_| MemoryError::Other("scene LLM call timed out after 15s".into()))?
266        .map_err(MemoryError::Llm)?;
267
268    parse_label_profile(&result)
269}
270
271fn parse_label_profile(response: &str) -> Result<(String, String), MemoryError> {
272    // Try JSON parsing first.
273    if let Ok(val) = serde_json::from_str::<serde_json::Value>(response) {
274        let label = val
275            .get("label")
276            .and_then(|v| v.as_str())
277            .unwrap_or("")
278            .trim()
279            .to_owned();
280        let profile = val
281            .get("profile")
282            .and_then(|v| v.as_str())
283            .unwrap_or("")
284            .trim()
285            .to_owned();
286        if !label.is_empty() && !profile.is_empty() {
287            return Ok((label, profile));
288        }
289    }
290    // Fallback: treat first line as label, rest as profile.
291    let trimmed = response.trim();
292    let mut lines = trimmed.splitn(2, '\n');
293    let label = lines.next().unwrap_or("").trim().to_owned();
294    let profile = lines.next().unwrap_or(trimmed).trim().to_owned();
295    if label.is_empty() {
296        return Err(MemoryError::Other("scene LLM returned empty label".into()));
297    }
298    let profile = if profile.is_empty() {
299        label.clone()
300    } else {
301        profile
302    };
303    Ok((label, profile))
304}
305
306/// List all `MemScenes` from the store.
307///
308/// # Errors
309///
310/// Returns an error if the `SQLite` query fails.
311pub async fn list_scenes(store: &SqliteStore) -> Result<Vec<MemScene>, MemoryError> {
312    store.list_mem_scenes().await
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn cluster_messages_groups_similar() {
321        let v1 = vec![1.0f32, 0.0, 0.0];
322        let v2 = vec![1.0f32, 0.0, 0.0];
323        let v3 = vec![0.0f32, 1.0, 0.0];
324
325        let candidates = vec![
326            (MessageId(1), "a".to_owned(), v1),
327            (MessageId(2), "b".to_owned(), v2),
328            (MessageId(3), "c".to_owned(), v3),
329        ];
330
331        let clusters = cluster_messages(candidates, 0.80);
332        assert_eq!(clusters.len(), 2);
333        assert_eq!(clusters[0].len(), 2);
334        assert_eq!(clusters[1].len(), 1);
335    }
336
337    #[test]
338    fn parse_label_profile_valid_json() {
339        let json = r#"{"label": "Rust Auth JWT", "profile": "The project uses JWT for auth."}"#;
340        let (label, profile) = parse_label_profile(json).unwrap();
341        assert_eq!(label, "Rust Auth JWT");
342        assert_eq!(profile, "The project uses JWT for auth.");
343    }
344
345    #[test]
346    fn parse_label_profile_fallback_lines() {
347        let text = "Rust Auth\nJWT tokens used for authentication. Rate limited at 100 rps.";
348        let (label, profile) = parse_label_profile(text).unwrap();
349        assert_eq!(label, "Rust Auth");
350        assert!(profile.contains("JWT"));
351    }
352
353    #[test]
354    fn parse_label_profile_empty_fails() {
355        assert!(parse_label_profile("").is_err());
356    }
357}