Skip to main content

zeph_memory/store/
memory_tree.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_db::{query, query_as, query_scalar, sql};
5
6use super::DbStore;
7use crate::error::MemoryError;
8
9/// A single memory tree node row from the `memory_tree` table.
10#[derive(Debug, Clone, sqlx::FromRow)]
11pub struct MemoryTreeRow {
12    pub id: i64,
13    pub level: i64,
14    pub parent_id: Option<i64>,
15    pub content: String,
16    pub source_ids: String,
17    pub token_count: i64,
18    pub consolidated_at: Option<String>,
19    pub created_at: String,
20}
21
22impl DbStore {
23    /// Insert a leaf node (level 0) into the memory tree.
24    ///
25    /// Returns the id of the new row.
26    ///
27    /// # Errors
28    ///
29    /// Returns an error if the query fails.
30    pub async fn insert_tree_leaf(
31        &self,
32        content: &str,
33        token_count: i64,
34    ) -> Result<i64, MemoryError> {
35        let (id,): (i64,) = query_as(sql!(
36            "INSERT INTO memory_tree (level, content, token_count)
37             VALUES (0, ?, ?)
38             RETURNING id"
39        ))
40        .bind(content)
41        .bind(token_count)
42        .fetch_one(self.pool())
43        .await?;
44
45        Ok(id)
46    }
47
48    /// Insert a consolidated node at a given level.
49    ///
50    /// Returns the id of the new row.
51    ///
52    /// # Errors
53    ///
54    /// Returns an error if the query fails.
55    pub async fn insert_tree_node(
56        &self,
57        level: i64,
58        parent_id: Option<i64>,
59        content: &str,
60        source_ids: &str,
61        token_count: i64,
62    ) -> Result<i64, MemoryError> {
63        let (id,): (i64,) = query_as(sql!(
64            "INSERT INTO memory_tree
65                (level, parent_id, content, source_ids, token_count, consolidated_at)
66             VALUES (?, ?, ?, ?, ?, datetime('now'))
67             RETURNING id"
68        ))
69        .bind(level)
70        .bind(parent_id)
71        .bind(content)
72        .bind(source_ids)
73        .bind(token_count)
74        .fetch_one(self.pool())
75        .await?;
76
77        Ok(id)
78    }
79
80    /// Load unconsolidated leaf nodes (level 0 without a parent).
81    ///
82    /// # Errors
83    ///
84    /// Returns an error if the query fails.
85    pub async fn load_tree_leaves_unconsolidated(
86        &self,
87        limit: usize,
88    ) -> Result<Vec<MemoryTreeRow>, MemoryError> {
89        let rows: Vec<MemoryTreeRow> = query_as(sql!(
90            "SELECT id, level, parent_id, content, source_ids, token_count,
91                    consolidated_at, created_at
92             FROM memory_tree
93             WHERE level = 0 AND parent_id IS NULL
94             ORDER BY created_at ASC
95             LIMIT ?"
96        ))
97        .bind(i64::try_from(limit).unwrap_or(i64::MAX))
98        .fetch_all(self.pool())
99        .await?;
100
101        Ok(rows)
102    }
103
104    /// Load all nodes at a given level (for consolidation of higher levels).
105    ///
106    /// # Errors
107    ///
108    /// Returns an error if the query fails.
109    pub async fn load_tree_level(
110        &self,
111        level: i64,
112        limit: usize,
113    ) -> Result<Vec<MemoryTreeRow>, MemoryError> {
114        let rows: Vec<MemoryTreeRow> = query_as(sql!(
115            "SELECT id, level, parent_id, content, source_ids, token_count,
116                    consolidated_at, created_at
117             FROM memory_tree
118             WHERE level = ? AND parent_id IS NULL
119             ORDER BY created_at ASC
120             LIMIT ?"
121        ))
122        .bind(level)
123        .bind(i64::try_from(limit).unwrap_or(i64::MAX))
124        .fetch_all(self.pool())
125        .await?;
126
127        Ok(rows)
128    }
129
130    /// Traverse from a leaf up to `max_level`, returning all ancestor nodes.
131    ///
132    /// The result is ordered from leaf (level 0) to root (highest level).
133    ///
134    /// # Errors
135    ///
136    /// Returns an error if the query fails.
137    pub async fn traverse_tree_up(
138        &self,
139        leaf_id: i64,
140        max_level: i64,
141    ) -> Result<Vec<MemoryTreeRow>, MemoryError> {
142        // Walk up via parent_id chain, bounded by max_level.
143        let mut result = Vec::new();
144        let mut current_id = leaf_id;
145
146        for _ in 0..=max_level {
147            let row: Option<MemoryTreeRow> = query_as(sql!(
148                "SELECT id, level, parent_id, content, source_ids, token_count,
149                        consolidated_at, created_at
150                 FROM memory_tree
151                 WHERE id = ?"
152            ))
153            .bind(current_id)
154            .fetch_optional(self.pool())
155            .await?;
156
157            match row {
158                None => break,
159                Some(r) => {
160                    let next_id = r.parent_id;
161                    result.push(r);
162                    match next_id {
163                        None => break,
164                        Some(p) => current_id = p,
165                    }
166                }
167            }
168        }
169
170        Ok(result)
171    }
172
173    /// Mark child nodes as consolidated by setting their `parent_id`.
174    ///
175    /// This runs inside a single transaction to prevent partial state.
176    /// Per-cluster transactions (critic S2 fix): call this once per cluster,
177    /// not once per full sweep.
178    ///
179    /// # Errors
180    ///
181    /// Returns an error if the query fails.
182    pub async fn mark_nodes_consolidated(
183        &self,
184        child_ids: &[i64],
185        parent_id: i64,
186    ) -> Result<(), MemoryError> {
187        if child_ids.is_empty() {
188            return Ok(());
189        }
190
191        let mut tx = self.pool().begin().await?;
192
193        for &child_id in child_ids {
194            query(sql!(
195                "UPDATE memory_tree
196                 SET parent_id = ?, consolidated_at = datetime('now')
197                 WHERE id = ? AND parent_id IS NULL"
198            ))
199            .bind(parent_id)
200            .bind(child_id)
201            .execute(&mut *tx)
202            .await?;
203        }
204
205        tx.commit().await?;
206        Ok(())
207    }
208
209    /// Insert a parent node and mark its children as consolidated in one transaction.
210    ///
211    /// Both the `INSERT` of the parent and the `UPDATE` of all children happen inside a single
212    /// `BEGIN … COMMIT`. A crash between the two operations therefore leaves no orphaned parent.
213    ///
214    /// # Errors
215    ///
216    /// Returns an error if any query inside the transaction fails (the transaction is rolled back).
217    #[cfg_attr(
218        feature = "profiling",
219        tracing::instrument(name = "memory.consolidate", skip_all)
220    )]
221    pub async fn consolidate_cluster(
222        &self,
223        level: i64,
224        summary: &str,
225        source_ids_json: &str,
226        token_count: i64,
227        child_ids: &[i64],
228    ) -> Result<i64, MemoryError> {
229        if child_ids.is_empty() {
230            return Err(MemoryError::InvalidInput(
231                "child_ids must not be empty".into(),
232            ));
233        }
234
235        let mut tx = self.pool().begin().await?;
236
237        let (parent_id,): (i64,) = zeph_db::query_as(zeph_db::sql!(
238            "INSERT INTO memory_tree
239                (level, content, source_ids, token_count, consolidated_at)
240             VALUES (?, ?, ?, ?, datetime('now'))
241             RETURNING id"
242        ))
243        .bind(level)
244        .bind(summary)
245        .bind(source_ids_json)
246        .bind(token_count)
247        .fetch_one(&mut *tx)
248        .await?;
249
250        for &child_id in child_ids {
251            zeph_db::query(zeph_db::sql!(
252                "UPDATE memory_tree
253                 SET parent_id = ?, consolidated_at = datetime('now')
254                 WHERE id = ? AND parent_id IS NULL"
255            ))
256            .bind(parent_id)
257            .bind(child_id)
258            .execute(&mut *tx)
259            .await?;
260        }
261
262        tx.commit().await?;
263        Ok(parent_id)
264    }
265
266    /// Increment the total consolidation counter in `memory_tree_meta`.
267    ///
268    /// # Errors
269    ///
270    /// Returns an error if the query fails.
271    pub async fn increment_tree_consolidation_count(&self) -> Result<(), MemoryError> {
272        query(sql!(
273            "UPDATE memory_tree_meta
274             SET total_consolidations = total_consolidations + 1,
275                 last_consolidation_at = datetime('now'),
276                 updated_at = datetime('now')
277             WHERE id = 1"
278        ))
279        .execute(self.pool())
280        .await?;
281
282        Ok(())
283    }
284
285    /// Count total nodes in the memory tree.
286    ///
287    /// # Errors
288    ///
289    /// Returns an error if the query fails.
290    pub async fn count_tree_nodes(&self) -> Result<i64, MemoryError> {
291        let count: i64 = query_scalar(sql!("SELECT COUNT(*) FROM memory_tree"))
292            .fetch_one(self.pool())
293            .await?;
294
295        Ok(count)
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    async fn make_store() -> DbStore {
304        DbStore::with_pool_size(":memory:", 1)
305            .await
306            .expect("in-memory store")
307    }
308
309    #[tokio::test]
310    async fn insert_leaf_and_count() {
311        let store = make_store().await;
312        let id = store
313            .insert_tree_leaf("remember this fact", 10)
314            .await
315            .expect("insert leaf");
316        assert!(id > 0);
317        assert_eq!(store.count_tree_nodes().await.expect("count"), 1);
318    }
319
320    #[tokio::test]
321    async fn load_unconsolidated_leaves_excludes_parented() {
322        let store = make_store().await;
323        let leaf1 = store.insert_tree_leaf("leaf one", 5).await.expect("leaf1");
324        let leaf2 = store.insert_tree_leaf("leaf two", 5).await.expect("leaf2");
325
326        // Consolidate into a parent node.
327        let parent_id = store
328            .insert_tree_node(1, None, "summary of leaf1 and leaf2", "[]", 10)
329            .await
330            .expect("parent");
331        store
332            .mark_nodes_consolidated(&[leaf1, leaf2], parent_id)
333            .await
334            .expect("mark consolidated");
335
336        // No unconsolidated leaves should remain.
337        let leaves = store
338            .load_tree_leaves_unconsolidated(10)
339            .await
340            .expect("load");
341        assert!(
342            leaves.is_empty(),
343            "consolidated leaves must not appear in unconsolidated query"
344        );
345    }
346
347    #[tokio::test]
348    async fn mark_nodes_consolidated_is_per_cluster_transaction() {
349        let store = make_store().await;
350        let leaf1 = store.insert_tree_leaf("a", 1).await.expect("l1");
351        let leaf2 = store.insert_tree_leaf("b", 1).await.expect("l2");
352        let parent = store
353            .insert_tree_node(1, None, "ab summary", "[]", 2)
354            .await
355            .expect("parent");
356
357        store
358            .mark_nodes_consolidated(&[leaf1, leaf2], parent)
359            .await
360            .expect("mark");
361
362        // Verify both are now parented.
363        let rows: Vec<MemoryTreeRow> = zeph_db::query_as(zeph_db::sql!(
364            "SELECT id, level, parent_id, content, source_ids, token_count,
365                    consolidated_at, created_at
366             FROM memory_tree WHERE level = 0"
367        ))
368        .fetch_all(store.pool())
369        .await
370        .expect("fetch");
371
372        assert!(rows.iter().all(|r| r.parent_id == Some(parent)));
373    }
374
375    #[tokio::test]
376    async fn traverse_tree_up_returns_path() {
377        let store = make_store().await;
378        let leaf = store.insert_tree_leaf("leaf", 1).await.expect("leaf");
379        let mid = store
380            .insert_tree_node(1, None, "mid", "[]", 2)
381            .await
382            .expect("mid");
383        store
384            .mark_nodes_consolidated(&[leaf], mid)
385            .await
386            .expect("mark l→m");
387
388        let path = store.traverse_tree_up(leaf, 3).await.expect("traverse");
389        assert_eq!(path.len(), 2, "leaf + mid parent");
390        assert_eq!(path[0].id, leaf);
391        assert_eq!(path[1].id, mid);
392    }
393
394    #[tokio::test]
395    async fn mark_nodes_consolidated_empty_slice_is_noop() {
396        let store = make_store().await;
397        // Should not fail on empty slice.
398        store.mark_nodes_consolidated(&[], 999).await.expect("noop");
399    }
400}