Skip to main content

zeph_memory/
scenes.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! `MemScene` consolidation (#2332).
5//!
6//! Groups semantically related semantic-tier messages into stable entity profiles (scenes).
7//! Runs as a separate background loop, decoupled from tier promotion timing.
8
9use std::sync::Arc;
10use std::time::Duration;
11
12use tokio_util::sync::CancellationToken;
13use tracing::Instrument as _;
14use zeph_llm::any::AnyProvider;
15use zeph_llm::provider::LlmProvider as _;
16
17use crate::error::MemoryError;
18use crate::store::SqliteStore;
19use crate::types::{MemSceneId, MessageId};
20use zeph_common::math::cosine_similarity;
21
22/// A `MemScene` groups semantically related semantic-tier messages with a stable entity profile.
23///
24/// Scenes are created and updated by [`start_scene_consolidation_loop`] and can be
25/// listed via [`list_scenes`].
26#[derive(Debug, Clone)]
27pub struct MemScene {
28    /// `SQLite` row ID of the scene.
29    pub id: MemSceneId,
30    /// Short human-readable label for the scene (e.g. `"Rust programming"`).
31    pub label: String,
32    /// LLM-generated entity profile summarising the scene's members.
33    pub profile: String,
34    /// Number of messages currently assigned to this scene.
35    pub member_count: u32,
36    /// Unix timestamp when the scene was first created.
37    pub created_at: i64,
38    /// Unix timestamp of the last profile update.
39    pub updated_at: i64,
40}
41
42/// Configuration for the scene consolidation background loop.
43#[derive(Debug, Clone)]
44pub struct SceneConfig {
45    /// Enable or disable the scene consolidation loop.
46    pub enabled: bool,
47    /// Minimum cosine similarity for two messages to be assigned to the same scene.
48    pub similarity_threshold: f32,
49    /// Maximum number of unassigned messages to process per sweep.
50    pub batch_size: usize,
51    /// How often to run a sweep, in seconds.
52    pub sweep_interval_secs: u64,
53}
54
55/// Start the background scene consolidation loop.
56///
57/// Each sweep clusters unassigned semantic-tier messages into `MemScenes`.
58/// Runs independently from the tier promotion loop.
59pub async fn start_scene_consolidation_loop(
60    store: Arc<SqliteStore>,
61    provider: AnyProvider,
62    config: SceneConfig,
63    cancel: CancellationToken,
64) {
65    if !config.enabled {
66        tracing::debug!("scene consolidation disabled (tiers.scene_enabled = false)");
67        return;
68    }
69
70    let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
71    // Skip first tick to avoid running immediately at startup.
72    ticker.tick().await;
73
74    loop {
75        tokio::select! {
76            () = cancel.cancelled() => {
77                tracing::debug!("scene consolidation loop shutting down");
78                return;
79            }
80            _ = ticker.tick() => {}
81        }
82
83        tracing::debug!("scene consolidation: starting sweep");
84        let start = std::time::Instant::now();
85
86        match consolidate_scenes(&store, &provider, &config).await {
87            Ok(stats) => {
88                tracing::info!(
89                    candidates = stats.candidates,
90                    scenes_created = stats.scenes_created,
91                    messages_assigned = stats.messages_assigned,
92                    elapsed_ms = start.elapsed().as_millis(),
93                    "scene consolidation: sweep complete"
94                );
95            }
96            Err(e) => {
97                tracing::warn!(
98                    error = %e,
99                    elapsed_ms = start.elapsed().as_millis(),
100                    "scene consolidation: sweep failed, will retry"
101                );
102            }
103        }
104    }
105}
106
107/// Stats collected during a single scene consolidation sweep.
108#[derive(Debug, Default)]
109pub struct SceneStats {
110    pub candidates: usize,
111    pub scenes_created: usize,
112    pub messages_assigned: usize,
113}
114
115/// Execute one full scene consolidation sweep.
116///
117/// # Errors
118///
119/// Returns an error if the `SQLite` query fails. LLM and embedding errors are logged but skipped.
120#[tracing::instrument(name = "memory.scenes.consolidation_sweep", skip_all)]
121pub async fn consolidate_scenes(
122    store: &SqliteStore,
123    provider: &AnyProvider,
124    config: &SceneConfig,
125) -> Result<SceneStats, MemoryError> {
126    let candidates = store
127        .find_unscened_semantic_messages(config.batch_size)
128        .await?;
129
130    if candidates.len() < 2 {
131        return Ok(SceneStats::default());
132    }
133
134    let mut stats = SceneStats {
135        candidates: candidates.len(),
136        ..SceneStats::default()
137    };
138
139    if !provider.supports_embeddings() {
140        return Ok(stats);
141    }
142
143    // Embed all candidates in a single batch call.
144    let texts: Vec<&str> = candidates.iter().map(|(_, c)| c.as_str()).collect();
145    let span = tracing::info_span!("memory.scenes.embed_batch", count = texts.len());
146    let vecs = provider.embed_batch(&texts).instrument(span).await;
147    let embedded: Vec<(MessageId, String, Vec<f32>)> = match vecs {
148        Ok(vecs) => {
149            if vecs.len() != texts.len() {
150                tracing::warn!(
151                    expected = texts.len(),
152                    got = vecs.len(),
153                    "scene consolidation: embed_batch length mismatch, skipping sweep"
154                );
155                return Ok(stats);
156            }
157            candidates
158                .into_iter()
159                .zip(vecs)
160                .map(|((msg_id, content), vec)| (msg_id, content, vec))
161                .collect()
162        }
163        Err(e) => {
164            tracing::warn!(error = %e, "scene consolidation: batch embed failed, skipping sweep");
165            return Ok(stats);
166        }
167    };
168
169    if embedded.len() < 2 {
170        return Ok(stats);
171    }
172
173    // Cluster by cosine similarity.
174    let clusters = cluster_messages(embedded, config.similarity_threshold);
175
176    for cluster in clusters {
177        if cluster.len() < 2 {
178            continue;
179        }
180
181        let contents: Vec<&str> = cluster.iter().map(|(_, c, _)| c.as_str()).collect();
182        let msg_ids: Vec<MessageId> = cluster.iter().map(|(id, _, _)| *id).collect();
183
184        match generate_scene_label_and_profile(provider, &contents).await {
185            Ok((label, profile)) => {
186                let label = label.chars().take(100).collect::<String>();
187                let profile = profile.chars().take(2000).collect::<String>();
188                match store.insert_mem_scene(&label, &profile, &msg_ids).await {
189                    Ok(_scene_id) => {
190                        stats.scenes_created += 1;
191                        stats.messages_assigned += msg_ids.len();
192                    }
193                    Err(e) => {
194                        tracing::warn!(
195                            error = %e,
196                            cluster_size = msg_ids.len(),
197                            "scene consolidation: failed to insert scene"
198                        );
199                    }
200                }
201            }
202            Err(e) => {
203                tracing::warn!(
204                    error = %e,
205                    cluster_size = msg_ids.len(),
206                    "scene consolidation: LLM label generation failed, skipping cluster"
207                );
208            }
209        }
210    }
211
212    Ok(stats)
213}
214
215fn cluster_messages(
216    candidates: Vec<(MessageId, String, Vec<f32>)>,
217    threshold: f32,
218) -> Vec<Vec<(MessageId, String, Vec<f32>)>> {
219    let mut clusters: Vec<Vec<(MessageId, String, Vec<f32>)>> = Vec::new();
220
221    'outer: for candidate in candidates {
222        for cluster in &mut clusters {
223            let rep = &cluster[0].2;
224            if cosine_similarity(&candidate.2, rep) >= threshold {
225                cluster.push(candidate);
226                continue 'outer;
227            }
228        }
229        clusters.push(vec![candidate]);
230    }
231
232    clusters
233}
234
235async fn generate_scene_label_and_profile(
236    provider: &AnyProvider,
237    contents: &[&str],
238) -> Result<(String, String), MemoryError> {
239    use zeph_llm::provider::{Message, MessageMetadata, Role};
240
241    let bullet_list: String = contents
242        .iter()
243        .enumerate()
244        .map(|(i, c)| format!("{}. {c}", i + 1))
245        .collect::<Vec<_>>()
246        .join("\n");
247
248    let system_content = "You are a memory scene architect. \
249        Given a set of related semantic facts, generate:\n\
250        1. A short label (5 words max) identifying the core entity or topic.\n\
251        2. A 2-3 sentence entity profile summarizing the key facts.\n\
252        Respond in JSON: {\"label\": \"...\", \"profile\": \"...\"}";
253
254    let user_content =
255        format!("Generate a label and profile for these related facts:\n\n{bullet_list}");
256
257    let messages = vec![
258        Message {
259            role: Role::System,
260            content: system_content.to_owned(),
261            parts: vec![],
262            metadata: MessageMetadata::default(),
263        },
264        Message {
265            role: Role::User,
266            content: user_content,
267            parts: vec![],
268            metadata: MessageMetadata::default(),
269        },
270    ];
271
272    let result = tokio::time::timeout(Duration::from_secs(15), provider.chat(&messages))
273        .await
274        .map_err(|_| MemoryError::Timeout("scene LLM call timed out after 15s".into()))?
275        .map_err(MemoryError::Llm)?;
276
277    parse_label_profile(&result)
278}
279
280fn parse_label_profile(response: &str) -> Result<(String, String), MemoryError> {
281    // Try JSON parsing first.
282    if let Ok(val) = serde_json::from_str::<serde_json::Value>(response) {
283        let label = val
284            .get("label")
285            .and_then(|v| v.as_str())
286            .unwrap_or("")
287            .trim()
288            .to_owned();
289        let profile = val
290            .get("profile")
291            .and_then(|v| v.as_str())
292            .unwrap_or("")
293            .trim()
294            .to_owned();
295        if !label.is_empty() && !profile.is_empty() {
296            return Ok((label, profile));
297        }
298    }
299    // Fallback: treat first line as label, rest as profile.
300    let trimmed = response.trim();
301    let mut lines = trimmed.splitn(2, '\n');
302    let label = lines.next().unwrap_or("").trim().to_owned();
303    let profile = lines.next().unwrap_or(trimmed).trim().to_owned();
304    if label.is_empty() {
305        return Err(MemoryError::InvalidInput(
306            "scene LLM returned empty label".into(),
307        ));
308    }
309    let profile = if profile.is_empty() {
310        label.clone()
311    } else {
312        profile
313    };
314    Ok((label, profile))
315}
316
317/// List all `MemScenes` from the store.
318///
319/// # Errors
320///
321/// Returns an error if the `SQLite` query fails.
322pub async fn list_scenes(store: &SqliteStore) -> Result<Vec<MemScene>, MemoryError> {
323    store.list_mem_scenes().await
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn cluster_messages_groups_similar() {
332        let v1 = vec![1.0f32, 0.0, 0.0];
333        let v2 = vec![1.0f32, 0.0, 0.0];
334        let v3 = vec![0.0f32, 1.0, 0.0];
335
336        let candidates = vec![
337            (MessageId(1), "a".to_owned(), v1),
338            (MessageId(2), "b".to_owned(), v2),
339            (MessageId(3), "c".to_owned(), v3),
340        ];
341
342        let clusters = cluster_messages(candidates, 0.80);
343        assert_eq!(clusters.len(), 2);
344        assert_eq!(clusters[0].len(), 2);
345        assert_eq!(clusters[1].len(), 1);
346    }
347
348    #[test]
349    fn parse_label_profile_valid_json() {
350        let json = r#"{"label": "Rust Auth JWT", "profile": "The project uses JWT for auth."}"#;
351        let (label, profile) = parse_label_profile(json).unwrap();
352        assert_eq!(label, "Rust Auth JWT");
353        assert_eq!(profile, "The project uses JWT for auth.");
354    }
355
356    #[test]
357    fn parse_label_profile_fallback_lines() {
358        let text = "Rust Auth\nJWT tokens used for authentication. Rate limited at 100 rps.";
359        let (label, profile) = parse_label_profile(text).unwrap();
360        assert_eq!(label, "Rust Auth");
361        assert!(profile.contains("JWT"));
362    }
363
364    #[test]
365    fn parse_label_profile_empty_fails() {
366        assert!(parse_label_profile("").is_err());
367    }
368}