Skip to main content

zeph_memory/
snapshot.rs

1use serde::{Deserialize, Serialize};
2
3use crate::error::MemoryError;
4use crate::sqlite::SqliteStore;
5use crate::types::ConversationId;
6
7#[derive(Debug, Serialize, Deserialize)]
8pub struct MemorySnapshot {
9    pub version: u32,
10    pub exported_at: String,
11    pub conversations: Vec<ConversationSnapshot>,
12}
13
14#[derive(Debug, Serialize, Deserialize)]
15pub struct ConversationSnapshot {
16    pub id: i64,
17    pub messages: Vec<MessageSnapshot>,
18    pub summaries: Vec<SummarySnapshot>,
19}
20
21#[derive(Debug, Serialize, Deserialize)]
22pub struct MessageSnapshot {
23    pub id: i64,
24    pub conversation_id: i64,
25    pub role: String,
26    pub content: String,
27    pub parts_json: String,
28    pub created_at: i64,
29}
30
31#[derive(Debug, Serialize, Deserialize)]
32pub struct SummarySnapshot {
33    pub id: i64,
34    pub conversation_id: i64,
35    pub content: String,
36    pub first_message_id: i64,
37    pub last_message_id: i64,
38    pub token_estimate: i64,
39}
40
41#[derive(Debug, Default)]
42pub struct ImportStats {
43    pub conversations_imported: usize,
44    pub messages_imported: usize,
45    pub summaries_imported: usize,
46    pub skipped: usize,
47}
48
49/// Export all conversations, messages and summaries from `SQLite` into a snapshot.
50///
51/// # Errors
52///
53/// Returns an error if any database query fails.
54pub async fn export_snapshot(sqlite: &SqliteStore) -> Result<MemorySnapshot, MemoryError> {
55    let conv_ids: Vec<(i64,)> = sqlx::query_as("SELECT id FROM conversations ORDER BY id ASC")
56        .fetch_all(sqlite.pool())
57        .await?;
58
59    let exported_at = chrono_now();
60    let mut conversations = Vec::with_capacity(conv_ids.len());
61
62    for (cid_raw,) in conv_ids {
63        let cid = ConversationId(cid_raw);
64
65        let msg_rows: Vec<(i64, String, String, String, i64)> = sqlx::query_as(
66            "SELECT id, role, content, parts, \
67             COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) \
68             FROM messages WHERE conversation_id = ? ORDER BY id ASC",
69        )
70        .bind(cid)
71        .fetch_all(sqlite.pool())
72        .await?;
73
74        let messages = msg_rows
75            .into_iter()
76            .map(
77                |(id, role, content, parts_json, created_at)| MessageSnapshot {
78                    id,
79                    conversation_id: cid_raw,
80                    role,
81                    content,
82                    parts_json,
83                    created_at,
84                },
85            )
86            .collect();
87
88        let sum_rows = sqlite.load_summaries(cid).await?;
89        let summaries = sum_rows
90            .into_iter()
91            .map(
92                |(
93                    id,
94                    conversation_id,
95                    content,
96                    first_message_id,
97                    last_message_id,
98                    token_estimate,
99                )| {
100                    SummarySnapshot {
101                        id,
102                        conversation_id: conversation_id.0,
103                        content,
104                        first_message_id: first_message_id.0,
105                        last_message_id: last_message_id.0,
106                        token_estimate,
107                    }
108                },
109            )
110            .collect();
111
112        conversations.push(ConversationSnapshot {
113            id: cid_raw,
114            messages,
115            summaries,
116        });
117    }
118
119    Ok(MemorySnapshot {
120        version: 1,
121        exported_at,
122        conversations,
123    })
124}
125
126/// Import a snapshot into `SQLite`, skipping duplicate entries.
127///
128/// Returns stats about what was imported.
129///
130/// # Errors
131///
132/// Returns an error if any database operation fails.
133pub async fn import_snapshot(
134    sqlite: &SqliteStore,
135    snapshot: MemorySnapshot,
136) -> Result<ImportStats, MemoryError> {
137    if snapshot.version != 1 {
138        return Err(MemoryError::Snapshot(format!(
139            "unsupported snapshot version {}: only version 1 is supported",
140            snapshot.version
141        )));
142    }
143    let mut stats = ImportStats::default();
144
145    for conv in snapshot.conversations {
146        let exists: Option<(i64,)> = sqlx::query_as("SELECT id FROM conversations WHERE id = ?")
147            .bind(conv.id)
148            .fetch_optional(sqlite.pool())
149            .await?;
150
151        if exists.is_none() {
152            sqlx::query("INSERT INTO conversations (id) VALUES (?)")
153                .bind(conv.id)
154                .execute(sqlite.pool())
155                .await?;
156            stats.conversations_imported += 1;
157        } else {
158            stats.skipped += 1;
159        }
160
161        for msg in conv.messages {
162            let result = sqlx::query(
163                "INSERT OR IGNORE INTO messages (id, conversation_id, role, content, parts) \
164                 VALUES (?, ?, ?, ?, ?)",
165            )
166            .bind(msg.id)
167            .bind(msg.conversation_id)
168            .bind(&msg.role)
169            .bind(&msg.content)
170            .bind(&msg.parts_json)
171            .execute(sqlite.pool())
172            .await?;
173
174            if result.rows_affected() > 0 {
175                stats.messages_imported += 1;
176            } else {
177                stats.skipped += 1;
178            }
179        }
180
181        for sum in conv.summaries {
182            let result = sqlx::query(
183                "INSERT OR IGNORE INTO summaries \
184                 (id, conversation_id, content, first_message_id, last_message_id, token_estimate) \
185                 VALUES (?, ?, ?, ?, ?, ?)",
186            )
187            .bind(sum.id)
188            .bind(sum.conversation_id)
189            .bind(&sum.content)
190            .bind(sum.first_message_id)
191            .bind(sum.last_message_id)
192            .bind(sum.token_estimate)
193            .execute(sqlite.pool())
194            .await?;
195
196            if result.rows_affected() > 0 {
197                stats.summaries_imported += 1;
198            } else {
199                stats.skipped += 1;
200            }
201        }
202    }
203
204    Ok(stats)
205}
206
207fn chrono_now() -> String {
208    use std::time::{SystemTime, UNIX_EPOCH};
209    let secs = SystemTime::now()
210        .duration_since(UNIX_EPOCH)
211        .unwrap_or_default()
212        .as_secs();
213    // Format as ISO-8601 approximation without chrono dependency
214    let (year, month, day, hour, min, sec) = unix_to_parts(secs);
215    format!("{year:04}-{month:02}-{day:02}T{hour:02}:{min:02}:{sec:02}Z")
216}
217
218fn unix_to_parts(secs: u64) -> (u64, u64, u64, u64, u64, u64) {
219    let sec = secs % 60;
220    let total_mins = secs / 60;
221    let min = total_mins % 60;
222    let total_hours = total_mins / 60;
223    let hour = total_hours % 24;
224    let total_days = total_hours / 24;
225
226    // Gregorian calendar calculation (civil date from days since Unix epoch)
227    let adjusted = total_days + 719_468;
228    let era = adjusted / 146_097;
229    let doe = adjusted - era * 146_097;
230    let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146_096) / 365;
231    let year = yoe + era * 400;
232    let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
233    let mp = (5 * doy + 2) / 153;
234    let day = doy - (153 * mp + 2) / 5 + 1;
235    let month = if mp < 10 { mp + 3 } else { mp - 9 };
236    let year = if month <= 2 { year + 1 } else { year };
237    (year, month, day, hour, min, sec)
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[tokio::test]
245    async fn export_empty_database() {
246        let store = SqliteStore::new(":memory:").await.unwrap();
247        let snapshot = export_snapshot(&store).await.unwrap();
248        assert_eq!(snapshot.version, 1);
249        assert!(snapshot.conversations.is_empty());
250        assert!(!snapshot.exported_at.is_empty());
251    }
252
253    #[tokio::test]
254    async fn export_import_roundtrip() {
255        let src = SqliteStore::new(":memory:").await.unwrap();
256        let cid = src.create_conversation().await.unwrap();
257        src.save_message(cid, "user", "hello export").await.unwrap();
258        src.save_message(cid, "assistant", "hi import")
259            .await
260            .unwrap();
261
262        let snapshot = export_snapshot(&src).await.unwrap();
263        assert_eq!(snapshot.conversations.len(), 1);
264        assert_eq!(snapshot.conversations[0].messages.len(), 2);
265
266        let dst = SqliteStore::new(":memory:").await.unwrap();
267        let stats = import_snapshot(&dst, snapshot).await.unwrap();
268        assert_eq!(stats.conversations_imported, 1);
269        assert_eq!(stats.messages_imported, 2);
270        assert_eq!(stats.skipped, 0);
271
272        let history = dst.load_history(cid, 50).await.unwrap();
273        assert_eq!(history.len(), 2);
274        assert_eq!(history[0].content, "hello export");
275        assert_eq!(history[1].content, "hi import");
276    }
277
278    #[tokio::test]
279    async fn import_duplicate_skips() {
280        let src = SqliteStore::new(":memory:").await.unwrap();
281        let cid = src.create_conversation().await.unwrap();
282        src.save_message(cid, "user", "msg").await.unwrap();
283
284        let snapshot = export_snapshot(&src).await.unwrap();
285
286        let dst = SqliteStore::new(":memory:").await.unwrap();
287        let stats1 = import_snapshot(&dst, snapshot).await.unwrap();
288        assert_eq!(stats1.messages_imported, 1);
289
290        let snapshot2 = export_snapshot(&src).await.unwrap();
291        let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
292        assert_eq!(stats2.messages_imported, 0);
293        assert!(stats2.skipped > 0);
294    }
295
296    #[tokio::test]
297    async fn import_existing_conversation_increments_skipped_not_imported() {
298        let src = SqliteStore::new(":memory:").await.unwrap();
299        let cid = src.create_conversation().await.unwrap();
300        src.save_message(cid, "user", "only message").await.unwrap();
301
302        let snapshot = export_snapshot(&src).await.unwrap();
303
304        // Import once — conversation is new.
305        let dst = SqliteStore::new(":memory:").await.unwrap();
306        let stats1 = import_snapshot(&dst, snapshot).await.unwrap();
307        assert_eq!(stats1.conversations_imported, 1);
308        assert_eq!(stats1.messages_imported, 1);
309
310        // Import again with no new messages — conversation already exists, must be counted as skipped.
311        let snapshot2 = export_snapshot(&src).await.unwrap();
312        let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
313        assert_eq!(
314            stats2.conversations_imported, 0,
315            "existing conversation must not be counted as imported"
316        );
317        // The conversation itself contributes one skipped, plus the duplicate message.
318        assert!(
319            stats2.skipped >= 1,
320            "re-importing an existing conversation must increment skipped"
321        );
322    }
323
324    #[tokio::test]
325    async fn export_includes_summaries() {
326        let store = SqliteStore::new(":memory:").await.unwrap();
327        let cid = store.create_conversation().await.unwrap();
328        let m1 = store.save_message(cid, "user", "a").await.unwrap();
329        let m2 = store.save_message(cid, "assistant", "b").await.unwrap();
330        store.save_summary(cid, "summary", m1, m2, 5).await.unwrap();
331
332        let snapshot = export_snapshot(&store).await.unwrap();
333        assert_eq!(snapshot.conversations[0].summaries.len(), 1);
334        assert_eq!(snapshot.conversations[0].summaries[0].content, "summary");
335    }
336
337    #[test]
338    fn chrono_now_not_empty() {
339        let ts = chrono_now();
340        assert!(ts.contains('T'));
341        assert!(ts.ends_with('Z'));
342    }
343
344    #[test]
345    fn import_corrupt_json_returns_error() {
346        let result = serde_json::from_str::<MemorySnapshot>("not valid json at all {{{");
347        assert!(result.is_err());
348    }
349
350    #[tokio::test]
351    async fn import_unsupported_version_returns_error() {
352        let store = SqliteStore::new(":memory:").await.unwrap();
353        let snapshot = MemorySnapshot {
354            version: 99,
355            exported_at: "2026-01-01T00:00:00Z".into(),
356            conversations: vec![],
357        };
358        let err = import_snapshot(&store, snapshot).await.unwrap_err();
359        let msg = err.to_string();
360        assert!(msg.contains("unsupported snapshot version 99"));
361    }
362
363    #[tokio::test]
364    async fn import_partial_overlap_adds_new_messages() {
365        let src = SqliteStore::new(":memory:").await.unwrap();
366        let cid = src.create_conversation().await.unwrap();
367        src.save_message(cid, "user", "existing message")
368            .await
369            .unwrap();
370
371        let snapshot1 = export_snapshot(&src).await.unwrap();
372
373        let dst = SqliteStore::new(":memory:").await.unwrap();
374        let stats1 = import_snapshot(&dst, snapshot1).await.unwrap();
375        assert_eq!(stats1.messages_imported, 1);
376
377        src.save_message(cid, "assistant", "new reply")
378            .await
379            .unwrap();
380        let snapshot2 = export_snapshot(&src).await.unwrap();
381        let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
382
383        assert_eq!(
384            stats2.messages_imported, 1,
385            "only the new message should be imported"
386        );
387        // skipped includes the existing conversation (1) plus the duplicate message (1).
388        assert_eq!(
389            stats2.skipped, 2,
390            "existing conversation and duplicate message should be skipped"
391        );
392
393        let history = dst.load_history(cid, 50).await.unwrap();
394        assert_eq!(history.len(), 2);
395        assert_eq!(history[1].content, "new reply");
396    }
397}