Skip to main content

zeph_memory/store/
memory_tree.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_db::{query, query_as, query_scalar, sql};
5
6use super::DbStore;
7use crate::error::MemoryError;
8
9/// A single memory tree node row from the `memory_tree` table.
10#[derive(Debug, Clone, sqlx::FromRow)]
11pub struct MemoryTreeRow {
12    pub id: i64,
13    pub level: i64,
14    pub parent_id: Option<i64>,
15    pub content: String,
16    pub source_ids: String,
17    pub token_count: i64,
18    pub consolidated_at: Option<String>,
19    pub created_at: String,
20}
21
22impl DbStore {
23    /// Insert a leaf node (level 0) into the memory tree.
24    ///
25    /// Returns the id of the new row.
26    ///
27    /// # Errors
28    ///
29    /// Returns an error if the query fails.
30    pub async fn insert_tree_leaf(
31        &self,
32        content: &str,
33        token_count: i64,
34    ) -> Result<i64, MemoryError> {
35        let (id,): (i64,) = query_as(sql!(
36            "INSERT INTO memory_tree (level, content, token_count)
37             VALUES (0, ?, ?)
38             RETURNING id"
39        ))
40        .bind(content)
41        .bind(token_count)
42        .fetch_one(self.pool())
43        .await?;
44
45        Ok(id)
46    }
47
48    /// Insert a consolidated node at a given level.
49    ///
50    /// Returns the id of the new row.
51    ///
52    /// # Errors
53    ///
54    /// Returns an error if the query fails.
55    pub async fn insert_tree_node(
56        &self,
57        level: i64,
58        parent_id: Option<i64>,
59        content: &str,
60        source_ids: &str,
61        token_count: i64,
62    ) -> Result<i64, MemoryError> {
63        let (id,): (i64,) = query_as(sql!(
64            "INSERT INTO memory_tree
65                (level, parent_id, content, source_ids, token_count, consolidated_at)
66             VALUES (?, ?, ?, ?, ?, datetime('now'))
67             RETURNING id"
68        ))
69        .bind(level)
70        .bind(parent_id)
71        .bind(content)
72        .bind(source_ids)
73        .bind(token_count)
74        .fetch_one(self.pool())
75        .await?;
76
77        Ok(id)
78    }
79
80    /// Load unconsolidated leaf nodes (level 0 without a parent).
81    ///
82    /// # Errors
83    ///
84    /// Returns an error if the query fails.
85    pub async fn load_tree_leaves_unconsolidated(
86        &self,
87        limit: usize,
88    ) -> Result<Vec<MemoryTreeRow>, MemoryError> {
89        let rows: Vec<MemoryTreeRow> = query_as(sql!(
90            "SELECT id, level, parent_id, content, source_ids, token_count,
91                    consolidated_at, created_at
92             FROM memory_tree
93             WHERE level = 0 AND parent_id IS NULL
94             ORDER BY created_at ASC
95             LIMIT ?"
96        ))
97        .bind(i64::try_from(limit).unwrap_or(i64::MAX))
98        .fetch_all(self.pool())
99        .await?;
100
101        Ok(rows)
102    }
103
104    /// Load all nodes at a given level (for consolidation of higher levels).
105    ///
106    /// # Errors
107    ///
108    /// Returns an error if the query fails.
109    pub async fn load_tree_level(
110        &self,
111        level: i64,
112        limit: usize,
113    ) -> Result<Vec<MemoryTreeRow>, MemoryError> {
114        let rows: Vec<MemoryTreeRow> = query_as(sql!(
115            "SELECT id, level, parent_id, content, source_ids, token_count,
116                    consolidated_at, created_at
117             FROM memory_tree
118             WHERE level = ? AND parent_id IS NULL
119             ORDER BY created_at ASC
120             LIMIT ?"
121        ))
122        .bind(level)
123        .bind(i64::try_from(limit).unwrap_or(i64::MAX))
124        .fetch_all(self.pool())
125        .await?;
126
127        Ok(rows)
128    }
129
130    /// Traverse from a leaf up to `max_level`, returning all ancestor nodes.
131    ///
132    /// The result is ordered from leaf (level 0) to root (highest level).
133    ///
134    /// # Errors
135    ///
136    /// Returns an error if the query fails.
137    pub async fn traverse_tree_up(
138        &self,
139        leaf_id: i64,
140        max_level: i64,
141    ) -> Result<Vec<MemoryTreeRow>, MemoryError> {
142        // Walk up via parent_id chain, bounded by max_level.
143        let mut result = Vec::new();
144        let mut current_id = leaf_id;
145
146        for _ in 0..=max_level {
147            let row: Option<MemoryTreeRow> = query_as(sql!(
148                "SELECT id, level, parent_id, content, source_ids, token_count,
149                        consolidated_at, created_at
150                 FROM memory_tree
151                 WHERE id = ?"
152            ))
153            .bind(current_id)
154            .fetch_optional(self.pool())
155            .await?;
156
157            match row {
158                None => break,
159                Some(r) => {
160                    let next_id = r.parent_id;
161                    result.push(r);
162                    match next_id {
163                        None => break,
164                        Some(p) => current_id = p,
165                    }
166                }
167            }
168        }
169
170        Ok(result)
171    }
172
173    /// Mark child nodes as consolidated by setting their `parent_id`.
174    ///
175    /// This runs inside a single transaction to prevent partial state.
176    /// Per-cluster transactions (critic S2 fix): call this once per cluster,
177    /// not once per full sweep.
178    ///
179    /// # Errors
180    ///
181    /// Returns an error if the query fails.
182    pub async fn mark_nodes_consolidated(
183        &self,
184        child_ids: &[i64],
185        parent_id: i64,
186    ) -> Result<(), MemoryError> {
187        if child_ids.is_empty() {
188            return Ok(());
189        }
190
191        let mut tx = self.pool().begin().await?;
192
193        for &child_id in child_ids {
194            query(sql!(
195                "UPDATE memory_tree
196                 SET parent_id = ?, consolidated_at = datetime('now')
197                 WHERE id = ? AND parent_id IS NULL"
198            ))
199            .bind(parent_id)
200            .bind(child_id)
201            .execute(&mut *tx)
202            .await?;
203        }
204
205        tx.commit().await?;
206        Ok(())
207    }
208
209    /// Insert a parent node and mark its children as consolidated in one transaction.
210    ///
211    /// Both the `INSERT` of the parent and the `UPDATE` of all children happen inside a single
212    /// `BEGIN … COMMIT`. A crash between the two operations therefore leaves no orphaned parent.
213    ///
214    /// # Errors
215    ///
216    /// Returns an error if any query inside the transaction fails (the transaction is rolled back).
217    pub async fn consolidate_cluster(
218        &self,
219        level: i64,
220        summary: &str,
221        source_ids_json: &str,
222        token_count: i64,
223        child_ids: &[i64],
224    ) -> Result<i64, MemoryError> {
225        if child_ids.is_empty() {
226            return Err(MemoryError::InvalidInput(
227                "child_ids must not be empty".into(),
228            ));
229        }
230
231        let mut tx = self.pool().begin().await?;
232
233        let (parent_id,): (i64,) = zeph_db::query_as(zeph_db::sql!(
234            "INSERT INTO memory_tree
235                (level, content, source_ids, token_count, consolidated_at)
236             VALUES (?, ?, ?, ?, datetime('now'))
237             RETURNING id"
238        ))
239        .bind(level)
240        .bind(summary)
241        .bind(source_ids_json)
242        .bind(token_count)
243        .fetch_one(&mut *tx)
244        .await?;
245
246        for &child_id in child_ids {
247            zeph_db::query(zeph_db::sql!(
248                "UPDATE memory_tree
249                 SET parent_id = ?, consolidated_at = datetime('now')
250                 WHERE id = ? AND parent_id IS NULL"
251            ))
252            .bind(parent_id)
253            .bind(child_id)
254            .execute(&mut *tx)
255            .await?;
256        }
257
258        tx.commit().await?;
259        Ok(parent_id)
260    }
261
262    /// Increment the total consolidation counter in `memory_tree_meta`.
263    ///
264    /// # Errors
265    ///
266    /// Returns an error if the query fails.
267    pub async fn increment_tree_consolidation_count(&self) -> Result<(), MemoryError> {
268        query(sql!(
269            "UPDATE memory_tree_meta
270             SET total_consolidations = total_consolidations + 1,
271                 last_consolidation_at = datetime('now'),
272                 updated_at = datetime('now')
273             WHERE id = 1"
274        ))
275        .execute(self.pool())
276        .await?;
277
278        Ok(())
279    }
280
281    /// Count total nodes in the memory tree.
282    ///
283    /// # Errors
284    ///
285    /// Returns an error if the query fails.
286    pub async fn count_tree_nodes(&self) -> Result<i64, MemoryError> {
287        let count: i64 = query_scalar(sql!("SELECT COUNT(*) FROM memory_tree"))
288            .fetch_one(self.pool())
289            .await?;
290
291        Ok(count)
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    async fn make_store() -> DbStore {
300        DbStore::with_pool_size(":memory:", 1)
301            .await
302            .expect("in-memory store")
303    }
304
305    #[tokio::test]
306    async fn insert_leaf_and_count() {
307        let store = make_store().await;
308        let id = store
309            .insert_tree_leaf("remember this fact", 10)
310            .await
311            .expect("insert leaf");
312        assert!(id > 0);
313        assert_eq!(store.count_tree_nodes().await.expect("count"), 1);
314    }
315
316    #[tokio::test]
317    async fn load_unconsolidated_leaves_excludes_parented() {
318        let store = make_store().await;
319        let leaf1 = store.insert_tree_leaf("leaf one", 5).await.expect("leaf1");
320        let leaf2 = store.insert_tree_leaf("leaf two", 5).await.expect("leaf2");
321
322        // Consolidate into a parent node.
323        let parent_id = store
324            .insert_tree_node(1, None, "summary of leaf1 and leaf2", "[]", 10)
325            .await
326            .expect("parent");
327        store
328            .mark_nodes_consolidated(&[leaf1, leaf2], parent_id)
329            .await
330            .expect("mark consolidated");
331
332        // No unconsolidated leaves should remain.
333        let leaves = store
334            .load_tree_leaves_unconsolidated(10)
335            .await
336            .expect("load");
337        assert!(
338            leaves.is_empty(),
339            "consolidated leaves must not appear in unconsolidated query"
340        );
341    }
342
343    #[tokio::test]
344    async fn mark_nodes_consolidated_is_per_cluster_transaction() {
345        let store = make_store().await;
346        let leaf1 = store.insert_tree_leaf("a", 1).await.expect("l1");
347        let leaf2 = store.insert_tree_leaf("b", 1).await.expect("l2");
348        let parent = store
349            .insert_tree_node(1, None, "ab summary", "[]", 2)
350            .await
351            .expect("parent");
352
353        store
354            .mark_nodes_consolidated(&[leaf1, leaf2], parent)
355            .await
356            .expect("mark");
357
358        // Verify both are now parented.
359        let rows: Vec<MemoryTreeRow> = zeph_db::query_as(zeph_db::sql!(
360            "SELECT id, level, parent_id, content, source_ids, token_count,
361                    consolidated_at, created_at
362             FROM memory_tree WHERE level = 0"
363        ))
364        .fetch_all(store.pool())
365        .await
366        .expect("fetch");
367
368        assert!(rows.iter().all(|r| r.parent_id == Some(parent)));
369    }
370
371    #[tokio::test]
372    async fn traverse_tree_up_returns_path() {
373        let store = make_store().await;
374        let leaf = store.insert_tree_leaf("leaf", 1).await.expect("leaf");
375        let mid = store
376            .insert_tree_node(1, None, "mid", "[]", 2)
377            .await
378            .expect("mid");
379        store
380            .mark_nodes_consolidated(&[leaf], mid)
381            .await
382            .expect("mark l→m");
383
384        let path = store.traverse_tree_up(leaf, 3).await.expect("traverse");
385        assert_eq!(path.len(), 2, "leaf + mid parent");
386        assert_eq!(path[0].id, leaf);
387        assert_eq!(path[1].id, mid);
388    }
389
390    #[tokio::test]
391    async fn mark_nodes_consolidated_empty_slice_is_noop() {
392        let store = make_store().await;
393        // Should not fail on empty slice.
394        store.mark_nodes_consolidated(&[], 999).await.expect("noop");
395    }
396}