1use zeph_db::{query, query_as, query_scalar, sql};
5
6use super::DbStore;
7use crate::error::MemoryError;
8
9#[derive(Debug, Clone, sqlx::FromRow)]
11pub struct MemoryTreeRow {
12 pub id: i64,
13 pub level: i64,
14 pub parent_id: Option<i64>,
15 pub content: String,
16 pub source_ids: String,
17 pub token_count: i64,
18 pub consolidated_at: Option<String>,
19 pub created_at: String,
20}
21
22impl DbStore {
23 pub async fn insert_tree_leaf(
31 &self,
32 content: &str,
33 token_count: i64,
34 ) -> Result<i64, MemoryError> {
35 let (id,): (i64,) = query_as(sql!(
36 "INSERT INTO memory_tree (level, content, token_count)
37 VALUES (0, ?, ?)
38 RETURNING id"
39 ))
40 .bind(content)
41 .bind(token_count)
42 .fetch_one(self.pool())
43 .await?;
44
45 Ok(id)
46 }
47
48 pub async fn insert_tree_node(
56 &self,
57 level: i64,
58 parent_id: Option<i64>,
59 content: &str,
60 source_ids: &str,
61 token_count: i64,
62 ) -> Result<i64, MemoryError> {
63 let (id,): (i64,) = query_as(sql!(
64 "INSERT INTO memory_tree
65 (level, parent_id, content, source_ids, token_count, consolidated_at)
66 VALUES (?, ?, ?, ?, ?, datetime('now'))
67 RETURNING id"
68 ))
69 .bind(level)
70 .bind(parent_id)
71 .bind(content)
72 .bind(source_ids)
73 .bind(token_count)
74 .fetch_one(self.pool())
75 .await?;
76
77 Ok(id)
78 }
79
80 pub async fn load_tree_leaves_unconsolidated(
86 &self,
87 limit: usize,
88 ) -> Result<Vec<MemoryTreeRow>, MemoryError> {
89 let rows: Vec<MemoryTreeRow> = query_as(sql!(
90 "SELECT id, level, parent_id, content, source_ids, token_count,
91 consolidated_at, created_at
92 FROM memory_tree
93 WHERE level = 0 AND parent_id IS NULL
94 ORDER BY created_at ASC
95 LIMIT ?"
96 ))
97 .bind(i64::try_from(limit).unwrap_or(i64::MAX))
98 .fetch_all(self.pool())
99 .await?;
100
101 Ok(rows)
102 }
103
104 pub async fn load_tree_level(
110 &self,
111 level: i64,
112 limit: usize,
113 ) -> Result<Vec<MemoryTreeRow>, MemoryError> {
114 let rows: Vec<MemoryTreeRow> = query_as(sql!(
115 "SELECT id, level, parent_id, content, source_ids, token_count,
116 consolidated_at, created_at
117 FROM memory_tree
118 WHERE level = ? AND parent_id IS NULL
119 ORDER BY created_at ASC
120 LIMIT ?"
121 ))
122 .bind(level)
123 .bind(i64::try_from(limit).unwrap_or(i64::MAX))
124 .fetch_all(self.pool())
125 .await?;
126
127 Ok(rows)
128 }
129
130 pub async fn traverse_tree_up(
138 &self,
139 leaf_id: i64,
140 max_level: i64,
141 ) -> Result<Vec<MemoryTreeRow>, MemoryError> {
142 let mut result = Vec::new();
144 let mut current_id = leaf_id;
145
146 for _ in 0..=max_level {
147 let row: Option<MemoryTreeRow> = query_as(sql!(
148 "SELECT id, level, parent_id, content, source_ids, token_count,
149 consolidated_at, created_at
150 FROM memory_tree
151 WHERE id = ?"
152 ))
153 .bind(current_id)
154 .fetch_optional(self.pool())
155 .await?;
156
157 match row {
158 None => break,
159 Some(r) => {
160 let next_id = r.parent_id;
161 result.push(r);
162 match next_id {
163 None => break,
164 Some(p) => current_id = p,
165 }
166 }
167 }
168 }
169
170 Ok(result)
171 }
172
173 pub async fn mark_nodes_consolidated(
183 &self,
184 child_ids: &[i64],
185 parent_id: i64,
186 ) -> Result<(), MemoryError> {
187 if child_ids.is_empty() {
188 return Ok(());
189 }
190
191 let mut tx = self.pool().begin().await?;
192
193 for &child_id in child_ids {
194 query(sql!(
195 "UPDATE memory_tree
196 SET parent_id = ?, consolidated_at = datetime('now')
197 WHERE id = ? AND parent_id IS NULL"
198 ))
199 .bind(parent_id)
200 .bind(child_id)
201 .execute(&mut *tx)
202 .await?;
203 }
204
205 tx.commit().await?;
206 Ok(())
207 }
208
209 #[cfg_attr(
218 feature = "profiling",
219 tracing::instrument(name = "memory.consolidate", skip_all)
220 )]
221 pub async fn consolidate_cluster(
222 &self,
223 level: i64,
224 summary: &str,
225 source_ids_json: &str,
226 token_count: i64,
227 child_ids: &[i64],
228 ) -> Result<i64, MemoryError> {
229 if child_ids.is_empty() {
230 return Err(MemoryError::InvalidInput(
231 "child_ids must not be empty".into(),
232 ));
233 }
234
235 let mut tx = self.pool().begin().await?;
236
237 let (parent_id,): (i64,) = zeph_db::query_as(zeph_db::sql!(
238 "INSERT INTO memory_tree
239 (level, content, source_ids, token_count, consolidated_at)
240 VALUES (?, ?, ?, ?, datetime('now'))
241 RETURNING id"
242 ))
243 .bind(level)
244 .bind(summary)
245 .bind(source_ids_json)
246 .bind(token_count)
247 .fetch_one(&mut *tx)
248 .await?;
249
250 for &child_id in child_ids {
251 zeph_db::query(zeph_db::sql!(
252 "UPDATE memory_tree
253 SET parent_id = ?, consolidated_at = datetime('now')
254 WHERE id = ? AND parent_id IS NULL"
255 ))
256 .bind(parent_id)
257 .bind(child_id)
258 .execute(&mut *tx)
259 .await?;
260 }
261
262 tx.commit().await?;
263 Ok(parent_id)
264 }
265
266 pub async fn increment_tree_consolidation_count(&self) -> Result<(), MemoryError> {
272 query(sql!(
273 "UPDATE memory_tree_meta
274 SET total_consolidations = total_consolidations + 1,
275 last_consolidation_at = datetime('now'),
276 updated_at = datetime('now')
277 WHERE id = 1"
278 ))
279 .execute(self.pool())
280 .await?;
281
282 Ok(())
283 }
284
285 pub async fn count_tree_nodes(&self) -> Result<i64, MemoryError> {
291 let count: i64 = query_scalar(sql!("SELECT COUNT(*) FROM memory_tree"))
292 .fetch_one(self.pool())
293 .await?;
294
295 Ok(count)
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 async fn make_store() -> DbStore {
304 DbStore::with_pool_size(":memory:", 1)
305 .await
306 .expect("in-memory store")
307 }
308
309 #[tokio::test]
310 async fn insert_leaf_and_count() {
311 let store = make_store().await;
312 let id = store
313 .insert_tree_leaf("remember this fact", 10)
314 .await
315 .expect("insert leaf");
316 assert!(id > 0);
317 assert_eq!(store.count_tree_nodes().await.expect("count"), 1);
318 }
319
320 #[tokio::test]
321 async fn load_unconsolidated_leaves_excludes_parented() {
322 let store = make_store().await;
323 let leaf1 = store.insert_tree_leaf("leaf one", 5).await.expect("leaf1");
324 let leaf2 = store.insert_tree_leaf("leaf two", 5).await.expect("leaf2");
325
326 let parent_id = store
328 .insert_tree_node(1, None, "summary of leaf1 and leaf2", "[]", 10)
329 .await
330 .expect("parent");
331 store
332 .mark_nodes_consolidated(&[leaf1, leaf2], parent_id)
333 .await
334 .expect("mark consolidated");
335
336 let leaves = store
338 .load_tree_leaves_unconsolidated(10)
339 .await
340 .expect("load");
341 assert!(
342 leaves.is_empty(),
343 "consolidated leaves must not appear in unconsolidated query"
344 );
345 }
346
347 #[tokio::test]
348 async fn mark_nodes_consolidated_is_per_cluster_transaction() {
349 let store = make_store().await;
350 let leaf1 = store.insert_tree_leaf("a", 1).await.expect("l1");
351 let leaf2 = store.insert_tree_leaf("b", 1).await.expect("l2");
352 let parent = store
353 .insert_tree_node(1, None, "ab summary", "[]", 2)
354 .await
355 .expect("parent");
356
357 store
358 .mark_nodes_consolidated(&[leaf1, leaf2], parent)
359 .await
360 .expect("mark");
361
362 let rows: Vec<MemoryTreeRow> = zeph_db::query_as(zeph_db::sql!(
364 "SELECT id, level, parent_id, content, source_ids, token_count,
365 consolidated_at, created_at
366 FROM memory_tree WHERE level = 0"
367 ))
368 .fetch_all(store.pool())
369 .await
370 .expect("fetch");
371
372 assert!(rows.iter().all(|r| r.parent_id == Some(parent)));
373 }
374
375 #[tokio::test]
376 async fn traverse_tree_up_returns_path() {
377 let store = make_store().await;
378 let leaf = store.insert_tree_leaf("leaf", 1).await.expect("leaf");
379 let mid = store
380 .insert_tree_node(1, None, "mid", "[]", 2)
381 .await
382 .expect("mid");
383 store
384 .mark_nodes_consolidated(&[leaf], mid)
385 .await
386 .expect("mark l→m");
387
388 let path = store.traverse_tree_up(leaf, 3).await.expect("traverse");
389 assert_eq!(path.len(), 2, "leaf + mid parent");
390 assert_eq!(path[0].id, leaf);
391 assert_eq!(path[1].id, mid);
392 }
393
394 #[tokio::test]
395 async fn mark_nodes_consolidated_empty_slice_is_noop() {
396 let store = make_store().await;
397 store.mark_nodes_consolidated(&[], 999).await.expect("noop");
399 }
400}