1use std::sync::Arc;
10use std::time::Duration;
11
12use tokio::task::JoinHandle;
13use tokio_util::sync::CancellationToken;
14use zeph_llm::any::AnyProvider;
15use zeph_llm::provider::LlmProvider as _;
16
17use crate::error::MemoryError;
18use crate::store::SqliteStore;
19use crate::types::{MemSceneId, MessageId};
20use zeph_common::math::cosine_similarity;
21
22#[derive(Debug, Clone)]
24pub struct MemScene {
25 pub id: MemSceneId,
26 pub label: String,
27 pub profile: String,
28 pub member_count: u32,
29 pub created_at: i64,
30 pub updated_at: i64,
31}
32
33#[derive(Debug, Clone)]
35pub struct SceneConfig {
36 pub enabled: bool,
37 pub similarity_threshold: f32,
38 pub batch_size: usize,
39 pub sweep_interval_secs: u64,
40}
41
42pub fn start_scene_consolidation_loop(
47 store: Arc<SqliteStore>,
48 provider: AnyProvider,
49 config: SceneConfig,
50 cancel: CancellationToken,
51) -> JoinHandle<()> {
52 tokio::spawn(async move {
53 if !config.enabled {
54 tracing::debug!("scene consolidation disabled (tiers.scene_enabled = false)");
55 return;
56 }
57
58 let mut ticker = tokio::time::interval(Duration::from_secs(config.sweep_interval_secs));
59 ticker.tick().await;
61
62 loop {
63 tokio::select! {
64 () = cancel.cancelled() => {
65 tracing::debug!("scene consolidation loop shutting down");
66 return;
67 }
68 _ = ticker.tick() => {}
69 }
70
71 tracing::debug!("scene consolidation: starting sweep");
72 let start = std::time::Instant::now();
73
74 match consolidate_scenes(&store, &provider, &config).await {
75 Ok(stats) => {
76 tracing::info!(
77 candidates = stats.candidates,
78 scenes_created = stats.scenes_created,
79 messages_assigned = stats.messages_assigned,
80 elapsed_ms = start.elapsed().as_millis(),
81 "scene consolidation: sweep complete"
82 );
83 }
84 Err(e) => {
85 tracing::warn!(
86 error = %e,
87 elapsed_ms = start.elapsed().as_millis(),
88 "scene consolidation: sweep failed, will retry"
89 );
90 }
91 }
92 }
93 })
94}
95
96#[derive(Debug, Default)]
98pub struct SceneStats {
99 pub candidates: usize,
100 pub scenes_created: usize,
101 pub messages_assigned: usize,
102}
103
104pub async fn consolidate_scenes(
110 store: &SqliteStore,
111 provider: &AnyProvider,
112 config: &SceneConfig,
113) -> Result<SceneStats, MemoryError> {
114 let candidates = store
115 .find_unscened_semantic_messages(config.batch_size)
116 .await?;
117
118 if candidates.len() < 2 {
119 return Ok(SceneStats::default());
120 }
121
122 let mut stats = SceneStats {
123 candidates: candidates.len(),
124 ..SceneStats::default()
125 };
126
127 let mut embedded: Vec<(MessageId, String, Vec<f32>)> = Vec::with_capacity(candidates.len());
129 if provider.supports_embeddings() {
130 for (msg_id, content) in candidates {
131 match provider.embed(&content).await {
132 Ok(vec) => embedded.push((msg_id, content, vec)),
133 Err(e) => {
134 tracing::warn!(
135 message_id = msg_id.0,
136 error = %e,
137 "scene consolidation: failed to embed candidate, skipping"
138 );
139 }
140 }
141 }
142 } else {
143 return Ok(stats);
144 }
145
146 if embedded.len() < 2 {
147 return Ok(stats);
148 }
149
150 let clusters = cluster_messages(embedded, config.similarity_threshold);
152
153 for cluster in clusters {
154 if cluster.len() < 2 {
155 continue;
156 }
157
158 let contents: Vec<&str> = cluster.iter().map(|(_, c, _)| c.as_str()).collect();
159 let msg_ids: Vec<MessageId> = cluster.iter().map(|(id, _, _)| *id).collect();
160
161 match generate_scene_label_and_profile(provider, &contents).await {
162 Ok((label, profile)) => {
163 let label = label.chars().take(100).collect::<String>();
164 let profile = profile.chars().take(2000).collect::<String>();
165 match store.insert_mem_scene(&label, &profile, &msg_ids).await {
166 Ok(_scene_id) => {
167 stats.scenes_created += 1;
168 stats.messages_assigned += msg_ids.len();
169 }
170 Err(e) => {
171 tracing::warn!(
172 error = %e,
173 cluster_size = msg_ids.len(),
174 "scene consolidation: failed to insert scene"
175 );
176 }
177 }
178 }
179 Err(e) => {
180 tracing::warn!(
181 error = %e,
182 cluster_size = msg_ids.len(),
183 "scene consolidation: LLM label generation failed, skipping cluster"
184 );
185 }
186 }
187 }
188
189 Ok(stats)
190}
191
192fn cluster_messages(
193 candidates: Vec<(MessageId, String, Vec<f32>)>,
194 threshold: f32,
195) -> Vec<Vec<(MessageId, String, Vec<f32>)>> {
196 let mut clusters: Vec<Vec<(MessageId, String, Vec<f32>)>> = Vec::new();
197
198 'outer: for candidate in candidates {
199 for cluster in &mut clusters {
200 let rep = &cluster[0].2;
201 if cosine_similarity(&candidate.2, rep) >= threshold {
202 cluster.push(candidate);
203 continue 'outer;
204 }
205 }
206 clusters.push(vec![candidate]);
207 }
208
209 clusters
210}
211
212async fn generate_scene_label_and_profile(
213 provider: &AnyProvider,
214 contents: &[&str],
215) -> Result<(String, String), MemoryError> {
216 use zeph_llm::provider::{Message, MessageMetadata, Role};
217
218 let bullet_list: String = contents
219 .iter()
220 .enumerate()
221 .map(|(i, c)| format!("{}. {c}", i + 1))
222 .collect::<Vec<_>>()
223 .join("\n");
224
225 let system_content = "You are a memory scene architect. \
226 Given a set of related semantic facts, generate:\n\
227 1. A short label (5 words max) identifying the core entity or topic.\n\
228 2. A 2-3 sentence entity profile summarizing the key facts.\n\
229 Respond in JSON: {\"label\": \"...\", \"profile\": \"...\"}";
230
231 let user_content =
232 format!("Generate a label and profile for these related facts:\n\n{bullet_list}");
233
234 let messages = vec![
235 Message {
236 role: Role::System,
237 content: system_content.to_owned(),
238 parts: vec![],
239 metadata: MessageMetadata::default(),
240 },
241 Message {
242 role: Role::User,
243 content: user_content,
244 parts: vec![],
245 metadata: MessageMetadata::default(),
246 },
247 ];
248
249 let result = tokio::time::timeout(Duration::from_secs(15), provider.chat(&messages))
250 .await
251 .map_err(|_| MemoryError::Other("scene LLM call timed out after 15s".into()))?
252 .map_err(MemoryError::Llm)?;
253
254 parse_label_profile(&result)
255}
256
257fn parse_label_profile(response: &str) -> Result<(String, String), MemoryError> {
258 if let Ok(val) = serde_json::from_str::<serde_json::Value>(response) {
260 let label = val
261 .get("label")
262 .and_then(|v| v.as_str())
263 .unwrap_or("")
264 .trim()
265 .to_owned();
266 let profile = val
267 .get("profile")
268 .and_then(|v| v.as_str())
269 .unwrap_or("")
270 .trim()
271 .to_owned();
272 if !label.is_empty() && !profile.is_empty() {
273 return Ok((label, profile));
274 }
275 }
276 let trimmed = response.trim();
278 let mut lines = trimmed.splitn(2, '\n');
279 let label = lines.next().unwrap_or("").trim().to_owned();
280 let profile = lines.next().unwrap_or(trimmed).trim().to_owned();
281 if label.is_empty() {
282 return Err(MemoryError::Other("scene LLM returned empty label".into()));
283 }
284 let profile = if profile.is_empty() {
285 label.clone()
286 } else {
287 profile
288 };
289 Ok((label, profile))
290}
291
292pub async fn list_scenes(store: &SqliteStore) -> Result<Vec<MemScene>, MemoryError> {
298 store.list_mem_scenes().await
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn cluster_messages_groups_similar() {
307 let v1 = vec![1.0f32, 0.0, 0.0];
308 let v2 = vec![1.0f32, 0.0, 0.0];
309 let v3 = vec![0.0f32, 1.0, 0.0];
310
311 let candidates = vec![
312 (MessageId(1), "a".to_owned(), v1),
313 (MessageId(2), "b".to_owned(), v2),
314 (MessageId(3), "c".to_owned(), v3),
315 ];
316
317 let clusters = cluster_messages(candidates, 0.80);
318 assert_eq!(clusters.len(), 2);
319 assert_eq!(clusters[0].len(), 2);
320 assert_eq!(clusters[1].len(), 1);
321 }
322
323 #[test]
324 fn parse_label_profile_valid_json() {
325 let json = r#"{"label": "Rust Auth JWT", "profile": "The project uses JWT for auth."}"#;
326 let (label, profile) = parse_label_profile(json).unwrap();
327 assert_eq!(label, "Rust Auth JWT");
328 assert_eq!(profile, "The project uses JWT for auth.");
329 }
330
331 #[test]
332 fn parse_label_profile_fallback_lines() {
333 let text = "Rust Auth\nJWT tokens used for authentication. Rate limited at 100 rps.";
334 let (label, profile) = parse_label_profile(text).unwrap();
335 assert_eq!(label, "Rust Auth");
336 assert!(profile.contains("JWT"));
337 }
338
339 #[test]
340 fn parse_label_profile_empty_fails() {
341 assert!(parse_label_profile("").is_err());
342 }
343}