Skip to main content

zeph_memory/
scenes.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! `MemScene` consolidation (#2332).
5//!
6//! Groups semantically related semantic-tier messages into stable entity profiles (scenes).
7//! Runs as a separate background loop, decoupled from tier promotion timing.
8
9use std::sync::Arc;
10use std::time::Duration;
11
12use tokio::task::JoinHandle;
13use tokio_util::sync::CancellationToken;
14use zeph_llm::any::AnyProvider;
15use zeph_llm::provider::LlmProvider as _;
16
17use crate::error::MemoryError;
18use crate::store::SqliteStore;
19use crate::types::{MemSceneId, MessageId};
20use zeph_common::math::cosine_similarity;
21
22/// A `MemScene` groups semantically related semantic-tier messages with a stable entity profile.
23///
24/// Scenes are created and updated by [`start_scene_consolidation_loop`] and can be
25/// listed via [`list_scenes`].
26#[derive(Debug, Clone)]
27pub struct MemScene {
28    /// `SQLite` row ID of the scene.
29    pub id: MemSceneId,
30    /// Short human-readable label for the scene (e.g. `"Rust programming"`).
31    pub label: String,
32    /// LLM-generated entity profile summarising the scene's members.
33    pub profile: String,
34    /// Number of messages currently assigned to this scene.
35    pub member_count: u32,
36    /// Unix timestamp when the scene was first created.
37    pub created_at: i64,
38    /// Unix timestamp of the last profile update.
39    pub updated_at: i64,
40}
41
42/// Configuration for the scene consolidation background loop.
43#[derive(Debug, Clone)]
44pub struct SceneConfig {
45    /// Enable or disable the scene consolidation loop.
46    pub enabled: bool,
47    /// Minimum cosine similarity for two messages to be assigned to the same scene.
48    pub similarity_threshold: f32,
49    /// Maximum number of unassigned messages to process per sweep.
50    pub batch_size: usize,
51    /// How often to run a sweep, in seconds.
52    pub sweep_interval_secs: u64,
53}
54
55/// Start the background scene consolidation loop.
56///
57/// Each sweep clusters unassigned semantic-tier messages into `MemScenes`.
58/// Runs independently from the tier promotion loop.
59pub fn start_scene_consolidation_loop(
60    store: Arc<SqliteStore>,
61    provider: AnyProvider,
62    config: SceneConfig,
63    cancel: CancellationToken,
64) -> JoinHandle<()> {
65    tokio::spawn(async move {
66        if !config.enabled {
67            tracing::debug!("scene consolidation disabled (tiers.scene_enabled = false)");
68            return;
69        }
70
71        let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
72        // Skip first tick to avoid running immediately at startup.
73        ticker.tick().await;
74
75        loop {
76            tokio::select! {
77                () = cancel.cancelled() => {
78                    tracing::debug!("scene consolidation loop shutting down");
79                    return;
80                }
81                _ = ticker.tick() => {}
82            }
83
84            tracing::debug!("scene consolidation: starting sweep");
85            let start = std::time::Instant::now();
86
87            match consolidate_scenes(&store, &provider, &config).await {
88                Ok(stats) => {
89                    tracing::info!(
90                        candidates = stats.candidates,
91                        scenes_created = stats.scenes_created,
92                        messages_assigned = stats.messages_assigned,
93                        elapsed_ms = start.elapsed().as_millis(),
94                        "scene consolidation: sweep complete"
95                    );
96                }
97                Err(e) => {
98                    tracing::warn!(
99                        error = %e,
100                        elapsed_ms = start.elapsed().as_millis(),
101                        "scene consolidation: sweep failed, will retry"
102                    );
103                }
104            }
105        }
106    })
107}
108
109/// Stats collected during a single scene consolidation sweep.
110#[derive(Debug, Default)]
111pub struct SceneStats {
112    pub candidates: usize,
113    pub scenes_created: usize,
114    pub messages_assigned: usize,
115}
116
117/// Execute one full scene consolidation sweep.
118///
119/// # Errors
120///
121/// Returns an error if the `SQLite` query fails. LLM and embedding errors are logged but skipped.
122#[cfg_attr(
123    feature = "profiling",
124    tracing::instrument(name = "memory.consolidate_scenes", skip_all)
125)]
126pub async fn consolidate_scenes(
127    store: &SqliteStore,
128    provider: &AnyProvider,
129    config: &SceneConfig,
130) -> Result<SceneStats, MemoryError> {
131    let candidates = store
132        .find_unscened_semantic_messages(config.batch_size)
133        .await?;
134
135    if candidates.len() < 2 {
136        return Ok(SceneStats::default());
137    }
138
139    let mut stats = SceneStats {
140        candidates: candidates.len(),
141        ..SceneStats::default()
142    };
143
144    // Embed all candidates.
145    let mut embedded: Vec<(MessageId, String, Vec<f32>)> = Vec::with_capacity(candidates.len());
146    if provider.supports_embeddings() {
147        for (msg_id, content) in candidates {
148            match provider.embed(&content).await {
149                Ok(vec) => embedded.push((msg_id, content, vec)),
150                Err(e) => {
151                    tracing::warn!(
152                        message_id = msg_id.0,
153                        error = %e,
154                        "scene consolidation: failed to embed candidate, skipping"
155                    );
156                }
157            }
158        }
159    } else {
160        return Ok(stats);
161    }
162
163    if embedded.len() < 2 {
164        return Ok(stats);
165    }
166
167    // Cluster by cosine similarity.
168    let clusters = cluster_messages(embedded, config.similarity_threshold);
169
170    for cluster in clusters {
171        if cluster.len() < 2 {
172            continue;
173        }
174
175        let contents: Vec<&str> = cluster.iter().map(|(_, c, _)| c.as_str()).collect();
176        let msg_ids: Vec<MessageId> = cluster.iter().map(|(id, _, _)| *id).collect();
177
178        match generate_scene_label_and_profile(provider, &contents).await {
179            Ok((label, profile)) => {
180                let label = label.chars().take(100).collect::<String>();
181                let profile = profile.chars().take(2000).collect::<String>();
182                match store.insert_mem_scene(&label, &profile, &msg_ids).await {
183                    Ok(_scene_id) => {
184                        stats.scenes_created += 1;
185                        stats.messages_assigned += msg_ids.len();
186                    }
187                    Err(e) => {
188                        tracing::warn!(
189                            error = %e,
190                            cluster_size = msg_ids.len(),
191                            "scene consolidation: failed to insert scene"
192                        );
193                    }
194                }
195            }
196            Err(e) => {
197                tracing::warn!(
198                    error = %e,
199                    cluster_size = msg_ids.len(),
200                    "scene consolidation: LLM label generation failed, skipping cluster"
201                );
202            }
203        }
204    }
205
206    Ok(stats)
207}
208
209fn cluster_messages(
210    candidates: Vec<(MessageId, String, Vec<f32>)>,
211    threshold: f32,
212) -> Vec<Vec<(MessageId, String, Vec<f32>)>> {
213    let mut clusters: Vec<Vec<(MessageId, String, Vec<f32>)>> = Vec::new();
214
215    'outer: for candidate in candidates {
216        for cluster in &mut clusters {
217            let rep = &cluster[0].2;
218            if cosine_similarity(&candidate.2, rep) >= threshold {
219                cluster.push(candidate);
220                continue 'outer;
221            }
222        }
223        clusters.push(vec![candidate]);
224    }
225
226    clusters
227}
228
229async fn generate_scene_label_and_profile(
230    provider: &AnyProvider,
231    contents: &[&str],
232) -> Result<(String, String), MemoryError> {
233    use zeph_llm::provider::{Message, MessageMetadata, Role};
234
235    let bullet_list: String = contents
236        .iter()
237        .enumerate()
238        .map(|(i, c)| format!("{}. {c}", i + 1))
239        .collect::<Vec<_>>()
240        .join("\n");
241
242    let system_content = "You are a memory scene architect. \
243        Given a set of related semantic facts, generate:\n\
244        1. A short label (5 words max) identifying the core entity or topic.\n\
245        2. A 2-3 sentence entity profile summarizing the key facts.\n\
246        Respond in JSON: {\"label\": \"...\", \"profile\": \"...\"}";
247
248    let user_content =
249        format!("Generate a label and profile for these related facts:\n\n{bullet_list}");
250
251    let messages = vec![
252        Message {
253            role: Role::System,
254            content: system_content.to_owned(),
255            parts: vec![],
256            metadata: MessageMetadata::default(),
257        },
258        Message {
259            role: Role::User,
260            content: user_content,
261            parts: vec![],
262            metadata: MessageMetadata::default(),
263        },
264    ];
265
266    let result = tokio::time::timeout(Duration::from_secs(15), provider.chat(&messages))
267        .await
268        .map_err(|_| MemoryError::Other("scene LLM call timed out after 15s".into()))?
269        .map_err(MemoryError::Llm)?;
270
271    parse_label_profile(&result)
272}
273
274fn parse_label_profile(response: &str) -> Result<(String, String), MemoryError> {
275    // Try JSON parsing first.
276    if let Ok(val) = serde_json::from_str::<serde_json::Value>(response) {
277        let label = val
278            .get("label")
279            .and_then(|v| v.as_str())
280            .unwrap_or("")
281            .trim()
282            .to_owned();
283        let profile = val
284            .get("profile")
285            .and_then(|v| v.as_str())
286            .unwrap_or("")
287            .trim()
288            .to_owned();
289        if !label.is_empty() && !profile.is_empty() {
290            return Ok((label, profile));
291        }
292    }
293    // Fallback: treat first line as label, rest as profile.
294    let trimmed = response.trim();
295    let mut lines = trimmed.splitn(2, '\n');
296    let label = lines.next().unwrap_or("").trim().to_owned();
297    let profile = lines.next().unwrap_or(trimmed).trim().to_owned();
298    if label.is_empty() {
299        return Err(MemoryError::Other("scene LLM returned empty label".into()));
300    }
301    let profile = if profile.is_empty() {
302        label.clone()
303    } else {
304        profile
305    };
306    Ok((label, profile))
307}
308
309/// List all `MemScenes` from the store.
310///
311/// # Errors
312///
313/// Returns an error if the `SQLite` query fails.
314pub async fn list_scenes(store: &SqliteStore) -> Result<Vec<MemScene>, MemoryError> {
315    store.list_mem_scenes().await
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn cluster_messages_groups_similar() {
324        let v1 = vec![1.0f32, 0.0, 0.0];
325        let v2 = vec![1.0f32, 0.0, 0.0];
326        let v3 = vec![0.0f32, 1.0, 0.0];
327
328        let candidates = vec![
329            (MessageId(1), "a".to_owned(), v1),
330            (MessageId(2), "b".to_owned(), v2),
331            (MessageId(3), "c".to_owned(), v3),
332        ];
333
334        let clusters = cluster_messages(candidates, 0.80);
335        assert_eq!(clusters.len(), 2);
336        assert_eq!(clusters[0].len(), 2);
337        assert_eq!(clusters[1].len(), 1);
338    }
339
340    #[test]
341    fn parse_label_profile_valid_json() {
342        let json = r#"{"label": "Rust Auth JWT", "profile": "The project uses JWT for auth."}"#;
343        let (label, profile) = parse_label_profile(json).unwrap();
344        assert_eq!(label, "Rust Auth JWT");
345        assert_eq!(profile, "The project uses JWT for auth.");
346    }
347
348    #[test]
349    fn parse_label_profile_fallback_lines() {
350        let text = "Rust Auth\nJWT tokens used for authentication. Rate limited at 100 rps.";
351        let (label, profile) = parse_label_profile(text).unwrap();
352        assert_eq!(label, "Rust Auth");
353        assert!(profile.contains("JWT"));
354    }
355
356    #[test]
357    fn parse_label_profile_empty_fails() {
358        assert!(parse_label_profile("").is_err());
359    }
360}