1use zeph_db::{query, query_as, query_scalar, sql};
5
6use super::DbStore;
7use crate::error::MemoryError;
8
9#[derive(Debug, Clone, sqlx::FromRow)]
11pub struct MemoryTreeRow {
12 pub id: i64,
13 pub level: i64,
14 pub parent_id: Option<i64>,
15 pub content: String,
16 pub source_ids: String,
17 pub token_count: i64,
18 pub consolidated_at: Option<String>,
19 pub created_at: String,
20}
21
22impl DbStore {
23 pub async fn insert_tree_leaf(
31 &self,
32 content: &str,
33 token_count: i64,
34 ) -> Result<i64, MemoryError> {
35 let (id,): (i64,) = query_as(sql!(
36 "INSERT INTO memory_tree (level, content, token_count)
37 VALUES (0, ?, ?)
38 RETURNING id"
39 ))
40 .bind(content)
41 .bind(token_count)
42 .fetch_one(self.pool())
43 .await?;
44
45 Ok(id)
46 }
47
48 pub async fn insert_tree_node(
56 &self,
57 level: i64,
58 parent_id: Option<i64>,
59 content: &str,
60 source_ids: &str,
61 token_count: i64,
62 ) -> Result<i64, MemoryError> {
63 let (id,): (i64,) = query_as(sql!(
64 "INSERT INTO memory_tree
65 (level, parent_id, content, source_ids, token_count, consolidated_at)
66 VALUES (?, ?, ?, ?, ?, datetime('now'))
67 RETURNING id"
68 ))
69 .bind(level)
70 .bind(parent_id)
71 .bind(content)
72 .bind(source_ids)
73 .bind(token_count)
74 .fetch_one(self.pool())
75 .await?;
76
77 Ok(id)
78 }
79
80 pub async fn load_tree_leaves_unconsolidated(
86 &self,
87 limit: usize,
88 ) -> Result<Vec<MemoryTreeRow>, MemoryError> {
89 let rows: Vec<MemoryTreeRow> = query_as(sql!(
90 "SELECT id, level, parent_id, content, source_ids, token_count,
91 consolidated_at, created_at
92 FROM memory_tree
93 WHERE level = 0 AND parent_id IS NULL
94 ORDER BY created_at ASC
95 LIMIT ?"
96 ))
97 .bind(i64::try_from(limit).unwrap_or(i64::MAX))
98 .fetch_all(self.pool())
99 .await?;
100
101 Ok(rows)
102 }
103
104 pub async fn load_tree_level(
110 &self,
111 level: i64,
112 limit: usize,
113 ) -> Result<Vec<MemoryTreeRow>, MemoryError> {
114 let rows: Vec<MemoryTreeRow> = query_as(sql!(
115 "SELECT id, level, parent_id, content, source_ids, token_count,
116 consolidated_at, created_at
117 FROM memory_tree
118 WHERE level = ? AND parent_id IS NULL
119 ORDER BY created_at ASC
120 LIMIT ?"
121 ))
122 .bind(level)
123 .bind(i64::try_from(limit).unwrap_or(i64::MAX))
124 .fetch_all(self.pool())
125 .await?;
126
127 Ok(rows)
128 }
129
130 pub async fn traverse_tree_up(
138 &self,
139 leaf_id: i64,
140 max_level: i64,
141 ) -> Result<Vec<MemoryTreeRow>, MemoryError> {
142 let mut result = Vec::new();
144 let mut current_id = leaf_id;
145
146 for _ in 0..=max_level {
147 let row: Option<MemoryTreeRow> = query_as(sql!(
148 "SELECT id, level, parent_id, content, source_ids, token_count,
149 consolidated_at, created_at
150 FROM memory_tree
151 WHERE id = ?"
152 ))
153 .bind(current_id)
154 .fetch_optional(self.pool())
155 .await?;
156
157 match row {
158 None => break,
159 Some(r) => {
160 let next_id = r.parent_id;
161 result.push(r);
162 match next_id {
163 None => break,
164 Some(p) => current_id = p,
165 }
166 }
167 }
168 }
169
170 Ok(result)
171 }
172
173 pub async fn mark_nodes_consolidated(
183 &self,
184 child_ids: &[i64],
185 parent_id: i64,
186 ) -> Result<(), MemoryError> {
187 if child_ids.is_empty() {
188 return Ok(());
189 }
190
191 let mut tx = self.pool().begin().await?;
192
193 for &child_id in child_ids {
194 query(sql!(
195 "UPDATE memory_tree
196 SET parent_id = ?, consolidated_at = datetime('now')
197 WHERE id = ? AND parent_id IS NULL"
198 ))
199 .bind(parent_id)
200 .bind(child_id)
201 .execute(&mut *tx)
202 .await?;
203 }
204
205 tx.commit().await?;
206 Ok(())
207 }
208
209 pub async fn consolidate_cluster(
218 &self,
219 level: i64,
220 summary: &str,
221 source_ids_json: &str,
222 token_count: i64,
223 child_ids: &[i64],
224 ) -> Result<i64, MemoryError> {
225 if child_ids.is_empty() {
226 return Err(MemoryError::InvalidInput(
227 "child_ids must not be empty".into(),
228 ));
229 }
230
231 let mut tx = self.pool().begin().await?;
232
233 let (parent_id,): (i64,) = zeph_db::query_as(zeph_db::sql!(
234 "INSERT INTO memory_tree
235 (level, content, source_ids, token_count, consolidated_at)
236 VALUES (?, ?, ?, ?, datetime('now'))
237 RETURNING id"
238 ))
239 .bind(level)
240 .bind(summary)
241 .bind(source_ids_json)
242 .bind(token_count)
243 .fetch_one(&mut *tx)
244 .await?;
245
246 for &child_id in child_ids {
247 zeph_db::query(zeph_db::sql!(
248 "UPDATE memory_tree
249 SET parent_id = ?, consolidated_at = datetime('now')
250 WHERE id = ? AND parent_id IS NULL"
251 ))
252 .bind(parent_id)
253 .bind(child_id)
254 .execute(&mut *tx)
255 .await?;
256 }
257
258 tx.commit().await?;
259 Ok(parent_id)
260 }
261
262 pub async fn increment_tree_consolidation_count(&self) -> Result<(), MemoryError> {
268 query(sql!(
269 "UPDATE memory_tree_meta
270 SET total_consolidations = total_consolidations + 1,
271 last_consolidation_at = datetime('now'),
272 updated_at = datetime('now')
273 WHERE id = 1"
274 ))
275 .execute(self.pool())
276 .await?;
277
278 Ok(())
279 }
280
281 pub async fn count_tree_nodes(&self) -> Result<i64, MemoryError> {
287 let count: i64 = query_scalar(sql!("SELECT COUNT(*) FROM memory_tree"))
288 .fetch_one(self.pool())
289 .await?;
290
291 Ok(count)
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 async fn make_store() -> DbStore {
300 DbStore::with_pool_size(":memory:", 1)
301 .await
302 .expect("in-memory store")
303 }
304
305 #[tokio::test]
306 async fn insert_leaf_and_count() {
307 let store = make_store().await;
308 let id = store
309 .insert_tree_leaf("remember this fact", 10)
310 .await
311 .expect("insert leaf");
312 assert!(id > 0);
313 assert_eq!(store.count_tree_nodes().await.expect("count"), 1);
314 }
315
316 #[tokio::test]
317 async fn load_unconsolidated_leaves_excludes_parented() {
318 let store = make_store().await;
319 let leaf1 = store.insert_tree_leaf("leaf one", 5).await.expect("leaf1");
320 let leaf2 = store.insert_tree_leaf("leaf two", 5).await.expect("leaf2");
321
322 let parent_id = store
324 .insert_tree_node(1, None, "summary of leaf1 and leaf2", "[]", 10)
325 .await
326 .expect("parent");
327 store
328 .mark_nodes_consolidated(&[leaf1, leaf2], parent_id)
329 .await
330 .expect("mark consolidated");
331
332 let leaves = store
334 .load_tree_leaves_unconsolidated(10)
335 .await
336 .expect("load");
337 assert!(
338 leaves.is_empty(),
339 "consolidated leaves must not appear in unconsolidated query"
340 );
341 }
342
343 #[tokio::test]
344 async fn mark_nodes_consolidated_is_per_cluster_transaction() {
345 let store = make_store().await;
346 let leaf1 = store.insert_tree_leaf("a", 1).await.expect("l1");
347 let leaf2 = store.insert_tree_leaf("b", 1).await.expect("l2");
348 let parent = store
349 .insert_tree_node(1, None, "ab summary", "[]", 2)
350 .await
351 .expect("parent");
352
353 store
354 .mark_nodes_consolidated(&[leaf1, leaf2], parent)
355 .await
356 .expect("mark");
357
358 let rows: Vec<MemoryTreeRow> = zeph_db::query_as(zeph_db::sql!(
360 "SELECT id, level, parent_id, content, source_ids, token_count,
361 consolidated_at, created_at
362 FROM memory_tree WHERE level = 0"
363 ))
364 .fetch_all(store.pool())
365 .await
366 .expect("fetch");
367
368 assert!(rows.iter().all(|r| r.parent_id == Some(parent)));
369 }
370
371 #[tokio::test]
372 async fn traverse_tree_up_returns_path() {
373 let store = make_store().await;
374 let leaf = store.insert_tree_leaf("leaf", 1).await.expect("leaf");
375 let mid = store
376 .insert_tree_node(1, None, "mid", "[]", 2)
377 .await
378 .expect("mid");
379 store
380 .mark_nodes_consolidated(&[leaf], mid)
381 .await
382 .expect("mark l→m");
383
384 let path = store.traverse_tree_up(leaf, 3).await.expect("traverse");
385 assert_eq!(path.len(), 2, "leaf + mid parent");
386 assert_eq!(path[0].id, leaf);
387 assert_eq!(path[1].id, mid);
388 }
389
390 #[tokio::test]
391 async fn mark_nodes_consolidated_empty_slice_is_noop() {
392 let store = make_store().await;
393 store.mark_nodes_consolidated(&[], 999).await.expect("noop");
395 }
396}