1use std::sync::Arc;
10use std::time::Duration;
11
12use tokio::task::JoinHandle;
13use tokio_util::sync::CancellationToken;
14use zeph_llm::any::AnyProvider;
15use zeph_llm::provider::LlmProvider as _;
16
17use crate::error::MemoryError;
18use crate::store::SqliteStore;
19use crate::types::{MemSceneId, MessageId};
20use zeph_common::math::cosine_similarity;
21
22#[derive(Debug, Clone)]
27pub struct MemScene {
28 pub id: MemSceneId,
30 pub label: String,
32 pub profile: String,
34 pub member_count: u32,
36 pub created_at: i64,
38 pub updated_at: i64,
40}
41
42#[derive(Debug, Clone)]
44pub struct SceneConfig {
45 pub enabled: bool,
47 pub similarity_threshold: f32,
49 pub batch_size: usize,
51 pub sweep_interval_secs: u64,
53}
54
55pub fn start_scene_consolidation_loop(
60 store: Arc<SqliteStore>,
61 provider: AnyProvider,
62 config: SceneConfig,
63 cancel: CancellationToken,
64) -> JoinHandle<()> {
65 tokio::spawn(async move {
66 if !config.enabled {
67 tracing::debug!("scene consolidation disabled (tiers.scene_enabled = false)");
68 return;
69 }
70
71 let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
72 ticker.tick().await;
74
75 loop {
76 tokio::select! {
77 () = cancel.cancelled() => {
78 tracing::debug!("scene consolidation loop shutting down");
79 return;
80 }
81 _ = ticker.tick() => {}
82 }
83
84 tracing::debug!("scene consolidation: starting sweep");
85 let start = std::time::Instant::now();
86
87 match consolidate_scenes(&store, &provider, &config).await {
88 Ok(stats) => {
89 tracing::info!(
90 candidates = stats.candidates,
91 scenes_created = stats.scenes_created,
92 messages_assigned = stats.messages_assigned,
93 elapsed_ms = start.elapsed().as_millis(),
94 "scene consolidation: sweep complete"
95 );
96 }
97 Err(e) => {
98 tracing::warn!(
99 error = %e,
100 elapsed_ms = start.elapsed().as_millis(),
101 "scene consolidation: sweep failed, will retry"
102 );
103 }
104 }
105 }
106 })
107}
108
109#[derive(Debug, Default)]
111pub struct SceneStats {
112 pub candidates: usize,
113 pub scenes_created: usize,
114 pub messages_assigned: usize,
115}
116
117#[cfg_attr(
123 feature = "profiling",
124 tracing::instrument(name = "memory.consolidate_scenes", skip_all)
125)]
126pub async fn consolidate_scenes(
127 store: &SqliteStore,
128 provider: &AnyProvider,
129 config: &SceneConfig,
130) -> Result<SceneStats, MemoryError> {
131 let candidates = store
132 .find_unscened_semantic_messages(config.batch_size)
133 .await?;
134
135 if candidates.len() < 2 {
136 return Ok(SceneStats::default());
137 }
138
139 let mut stats = SceneStats {
140 candidates: candidates.len(),
141 ..SceneStats::default()
142 };
143
144 let mut embedded: Vec<(MessageId, String, Vec<f32>)> = Vec::with_capacity(candidates.len());
146 if provider.supports_embeddings() {
147 for (msg_id, content) in candidates {
148 match provider.embed(&content).await {
149 Ok(vec) => embedded.push((msg_id, content, vec)),
150 Err(e) => {
151 tracing::warn!(
152 message_id = msg_id.0,
153 error = %e,
154 "scene consolidation: failed to embed candidate, skipping"
155 );
156 }
157 }
158 }
159 } else {
160 return Ok(stats);
161 }
162
163 if embedded.len() < 2 {
164 return Ok(stats);
165 }
166
167 let clusters = cluster_messages(embedded, config.similarity_threshold);
169
170 for cluster in clusters {
171 if cluster.len() < 2 {
172 continue;
173 }
174
175 let contents: Vec<&str> = cluster.iter().map(|(_, c, _)| c.as_str()).collect();
176 let msg_ids: Vec<MessageId> = cluster.iter().map(|(id, _, _)| *id).collect();
177
178 match generate_scene_label_and_profile(provider, &contents).await {
179 Ok((label, profile)) => {
180 let label = label.chars().take(100).collect::<String>();
181 let profile = profile.chars().take(2000).collect::<String>();
182 match store.insert_mem_scene(&label, &profile, &msg_ids).await {
183 Ok(_scene_id) => {
184 stats.scenes_created += 1;
185 stats.messages_assigned += msg_ids.len();
186 }
187 Err(e) => {
188 tracing::warn!(
189 error = %e,
190 cluster_size = msg_ids.len(),
191 "scene consolidation: failed to insert scene"
192 );
193 }
194 }
195 }
196 Err(e) => {
197 tracing::warn!(
198 error = %e,
199 cluster_size = msg_ids.len(),
200 "scene consolidation: LLM label generation failed, skipping cluster"
201 );
202 }
203 }
204 }
205
206 Ok(stats)
207}
208
209fn cluster_messages(
210 candidates: Vec<(MessageId, String, Vec<f32>)>,
211 threshold: f32,
212) -> Vec<Vec<(MessageId, String, Vec<f32>)>> {
213 let mut clusters: Vec<Vec<(MessageId, String, Vec<f32>)>> = Vec::new();
214
215 'outer: for candidate in candidates {
216 for cluster in &mut clusters {
217 let rep = &cluster[0].2;
218 if cosine_similarity(&candidate.2, rep) >= threshold {
219 cluster.push(candidate);
220 continue 'outer;
221 }
222 }
223 clusters.push(vec![candidate]);
224 }
225
226 clusters
227}
228
229async fn generate_scene_label_and_profile(
230 provider: &AnyProvider,
231 contents: &[&str],
232) -> Result<(String, String), MemoryError> {
233 use zeph_llm::provider::{Message, MessageMetadata, Role};
234
235 let bullet_list: String = contents
236 .iter()
237 .enumerate()
238 .map(|(i, c)| format!("{}. {c}", i + 1))
239 .collect::<Vec<_>>()
240 .join("\n");
241
242 let system_content = "You are a memory scene architect. \
243 Given a set of related semantic facts, generate:\n\
244 1. A short label (5 words max) identifying the core entity or topic.\n\
245 2. A 2-3 sentence entity profile summarizing the key facts.\n\
246 Respond in JSON: {\"label\": \"...\", \"profile\": \"...\"}";
247
248 let user_content =
249 format!("Generate a label and profile for these related facts:\n\n{bullet_list}");
250
251 let messages = vec![
252 Message {
253 role: Role::System,
254 content: system_content.to_owned(),
255 parts: vec![],
256 metadata: MessageMetadata::default(),
257 },
258 Message {
259 role: Role::User,
260 content: user_content,
261 parts: vec![],
262 metadata: MessageMetadata::default(),
263 },
264 ];
265
266 let result = tokio::time::timeout(Duration::from_secs(15), provider.chat(&messages))
267 .await
268 .map_err(|_| MemoryError::Other("scene LLM call timed out after 15s".into()))?
269 .map_err(MemoryError::Llm)?;
270
271 parse_label_profile(&result)
272}
273
274fn parse_label_profile(response: &str) -> Result<(String, String), MemoryError> {
275 if let Ok(val) = serde_json::from_str::<serde_json::Value>(response) {
277 let label = val
278 .get("label")
279 .and_then(|v| v.as_str())
280 .unwrap_or("")
281 .trim()
282 .to_owned();
283 let profile = val
284 .get("profile")
285 .and_then(|v| v.as_str())
286 .unwrap_or("")
287 .trim()
288 .to_owned();
289 if !label.is_empty() && !profile.is_empty() {
290 return Ok((label, profile));
291 }
292 }
293 let trimmed = response.trim();
295 let mut lines = trimmed.splitn(2, '\n');
296 let label = lines.next().unwrap_or("").trim().to_owned();
297 let profile = lines.next().unwrap_or(trimmed).trim().to_owned();
298 if label.is_empty() {
299 return Err(MemoryError::Other("scene LLM returned empty label".into()));
300 }
301 let profile = if profile.is_empty() {
302 label.clone()
303 } else {
304 profile
305 };
306 Ok((label, profile))
307}
308
309pub async fn list_scenes(store: &SqliteStore) -> Result<Vec<MemScene>, MemoryError> {
315 store.list_mem_scenes().await
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn cluster_messages_groups_similar() {
324 let v1 = vec![1.0f32, 0.0, 0.0];
325 let v2 = vec![1.0f32, 0.0, 0.0];
326 let v3 = vec![0.0f32, 1.0, 0.0];
327
328 let candidates = vec![
329 (MessageId(1), "a".to_owned(), v1),
330 (MessageId(2), "b".to_owned(), v2),
331 (MessageId(3), "c".to_owned(), v3),
332 ];
333
334 let clusters = cluster_messages(candidates, 0.80);
335 assert_eq!(clusters.len(), 2);
336 assert_eq!(clusters[0].len(), 2);
337 assert_eq!(clusters[1].len(), 1);
338 }
339
340 #[test]
341 fn parse_label_profile_valid_json() {
342 let json = r#"{"label": "Rust Auth JWT", "profile": "The project uses JWT for auth."}"#;
343 let (label, profile) = parse_label_profile(json).unwrap();
344 assert_eq!(label, "Rust Auth JWT");
345 assert_eq!(profile, "The project uses JWT for auth.");
346 }
347
348 #[test]
349 fn parse_label_profile_fallback_lines() {
350 let text = "Rust Auth\nJWT tokens used for authentication. Rate limited at 100 rps.";
351 let (label, profile) = parse_label_profile(text).unwrap();
352 assert_eq!(label, "Rust Auth");
353 assert!(profile.contains("JWT"));
354 }
355
356 #[test]
357 fn parse_label_profile_empty_fails() {
358 assert!(parse_label_profile("").is_err());
359 }
360}