Skip to main content

zeph_memory/
scenes.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! `MemScene` consolidation (#2332).
5//!
6//! Groups semantically related semantic-tier messages into stable entity profiles (scenes).
7//! Runs as a separate background loop, decoupled from tier promotion timing.
8
9use std::sync::Arc;
10use std::time::Duration;
11
12use tokio::task::JoinHandle;
13use tokio_util::sync::CancellationToken;
14use zeph_llm::any::AnyProvider;
15use zeph_llm::provider::LlmProvider as _;
16
17use crate::error::MemoryError;
18use crate::store::SqliteStore;
19use crate::types::{MemSceneId, MessageId};
20use zeph_common::math::cosine_similarity;
21
22/// A `MemScene` groups semantically related semantic-tier messages with a stable entity profile.
23#[derive(Debug, Clone)]
24pub struct MemScene {
25    pub id: MemSceneId,
26    pub label: String,
27    pub profile: String,
28    pub member_count: u32,
29    pub created_at: i64,
30    pub updated_at: i64,
31}
32
33/// Configuration for scene consolidation.
34#[derive(Debug, Clone)]
35pub struct SceneConfig {
36    pub enabled: bool,
37    pub similarity_threshold: f32,
38    pub batch_size: usize,
39    pub sweep_interval_secs: u64,
40}
41
42/// Start the background scene consolidation loop.
43///
44/// Each sweep clusters unassigned semantic-tier messages into `MemScenes`.
45/// Runs independently from the tier promotion loop.
46pub fn start_scene_consolidation_loop(
47    store: Arc<SqliteStore>,
48    provider: AnyProvider,
49    config: SceneConfig,
50    cancel: CancellationToken,
51) -> JoinHandle<()> {
52    tokio::spawn(async move {
53        if !config.enabled {
54            tracing::debug!("scene consolidation disabled (tiers.scene_enabled = false)");
55            return;
56        }
57
58        let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
59        // Skip first tick to avoid running immediately at startup.
60        ticker.tick().await;
61
62        loop {
63            tokio::select! {
64                () = cancel.cancelled() => {
65                    tracing::debug!("scene consolidation loop shutting down");
66                    return;
67                }
68                _ = ticker.tick() => {}
69            }
70
71            tracing::debug!("scene consolidation: starting sweep");
72            let start = std::time::Instant::now();
73
74            match consolidate_scenes(&store, &provider, &config).await {
75                Ok(stats) => {
76                    tracing::info!(
77                        candidates = stats.candidates,
78                        scenes_created = stats.scenes_created,
79                        messages_assigned = stats.messages_assigned,
80                        elapsed_ms = start.elapsed().as_millis(),
81                        "scene consolidation: sweep complete"
82                    );
83                }
84                Err(e) => {
85                    tracing::warn!(
86                        error = %e,
87                        elapsed_ms = start.elapsed().as_millis(),
88                        "scene consolidation: sweep failed, will retry"
89                    );
90                }
91            }
92        }
93    })
94}
95
96/// Stats collected during a single scene consolidation sweep.
97#[derive(Debug, Default)]
98pub struct SceneStats {
99    pub candidates: usize,
100    pub scenes_created: usize,
101    pub messages_assigned: usize,
102}
103
104/// Execute one full scene consolidation sweep.
105///
106/// # Errors
107///
108/// Returns an error if the `SQLite` query fails. LLM and embedding errors are logged but skipped.
109pub async fn consolidate_scenes(
110    store: &SqliteStore,
111    provider: &AnyProvider,
112    config: &SceneConfig,
113) -> Result<SceneStats, MemoryError> {
114    let candidates = store
115        .find_unscened_semantic_messages(config.batch_size)
116        .await?;
117
118    if candidates.len() < 2 {
119        return Ok(SceneStats::default());
120    }
121
122    let mut stats = SceneStats {
123        candidates: candidates.len(),
124        ..SceneStats::default()
125    };
126
127    // Embed all candidates.
128    let mut embedded: Vec<(MessageId, String, Vec<f32>)> = Vec::with_capacity(candidates.len());
129    if provider.supports_embeddings() {
130        for (msg_id, content) in candidates {
131            match provider.embed(&content).await {
132                Ok(vec) => embedded.push((msg_id, content, vec)),
133                Err(e) => {
134                    tracing::warn!(
135                        message_id = msg_id.0,
136                        error = %e,
137                        "scene consolidation: failed to embed candidate, skipping"
138                    );
139                }
140            }
141        }
142    } else {
143        return Ok(stats);
144    }
145
146    if embedded.len() < 2 {
147        return Ok(stats);
148    }
149
150    // Cluster by cosine similarity.
151    let clusters = cluster_messages(embedded, config.similarity_threshold);
152
153    for cluster in clusters {
154        if cluster.len() < 2 {
155            continue;
156        }
157
158        let contents: Vec<&str> = cluster.iter().map(|(_, c, _)| c.as_str()).collect();
159        let msg_ids: Vec<MessageId> = cluster.iter().map(|(id, _, _)| *id).collect();
160
161        match generate_scene_label_and_profile(provider, &contents).await {
162            Ok((label, profile)) => {
163                let label = label.chars().take(100).collect::<String>();
164                let profile = profile.chars().take(2000).collect::<String>();
165                match store.insert_mem_scene(&label, &profile, &msg_ids).await {
166                    Ok(_scene_id) => {
167                        stats.scenes_created += 1;
168                        stats.messages_assigned += msg_ids.len();
169                    }
170                    Err(e) => {
171                        tracing::warn!(
172                            error = %e,
173                            cluster_size = msg_ids.len(),
174                            "scene consolidation: failed to insert scene"
175                        );
176                    }
177                }
178            }
179            Err(e) => {
180                tracing::warn!(
181                    error = %e,
182                    cluster_size = msg_ids.len(),
183                    "scene consolidation: LLM label generation failed, skipping cluster"
184                );
185            }
186        }
187    }
188
189    Ok(stats)
190}
191
192fn cluster_messages(
193    candidates: Vec<(MessageId, String, Vec<f32>)>,
194    threshold: f32,
195) -> Vec<Vec<(MessageId, String, Vec<f32>)>> {
196    let mut clusters: Vec<Vec<(MessageId, String, Vec<f32>)>> = Vec::new();
197
198    'outer: for candidate in candidates {
199        for cluster in &mut clusters {
200            let rep = &cluster[0].2;
201            if cosine_similarity(&candidate.2, rep) >= threshold {
202                cluster.push(candidate);
203                continue 'outer;
204            }
205        }
206        clusters.push(vec![candidate]);
207    }
208
209    clusters
210}
211
212async fn generate_scene_label_and_profile(
213    provider: &AnyProvider,
214    contents: &[&str],
215) -> Result<(String, String), MemoryError> {
216    use zeph_llm::provider::{Message, MessageMetadata, Role};
217
218    let bullet_list: String = contents
219        .iter()
220        .enumerate()
221        .map(|(i, c)| format!("{}. {c}", i + 1))
222        .collect::<Vec<_>>()
223        .join("\n");
224
225    let system_content = "You are a memory scene architect. \
226        Given a set of related semantic facts, generate:\n\
227        1. A short label (5 words max) identifying the core entity or topic.\n\
228        2. A 2-3 sentence entity profile summarizing the key facts.\n\
229        Respond in JSON: {\"label\": \"...\", \"profile\": \"...\"}";
230
231    let user_content =
232        format!("Generate a label and profile for these related facts:\n\n{bullet_list}");
233
234    let messages = vec![
235        Message {
236            role: Role::System,
237            content: system_content.to_owned(),
238            parts: vec![],
239            metadata: MessageMetadata::default(),
240        },
241        Message {
242            role: Role::User,
243            content: user_content,
244            parts: vec![],
245            metadata: MessageMetadata::default(),
246        },
247    ];
248
249    let result = tokio::time::timeout(Duration::from_secs(15), provider.chat(&messages))
250        .await
251        .map_err(|_| MemoryError::Other("scene LLM call timed out after 15s".into()))?
252        .map_err(MemoryError::Llm)?;
253
254    parse_label_profile(&result)
255}
256
257fn parse_label_profile(response: &str) -> Result<(String, String), MemoryError> {
258    // Try JSON parsing first.
259    if let Ok(val) = serde_json::from_str::<serde_json::Value>(response) {
260        let label = val
261            .get("label")
262            .and_then(|v| v.as_str())
263            .unwrap_or("")
264            .trim()
265            .to_owned();
266        let profile = val
267            .get("profile")
268            .and_then(|v| v.as_str())
269            .unwrap_or("")
270            .trim()
271            .to_owned();
272        if !label.is_empty() && !profile.is_empty() {
273            return Ok((label, profile));
274        }
275    }
276    // Fallback: treat first line as label, rest as profile.
277    let trimmed = response.trim();
278    let mut lines = trimmed.splitn(2, '\n');
279    let label = lines.next().unwrap_or("").trim().to_owned();
280    let profile = lines.next().unwrap_or(trimmed).trim().to_owned();
281    if label.is_empty() {
282        return Err(MemoryError::Other("scene LLM returned empty label".into()));
283    }
284    let profile = if profile.is_empty() {
285        label.clone()
286    } else {
287        profile
288    };
289    Ok((label, profile))
290}
291
292/// List all `MemScenes` from the store.
293///
294/// # Errors
295///
296/// Returns an error if the `SQLite` query fails.
297pub async fn list_scenes(store: &SqliteStore) -> Result<Vec<MemScene>, MemoryError> {
298    store.list_mem_scenes().await
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn cluster_messages_groups_similar() {
307        let v1 = vec![1.0f32, 0.0, 0.0];
308        let v2 = vec![1.0f32, 0.0, 0.0];
309        let v3 = vec![0.0f32, 1.0, 0.0];
310
311        let candidates = vec![
312            (MessageId(1), "a".to_owned(), v1),
313            (MessageId(2), "b".to_owned(), v2),
314            (MessageId(3), "c".to_owned(), v3),
315        ];
316
317        let clusters = cluster_messages(candidates, 0.80);
318        assert_eq!(clusters.len(), 2);
319        assert_eq!(clusters[0].len(), 2);
320        assert_eq!(clusters[1].len(), 1);
321    }
322
323    #[test]
324    fn parse_label_profile_valid_json() {
325        let json = r#"{"label": "Rust Auth JWT", "profile": "The project uses JWT for auth."}"#;
326        let (label, profile) = parse_label_profile(json).unwrap();
327        assert_eq!(label, "Rust Auth JWT");
328        assert_eq!(profile, "The project uses JWT for auth.");
329    }
330
331    #[test]
332    fn parse_label_profile_fallback_lines() {
333        let text = "Rust Auth\nJWT tokens used for authentication. Rate limited at 100 rps.";
334        let (label, profile) = parse_label_profile(text).unwrap();
335        assert_eq!(label, "Rust Auth");
336        assert!(profile.contains("JWT"));
337    }
338
339    #[test]
340    fn parse_label_profile_empty_fails() {
341        assert!(parse_label_profile("").is_err());
342    }
343}