1use std::sync::Arc;
10use std::time::Duration;
11
12use tokio_util::sync::CancellationToken;
13use tracing::Instrument as _;
14use zeph_llm::any::AnyProvider;
15use zeph_llm::provider::LlmProvider as _;
16
17use crate::error::MemoryError;
18use crate::store::SqliteStore;
19use crate::types::{MemSceneId, MessageId};
20use zeph_common::math::cosine_similarity;
21
22#[derive(Debug, Clone)]
27pub struct MemScene {
28 pub id: MemSceneId,
30 pub label: String,
32 pub profile: String,
34 pub member_count: u32,
36 pub created_at: i64,
38 pub updated_at: i64,
40}
41
42#[derive(Debug, Clone)]
44pub struct SceneConfig {
45 pub enabled: bool,
47 pub similarity_threshold: f32,
49 pub batch_size: usize,
51 pub sweep_interval_secs: u64,
53}
54
55pub async fn start_scene_consolidation_loop(
60 store: Arc<SqliteStore>,
61 provider: AnyProvider,
62 config: SceneConfig,
63 cancel: CancellationToken,
64) {
65 if !config.enabled {
66 tracing::debug!("scene consolidation disabled (tiers.scene_enabled = false)");
67 return;
68 }
69
70 let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
71 ticker.tick().await;
73
74 loop {
75 tokio::select! {
76 () = cancel.cancelled() => {
77 tracing::debug!("scene consolidation loop shutting down");
78 return;
79 }
80 _ = ticker.tick() => {}
81 }
82
83 tracing::debug!("scene consolidation: starting sweep");
84 let start = std::time::Instant::now();
85
86 match consolidate_scenes(&store, &provider, &config).await {
87 Ok(stats) => {
88 tracing::info!(
89 candidates = stats.candidates,
90 scenes_created = stats.scenes_created,
91 messages_assigned = stats.messages_assigned,
92 elapsed_ms = start.elapsed().as_millis(),
93 "scene consolidation: sweep complete"
94 );
95 }
96 Err(e) => {
97 tracing::warn!(
98 error = %e,
99 elapsed_ms = start.elapsed().as_millis(),
100 "scene consolidation: sweep failed, will retry"
101 );
102 }
103 }
104 }
105}
106
107#[derive(Debug, Default)]
109pub struct SceneStats {
110 pub candidates: usize,
111 pub scenes_created: usize,
112 pub messages_assigned: usize,
113}
114
115#[tracing::instrument(name = "memory.scenes.consolidation_sweep", skip_all)]
121pub async fn consolidate_scenes(
122 store: &SqliteStore,
123 provider: &AnyProvider,
124 config: &SceneConfig,
125) -> Result<SceneStats, MemoryError> {
126 let candidates = store
127 .find_unscened_semantic_messages(config.batch_size)
128 .await?;
129
130 if candidates.len() < 2 {
131 return Ok(SceneStats::default());
132 }
133
134 let mut stats = SceneStats {
135 candidates: candidates.len(),
136 ..SceneStats::default()
137 };
138
139 if !provider.supports_embeddings() {
140 return Ok(stats);
141 }
142
143 let texts: Vec<&str> = candidates.iter().map(|(_, c)| c.as_str()).collect();
145 let span = tracing::info_span!("memory.scenes.embed_batch", count = texts.len());
146 let vecs = provider.embed_batch(&texts).instrument(span).await;
147 let embedded: Vec<(MessageId, String, Vec<f32>)> = match vecs {
148 Ok(vecs) => {
149 if vecs.len() != texts.len() {
150 tracing::warn!(
151 expected = texts.len(),
152 got = vecs.len(),
153 "scene consolidation: embed_batch length mismatch, skipping sweep"
154 );
155 return Ok(stats);
156 }
157 candidates
158 .into_iter()
159 .zip(vecs)
160 .map(|((msg_id, content), vec)| (msg_id, content, vec))
161 .collect()
162 }
163 Err(e) => {
164 tracing::warn!(error = %e, "scene consolidation: batch embed failed, skipping sweep");
165 return Ok(stats);
166 }
167 };
168
169 if embedded.len() < 2 {
170 return Ok(stats);
171 }
172
173 let clusters = cluster_messages(embedded, config.similarity_threshold);
175
176 for cluster in clusters {
177 if cluster.len() < 2 {
178 continue;
179 }
180
181 let contents: Vec<&str> = cluster.iter().map(|(_, c, _)| c.as_str()).collect();
182 let msg_ids: Vec<MessageId> = cluster.iter().map(|(id, _, _)| *id).collect();
183
184 match generate_scene_label_and_profile(provider, &contents).await {
185 Ok((label, profile)) => {
186 let label = label.chars().take(100).collect::<String>();
187 let profile = profile.chars().take(2000).collect::<String>();
188 match store.insert_mem_scene(&label, &profile, &msg_ids).await {
189 Ok(_scene_id) => {
190 stats.scenes_created += 1;
191 stats.messages_assigned += msg_ids.len();
192 }
193 Err(e) => {
194 tracing::warn!(
195 error = %e,
196 cluster_size = msg_ids.len(),
197 "scene consolidation: failed to insert scene"
198 );
199 }
200 }
201 }
202 Err(e) => {
203 tracing::warn!(
204 error = %e,
205 cluster_size = msg_ids.len(),
206 "scene consolidation: LLM label generation failed, skipping cluster"
207 );
208 }
209 }
210 }
211
212 Ok(stats)
213}
214
215fn cluster_messages(
216 candidates: Vec<(MessageId, String, Vec<f32>)>,
217 threshold: f32,
218) -> Vec<Vec<(MessageId, String, Vec<f32>)>> {
219 let mut clusters: Vec<Vec<(MessageId, String, Vec<f32>)>> = Vec::new();
220
221 'outer: for candidate in candidates {
222 for cluster in &mut clusters {
223 let rep = &cluster[0].2;
224 if cosine_similarity(&candidate.2, rep) >= threshold {
225 cluster.push(candidate);
226 continue 'outer;
227 }
228 }
229 clusters.push(vec![candidate]);
230 }
231
232 clusters
233}
234
235async fn generate_scene_label_and_profile(
236 provider: &AnyProvider,
237 contents: &[&str],
238) -> Result<(String, String), MemoryError> {
239 use zeph_llm::provider::{Message, MessageMetadata, Role};
240
241 let bullet_list: String = contents
242 .iter()
243 .enumerate()
244 .map(|(i, c)| format!("{}. {c}", i + 1))
245 .collect::<Vec<_>>()
246 .join("\n");
247
248 let system_content = "You are a memory scene architect. \
249 Given a set of related semantic facts, generate:\n\
250 1. A short label (5 words max) identifying the core entity or topic.\n\
251 2. A 2-3 sentence entity profile summarizing the key facts.\n\
252 Respond in JSON: {\"label\": \"...\", \"profile\": \"...\"}";
253
254 let user_content =
255 format!("Generate a label and profile for these related facts:\n\n{bullet_list}");
256
257 let messages = vec![
258 Message {
259 role: Role::System,
260 content: system_content.to_owned(),
261 parts: vec![],
262 metadata: MessageMetadata::default(),
263 },
264 Message {
265 role: Role::User,
266 content: user_content,
267 parts: vec![],
268 metadata: MessageMetadata::default(),
269 },
270 ];
271
272 let result = tokio::time::timeout(Duration::from_secs(15), provider.chat(&messages))
273 .await
274 .map_err(|_| MemoryError::Timeout("scene LLM call timed out after 15s".into()))?
275 .map_err(MemoryError::Llm)?;
276
277 parse_label_profile(&result)
278}
279
280fn parse_label_profile(response: &str) -> Result<(String, String), MemoryError> {
281 if let Ok(val) = serde_json::from_str::<serde_json::Value>(response) {
283 let label = val
284 .get("label")
285 .and_then(|v| v.as_str())
286 .unwrap_or("")
287 .trim()
288 .to_owned();
289 let profile = val
290 .get("profile")
291 .and_then(|v| v.as_str())
292 .unwrap_or("")
293 .trim()
294 .to_owned();
295 if !label.is_empty() && !profile.is_empty() {
296 return Ok((label, profile));
297 }
298 }
299 let trimmed = response.trim();
301 let mut lines = trimmed.splitn(2, '\n');
302 let label = lines.next().unwrap_or("").trim().to_owned();
303 let profile = lines.next().unwrap_or(trimmed).trim().to_owned();
304 if label.is_empty() {
305 return Err(MemoryError::InvalidInput(
306 "scene LLM returned empty label".into(),
307 ));
308 }
309 let profile = if profile.is_empty() {
310 label.clone()
311 } else {
312 profile
313 };
314 Ok((label, profile))
315}
316
317pub async fn list_scenes(store: &SqliteStore) -> Result<Vec<MemScene>, MemoryError> {
323 store.list_mem_scenes().await
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn cluster_messages_groups_similar() {
332 let v1 = vec![1.0f32, 0.0, 0.0];
333 let v2 = vec![1.0f32, 0.0, 0.0];
334 let v3 = vec![0.0f32, 1.0, 0.0];
335
336 let candidates = vec![
337 (MessageId(1), "a".to_owned(), v1),
338 (MessageId(2), "b".to_owned(), v2),
339 (MessageId(3), "c".to_owned(), v3),
340 ];
341
342 let clusters = cluster_messages(candidates, 0.80);
343 assert_eq!(clusters.len(), 2);
344 assert_eq!(clusters[0].len(), 2);
345 assert_eq!(clusters[1].len(), 1);
346 }
347
348 #[test]
349 fn parse_label_profile_valid_json() {
350 let json = r#"{"label": "Rust Auth JWT", "profile": "The project uses JWT for auth."}"#;
351 let (label, profile) = parse_label_profile(json).unwrap();
352 assert_eq!(label, "Rust Auth JWT");
353 assert_eq!(profile, "The project uses JWT for auth.");
354 }
355
356 #[test]
357 fn parse_label_profile_fallback_lines() {
358 let text = "Rust Auth\nJWT tokens used for authentication. Rate limited at 100 rps.";
359 let (label, profile) = parse_label_profile(text).unwrap();
360 assert_eq!(label, "Rust Auth");
361 assert!(profile.contains("JWT"));
362 }
363
364 #[test]
365 fn parse_label_profile_empty_fails() {
366 assert!(parse_label_profile("").is_err());
367 }
368}