1use std::sync::Arc;
10use std::time::Duration;
11
12use tokio_util::sync::CancellationToken;
13use zeph_llm::any::AnyProvider;
14use zeph_llm::provider::LlmProvider as _;
15
16use crate::error::MemoryError;
17use crate::store::SqliteStore;
18use crate::types::{MemSceneId, MessageId};
19use zeph_common::math::cosine_similarity;
20
21#[derive(Debug, Clone)]
26pub struct MemScene {
27 pub id: MemSceneId,
29 pub label: String,
31 pub profile: String,
33 pub member_count: u32,
35 pub created_at: i64,
37 pub updated_at: i64,
39}
40
41#[derive(Debug, Clone)]
43pub struct SceneConfig {
44 pub enabled: bool,
46 pub similarity_threshold: f32,
48 pub batch_size: usize,
50 pub sweep_interval_secs: u64,
52}
53
54pub async fn start_scene_consolidation_loop(
59 store: Arc<SqliteStore>,
60 provider: AnyProvider,
61 config: SceneConfig,
62 cancel: CancellationToken,
63) {
64 if !config.enabled {
65 tracing::debug!("scene consolidation disabled (tiers.scene_enabled = false)");
66 return;
67 }
68
69 let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
70 ticker.tick().await;
72
73 loop {
74 tokio::select! {
75 () = cancel.cancelled() => {
76 tracing::debug!("scene consolidation loop shutting down");
77 return;
78 }
79 _ = ticker.tick() => {}
80 }
81
82 tracing::debug!("scene consolidation: starting sweep");
83 let start = std::time::Instant::now();
84
85 match consolidate_scenes(&store, &provider, &config).await {
86 Ok(stats) => {
87 tracing::info!(
88 candidates = stats.candidates,
89 scenes_created = stats.scenes_created,
90 messages_assigned = stats.messages_assigned,
91 elapsed_ms = start.elapsed().as_millis(),
92 "scene consolidation: sweep complete"
93 );
94 }
95 Err(e) => {
96 tracing::warn!(
97 error = %e,
98 elapsed_ms = start.elapsed().as_millis(),
99 "scene consolidation: sweep failed, will retry"
100 );
101 }
102 }
103 }
104}
105
106#[derive(Debug, Default)]
108pub struct SceneStats {
109 pub candidates: usize,
110 pub scenes_created: usize,
111 pub messages_assigned: usize,
112}
113
114#[cfg_attr(
120 feature = "profiling",
121 tracing::instrument(name = "memory.consolidate_scenes", skip_all)
122)]
123pub async fn consolidate_scenes(
124 store: &SqliteStore,
125 provider: &AnyProvider,
126 config: &SceneConfig,
127) -> Result<SceneStats, MemoryError> {
128 let candidates = store
129 .find_unscened_semantic_messages(config.batch_size)
130 .await?;
131
132 if candidates.len() < 2 {
133 return Ok(SceneStats::default());
134 }
135
136 let mut stats = SceneStats {
137 candidates: candidates.len(),
138 ..SceneStats::default()
139 };
140
141 let mut embedded: Vec<(MessageId, String, Vec<f32>)> = Vec::with_capacity(candidates.len());
143 if provider.supports_embeddings() {
144 for (msg_id, content) in candidates {
145 match provider.embed(&content).await {
146 Ok(vec) => embedded.push((msg_id, content, vec)),
147 Err(e) => {
148 tracing::warn!(
149 message_id = msg_id.0,
150 error = %e,
151 "scene consolidation: failed to embed candidate, skipping"
152 );
153 }
154 }
155 }
156 } else {
157 return Ok(stats);
158 }
159
160 if embedded.len() < 2 {
161 return Ok(stats);
162 }
163
164 let clusters = cluster_messages(embedded, config.similarity_threshold);
166
167 for cluster in clusters {
168 if cluster.len() < 2 {
169 continue;
170 }
171
172 let contents: Vec<&str> = cluster.iter().map(|(_, c, _)| c.as_str()).collect();
173 let msg_ids: Vec<MessageId> = cluster.iter().map(|(id, _, _)| *id).collect();
174
175 match generate_scene_label_and_profile(provider, &contents).await {
176 Ok((label, profile)) => {
177 let label = label.chars().take(100).collect::<String>();
178 let profile = profile.chars().take(2000).collect::<String>();
179 match store.insert_mem_scene(&label, &profile, &msg_ids).await {
180 Ok(_scene_id) => {
181 stats.scenes_created += 1;
182 stats.messages_assigned += msg_ids.len();
183 }
184 Err(e) => {
185 tracing::warn!(
186 error = %e,
187 cluster_size = msg_ids.len(),
188 "scene consolidation: failed to insert scene"
189 );
190 }
191 }
192 }
193 Err(e) => {
194 tracing::warn!(
195 error = %e,
196 cluster_size = msg_ids.len(),
197 "scene consolidation: LLM label generation failed, skipping cluster"
198 );
199 }
200 }
201 }
202
203 Ok(stats)
204}
205
206fn cluster_messages(
207 candidates: Vec<(MessageId, String, Vec<f32>)>,
208 threshold: f32,
209) -> Vec<Vec<(MessageId, String, Vec<f32>)>> {
210 let mut clusters: Vec<Vec<(MessageId, String, Vec<f32>)>> = Vec::new();
211
212 'outer: for candidate in candidates {
213 for cluster in &mut clusters {
214 let rep = &cluster[0].2;
215 if cosine_similarity(&candidate.2, rep) >= threshold {
216 cluster.push(candidate);
217 continue 'outer;
218 }
219 }
220 clusters.push(vec![candidate]);
221 }
222
223 clusters
224}
225
226async fn generate_scene_label_and_profile(
227 provider: &AnyProvider,
228 contents: &[&str],
229) -> Result<(String, String), MemoryError> {
230 use zeph_llm::provider::{Message, MessageMetadata, Role};
231
232 let bullet_list: String = contents
233 .iter()
234 .enumerate()
235 .map(|(i, c)| format!("{}. {c}", i + 1))
236 .collect::<Vec<_>>()
237 .join("\n");
238
239 let system_content = "You are a memory scene architect. \
240 Given a set of related semantic facts, generate:\n\
241 1. A short label (5 words max) identifying the core entity or topic.\n\
242 2. A 2-3 sentence entity profile summarizing the key facts.\n\
243 Respond in JSON: {\"label\": \"...\", \"profile\": \"...\"}";
244
245 let user_content =
246 format!("Generate a label and profile for these related facts:\n\n{bullet_list}");
247
248 let messages = vec![
249 Message {
250 role: Role::System,
251 content: system_content.to_owned(),
252 parts: vec![],
253 metadata: MessageMetadata::default(),
254 },
255 Message {
256 role: Role::User,
257 content: user_content,
258 parts: vec![],
259 metadata: MessageMetadata::default(),
260 },
261 ];
262
263 let result = tokio::time::timeout(Duration::from_secs(15), provider.chat(&messages))
264 .await
265 .map_err(|_| MemoryError::Other("scene LLM call timed out after 15s".into()))?
266 .map_err(MemoryError::Llm)?;
267
268 parse_label_profile(&result)
269}
270
271fn parse_label_profile(response: &str) -> Result<(String, String), MemoryError> {
272 if let Ok(val) = serde_json::from_str::<serde_json::Value>(response) {
274 let label = val
275 .get("label")
276 .and_then(|v| v.as_str())
277 .unwrap_or("")
278 .trim()
279 .to_owned();
280 let profile = val
281 .get("profile")
282 .and_then(|v| v.as_str())
283 .unwrap_or("")
284 .trim()
285 .to_owned();
286 if !label.is_empty() && !profile.is_empty() {
287 return Ok((label, profile));
288 }
289 }
290 let trimmed = response.trim();
292 let mut lines = trimmed.splitn(2, '\n');
293 let label = lines.next().unwrap_or("").trim().to_owned();
294 let profile = lines.next().unwrap_or(trimmed).trim().to_owned();
295 if label.is_empty() {
296 return Err(MemoryError::Other("scene LLM returned empty label".into()));
297 }
298 let profile = if profile.is_empty() {
299 label.clone()
300 } else {
301 profile
302 };
303 Ok((label, profile))
304}
305
306pub async fn list_scenes(store: &SqliteStore) -> Result<Vec<MemScene>, MemoryError> {
312 store.list_mem_scenes().await
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn cluster_messages_groups_similar() {
321 let v1 = vec![1.0f32, 0.0, 0.0];
322 let v2 = vec![1.0f32, 0.0, 0.0];
323 let v3 = vec![0.0f32, 1.0, 0.0];
324
325 let candidates = vec![
326 (MessageId(1), "a".to_owned(), v1),
327 (MessageId(2), "b".to_owned(), v2),
328 (MessageId(3), "c".to_owned(), v3),
329 ];
330
331 let clusters = cluster_messages(candidates, 0.80);
332 assert_eq!(clusters.len(), 2);
333 assert_eq!(clusters[0].len(), 2);
334 assert_eq!(clusters[1].len(), 1);
335 }
336
337 #[test]
338 fn parse_label_profile_valid_json() {
339 let json = r#"{"label": "Rust Auth JWT", "profile": "The project uses JWT for auth."}"#;
340 let (label, profile) = parse_label_profile(json).unwrap();
341 assert_eq!(label, "Rust Auth JWT");
342 assert_eq!(profile, "The project uses JWT for auth.");
343 }
344
345 #[test]
346 fn parse_label_profile_fallback_lines() {
347 let text = "Rust Auth\nJWT tokens used for authentication. Rate limited at 100 rps.";
348 let (label, profile) = parse_label_profile(text).unwrap();
349 assert_eq!(label, "Rust Auth");
350 assert!(profile.contains("JWT"));
351 }
352
353 #[test]
354 fn parse_label_profile_empty_fails() {
355 assert!(parse_label_profile("").is_err());
356 }
357}