Skip to main content

sqlite_graphrag/storage/
pending_memories.rs

1//! GAP-001 (v1.0.82): DAO para tabela `pending_memories`.
2//!
3//! Persistência por estágios com checkpoint retomável. Permite ao `remember` retomar
4//! do Estágio B (embedding) sem re-validar Estágio A (parse + validate).
5//!
6//! Status transitions:
7//!   validated → embedding_in_progress → embedding_done → committed
8//!                                                    ↘ abandoned (manual cleanup)
9//!                                                    ↘ failed (max attempts reached)
10
11use rusqlite::{params, Connection};
12
13use crate::errors::AppError;
14
15/// Status enum de uma entrada pending. Mapeia 1:1 para o CHECK constraint
16/// da tabela `pending_memories`.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
18#[serde(rename_all = "snake_case")]
19pub enum PendingStatus {
20    Validated,
21    EmbeddingInProgress,
22    EmbeddingDone,
23    Committed,
24    Abandoned,
25    Failed,
26}
27
28impl PendingStatus {
29    pub fn as_str(&self) -> &'static str {
30        match self {
31            Self::Validated => "validated",
32            Self::EmbeddingInProgress => "embedding_in_progress",
33            Self::EmbeddingDone => "embedding_done",
34            Self::Committed => "committed",
35            Self::Abandoned => "abandoned",
36            Self::Failed => "failed",
37        }
38    }
39}
40
41/// Representa uma entrada da tabela `pending_memories`.
42#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
43pub struct PendingMemory {
44    pub pending_id: i64,
45    pub name: String,
46    pub namespace: String,
47    pub memory_type: String,
48    pub description: Option<String>,
49    pub body: Vec<u8>,
50    pub body_hash: String,
51    pub entities_json: Option<String>,
52    pub relationships_json: Option<String>,
53    pub status: PendingStatus,
54    pub embedding: Option<Vec<u8>>,
55    pub embedding_dim: Option<i32>,
56    pub attempt_count: i32,
57    pub last_error: Option<String>,
58    pub created_at: i64,
59    pub updated_at: i64,
60}
61
62/// Insere uma nova entrada em `pending_memories` com status `validated`.
63///
64/// Retorna o `pending_id` gerado.
65#[allow(clippy::too_many_arguments)]
66pub fn insert_validated(
67    conn: &Connection,
68    name: &str,
69    namespace: &str,
70    memory_type: &str,
71    description: Option<&str>,
72    body: &[u8],
73    body_hash: &str,
74    entities_json: Option<&str>,
75    relationships_json: Option<&str>,
76) -> Result<i64, AppError> {
77    conn.execute(
78        "INSERT INTO pending_memories
79            (name, namespace, memory_type, description, body, body_hash,
80             entities_json, relationships_json, status, attempt_count)
81         VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, 'validated', 0)",
82        params![
83            name,
84            namespace,
85            memory_type,
86            description,
87            body,
88            body_hash,
89            entities_json,
90            relationships_json,
91        ],
92    )?;
93    Ok(conn.last_insert_rowid())
94}
95
96/// Atualiza status para `embedding_in_progress` e incrementa `attempt_count`.
97pub fn update_to_embedding_in_progress(conn: &Connection, pending_id: i64) -> Result<(), AppError> {
98    conn.execute(
99        "UPDATE pending_memories
100         SET status = 'embedding_in_progress',
101             attempt_count = attempt_count + 1,
102             updated_at = unixepoch()
103         WHERE pending_id = ?1",
104        params![pending_id],
105    )?;
106    Ok(())
107}
108
109/// Atualiza status para `embedding_done` e armazena o embedding BLOB.
110pub fn update_to_embedding_done(
111    conn: &Connection,
112    pending_id: i64,
113    embedding: &[u8],
114    dim: i32,
115) -> Result<(), AppError> {
116    conn.execute(
117        "UPDATE pending_memories
118         SET status = 'embedding_done',
119             embedding = ?1,
120             embedding_dim = ?2,
121             updated_at = unixepoch()
122         WHERE pending_id = ?3",
123        params![embedding, dim, pending_id],
124    )?;
125    Ok(())
126}
127
128/// Marca como `committed` (chamado após Estágio C com sucesso).
129pub fn mark_committed(conn: &Connection, pending_id: i64) -> Result<(), AppError> {
130    conn.execute(
131        "UPDATE pending_memories
132         SET status = 'committed',
133             updated_at = unixepoch()
134         WHERE pending_id = ?1",
135        params![pending_id],
136    )?;
137    Ok(())
138}
139
140/// Marca como `failed` com mensagem de erro.
141pub fn mark_failed(conn: &Connection, pending_id: i64, error: &str) -> Result<(), AppError> {
142    conn.execute(
143        "UPDATE pending_memories
144         SET status = 'failed',
145             last_error = ?1,
146             updated_at = unixepoch()
147         WHERE pending_id = ?2",
148        params![error, pending_id],
149    )?;
150    Ok(())
151}
152
153/// Lista entradas por status, ordenadas por `updated_at` ascendente.
154pub fn list_by_status(
155    conn: &Connection,
156    status: PendingStatus,
157    limit: usize,
158) -> Result<Vec<PendingMemory>, AppError> {
159    let mut stmt = conn.prepare(
160        "SELECT pending_id, name, namespace, memory_type, description, body,
161                body_hash, entities_json, relationships_json, status,
162                embedding, embedding_dim, attempt_count, last_error,
163                created_at, updated_at
164         FROM pending_memories
165         WHERE status = ?1
166         ORDER BY updated_at ASC
167         LIMIT ?2",
168    )?;
169    let rows = stmt.query_map(params![status.as_str(), limit as i64], |row| {
170        Ok(PendingMemory {
171            pending_id: row.get(0)?,
172            name: row.get(1)?,
173            namespace: row.get(2)?,
174            memory_type: row.get(3)?,
175            description: row.get(4)?,
176            body: row.get(5)?,
177            body_hash: row.get(6)?,
178            entities_json: row.get(7)?,
179            relationships_json: row.get(8)?,
180            status: parse_status(&row.get::<_, String>(9)?).map_err(|e| -> rusqlite::Error {
181                rusqlite::Error::FromSqlConversionFailure(
182                    9,
183                    rusqlite::types::Type::Text,
184                    Box::new(std::io::Error::other(e.to_string())),
185                )
186            })?,
187            embedding: row.get(10)?,
188            embedding_dim: row.get(11)?,
189            attempt_count: row.get(12)?,
190            last_error: row.get(13)?,
191            created_at: row.get(14)?,
192            updated_at: row.get(15)?,
193        })
194    })?;
195    let mut pending = Vec::new();
196    for row in rows {
197        pending.push(row?);
198    }
199    Ok(pending)
200}
201
202/// Busca por `pending_id`.
203pub fn find_by_id(conn: &Connection, pending_id: i64) -> Result<Option<PendingMemory>, AppError> {
204    let mut stmt = conn.prepare(
205        "SELECT pending_id, name, namespace, memory_type, description, body,
206                body_hash, entities_json, relationships_json, status,
207                embedding, embedding_dim, attempt_count, last_error,
208                created_at, updated_at
209         FROM pending_memories
210         WHERE pending_id = ?1",
211    )?;
212    let mut rows = stmt.query(params![pending_id])?;
213    if let Some(row) = rows.next()? {
214        Ok(Some(PendingMemory {
215            pending_id: row.get(0)?,
216            name: row.get(1)?,
217            namespace: row.get(2)?,
218            memory_type: row.get(3)?,
219            description: row.get(4)?,
220            body: row.get(5)?,
221            body_hash: row.get(6)?,
222            entities_json: row.get(7)?,
223            relationships_json: row.get(8)?,
224            status: parse_status(row.get::<_, String>(9)?.as_str())?,
225            embedding: row.get(10)?,
226            embedding_dim: row.get(11)?,
227            attempt_count: row.get(12)?,
228            last_error: row.get(13)?,
229            created_at: row.get(14)?,
230            updated_at: row.get(15)?,
231        }))
232    } else {
233        Ok(None)
234    }
235}
236
237/// Remove entradas `embedding_in_progress` mais velhas que `older_than_secs`.
238/// Retorna o número de entradas removidas.
239pub fn cleanup_older_than(conn: &Connection, older_than_secs: i64) -> Result<usize, AppError> {
240    let cutoff = chrono::Utc::now().timestamp() - older_than_secs;
241    let count = conn.execute(
242        "DELETE FROM pending_memories
243         WHERE status IN ('embedding_in_progress', 'validated', 'failed')
244           AND updated_at < ?1",
245        params![cutoff],
246    )?;
247    Ok(count)
248}
249
250fn parse_status(s: &str) -> Result<PendingStatus, AppError> {
251    match s {
252        "validated" => Ok(PendingStatus::Validated),
253        "embedding_in_progress" => Ok(PendingStatus::EmbeddingInProgress),
254        "embedding_done" => Ok(PendingStatus::EmbeddingDone),
255        "committed" => Ok(PendingStatus::Committed),
256        "abandoned" => Ok(PendingStatus::Abandoned),
257        "failed" => Ok(PendingStatus::Failed),
258        other => Err(AppError::Validation(format!(
259            "unknown pending_memories status: {other}"
260        ))),
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use rusqlite::Connection;
268
269    fn fresh_db() -> Connection {
270        let mut conn = Connection::open_in_memory().expect("in-memory db");
271        conn.execute_batch("PRAGMA foreign_keys = ON;")
272            .expect("pragma");
273        crate::migrations::runner()
274            .run(&mut conn)
275            .expect("migrations apply");
276        conn
277    }
278
279    #[test]
280    fn insert_validated_returns_pending_id() {
281        let conn = fresh_db();
282        let id = insert_validated(
283            &conn,
284            "test-pending",
285            "global",
286            "note",
287            Some("desc"),
288            b"body bytes",
289            "blake3-hash-here",
290            None,
291            None,
292        )
293        .expect("insert");
294        assert!(id > 0);
295    }
296
297    #[test]
298    fn status_transition_validated_to_committed() {
299        let conn = fresh_db();
300        let id =
301            insert_validated(&conn, "x", "global", "note", None, b"b", "h", None, None).unwrap();
302        update_to_embedding_in_progress(&conn, id).unwrap();
303        let p = find_by_id(&conn, id).unwrap().unwrap();
304        assert_eq!(p.status, PendingStatus::EmbeddingInProgress);
305        assert_eq!(p.attempt_count, 1);
306
307        // Embedding BLOB é &[u8] little-endian — usar bytes brutos para teste
308        let fake_emb: Vec<u8> = vec![0u8; 64 * 4]; // 64 * 4 bytes
309        update_to_embedding_done(&conn, id, &fake_emb, 64).unwrap();
310        let p = find_by_id(&conn, id).unwrap().unwrap();
311        assert_eq!(p.status, PendingStatus::EmbeddingDone);
312        assert_eq!(p.embedding_dim, Some(64));
313
314        mark_committed(&conn, id).unwrap();
315        let p = find_by_id(&conn, id).unwrap().unwrap();
316        assert_eq!(p.status, PendingStatus::Committed);
317    }
318
319    #[test]
320    fn list_by_status_filters_correctly() {
321        let conn = fresh_db();
322        let id1 =
323            insert_validated(&conn, "a", "global", "note", None, b"b", "h", None, None).unwrap();
324        let _id2 =
325            insert_validated(&conn, "b", "global", "note", None, b"b", "h", None, None).unwrap();
326        mark_committed(&conn, id1).unwrap();
327        let validated = list_by_status(&conn, PendingStatus::Validated, 10).unwrap();
328        assert_eq!(validated.len(), 1);
329        assert_eq!(validated[0].name, "b");
330    }
331
332    #[test]
333    fn cleanup_older_than_removes_stale() {
334        let conn = fresh_db();
335        let _id = insert_validated(
336            &conn, "stale", "global", "note", None, b"b", "h", None, None,
337        )
338        .unwrap();
339        // Cleanup com cutoff no futuro = remove tudo
340        let removed = cleanup_older_than(&conn, -3600).unwrap();
341        assert_eq!(removed, 1);
342    }
343
344    #[test]
345    fn mark_failed_records_error() {
346        let conn = fresh_db();
347        let id =
348            insert_validated(&conn, "f", "global", "note", None, b"b", "h", None, None).unwrap();
349        mark_failed(&conn, id, "codex exited with OOM").unwrap();
350        let p = find_by_id(&conn, id).unwrap().unwrap();
351        assert_eq!(p.status, PendingStatus::Failed);
352        assert_eq!(p.last_error.as_deref(), Some("codex exited with OOM"));
353    }
354}