Skip to main content

zeph_memory/
snapshot.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::MemoryError;
7use crate::sqlite::SqliteStore;
8use crate::types::ConversationId;
9
10#[derive(Debug, Serialize, Deserialize)]
11pub struct MemorySnapshot {
12    pub version: u32,
13    pub exported_at: String,
14    pub conversations: Vec<ConversationSnapshot>,
15}
16
17#[derive(Debug, Serialize, Deserialize)]
18pub struct ConversationSnapshot {
19    pub id: i64,
20    pub messages: Vec<MessageSnapshot>,
21    pub summaries: Vec<SummarySnapshot>,
22}
23
24#[derive(Debug, Serialize, Deserialize)]
25pub struct MessageSnapshot {
26    pub id: i64,
27    pub conversation_id: i64,
28    pub role: String,
29    pub content: String,
30    pub parts_json: String,
31    pub created_at: i64,
32}
33
34#[derive(Debug, Serialize, Deserialize)]
35pub struct SummarySnapshot {
36    pub id: i64,
37    pub conversation_id: i64,
38    pub content: String,
39    pub first_message_id: i64,
40    pub last_message_id: i64,
41    pub token_estimate: i64,
42}
43
44#[derive(Debug, Default)]
45pub struct ImportStats {
46    pub conversations_imported: usize,
47    pub messages_imported: usize,
48    pub summaries_imported: usize,
49    pub skipped: usize,
50}
51
52/// Export all conversations, messages and summaries from `SQLite` into a snapshot.
53///
54/// # Errors
55///
56/// Returns an error if any database query fails.
57pub async fn export_snapshot(sqlite: &SqliteStore) -> Result<MemorySnapshot, MemoryError> {
58    let conv_ids: Vec<(i64,)> = sqlx::query_as("SELECT id FROM conversations ORDER BY id ASC")
59        .fetch_all(sqlite.pool())
60        .await?;
61
62    let exported_at = chrono_now();
63    let mut conversations = Vec::with_capacity(conv_ids.len());
64
65    for (cid_raw,) in conv_ids {
66        let cid = ConversationId(cid_raw);
67
68        let msg_rows: Vec<(i64, String, String, String, i64)> = sqlx::query_as(
69            "SELECT id, role, content, parts, \
70             COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) \
71             FROM messages WHERE conversation_id = ? ORDER BY id ASC",
72        )
73        .bind(cid)
74        .fetch_all(sqlite.pool())
75        .await?;
76
77        let messages = msg_rows
78            .into_iter()
79            .map(
80                |(id, role, content, parts_json, created_at)| MessageSnapshot {
81                    id,
82                    conversation_id: cid_raw,
83                    role,
84                    content,
85                    parts_json,
86                    created_at,
87                },
88            )
89            .collect();
90
91        let sum_rows = sqlite.load_summaries(cid).await?;
92        let summaries = sum_rows
93            .into_iter()
94            .map(
95                |(
96                    id,
97                    conversation_id,
98                    content,
99                    first_message_id,
100                    last_message_id,
101                    token_estimate,
102                )| {
103                    SummarySnapshot {
104                        id,
105                        conversation_id: conversation_id.0,
106                        content,
107                        first_message_id: first_message_id.0,
108                        last_message_id: last_message_id.0,
109                        token_estimate,
110                    }
111                },
112            )
113            .collect();
114
115        conversations.push(ConversationSnapshot {
116            id: cid_raw,
117            messages,
118            summaries,
119        });
120    }
121
122    Ok(MemorySnapshot {
123        version: 1,
124        exported_at,
125        conversations,
126    })
127}
128
129/// Import a snapshot into `SQLite`, skipping duplicate entries.
130///
131/// Returns stats about what was imported.
132///
133/// # Errors
134///
135/// Returns an error if any database operation fails.
136pub async fn import_snapshot(
137    sqlite: &SqliteStore,
138    snapshot: MemorySnapshot,
139) -> Result<ImportStats, MemoryError> {
140    if snapshot.version != 1 {
141        return Err(MemoryError::Snapshot(format!(
142            "unsupported snapshot version {}: only version 1 is supported",
143            snapshot.version
144        )));
145    }
146    let mut stats = ImportStats::default();
147
148    for conv in snapshot.conversations {
149        let exists: Option<(i64,)> = sqlx::query_as("SELECT id FROM conversations WHERE id = ?")
150            .bind(conv.id)
151            .fetch_optional(sqlite.pool())
152            .await?;
153
154        if exists.is_none() {
155            sqlx::query("INSERT INTO conversations (id) VALUES (?)")
156                .bind(conv.id)
157                .execute(sqlite.pool())
158                .await?;
159            stats.conversations_imported += 1;
160        } else {
161            stats.skipped += 1;
162        }
163
164        for msg in conv.messages {
165            let result = sqlx::query(
166                "INSERT OR IGNORE INTO messages (id, conversation_id, role, content, parts) \
167                 VALUES (?, ?, ?, ?, ?)",
168            )
169            .bind(msg.id)
170            .bind(msg.conversation_id)
171            .bind(&msg.role)
172            .bind(&msg.content)
173            .bind(&msg.parts_json)
174            .execute(sqlite.pool())
175            .await?;
176
177            if result.rows_affected() > 0 {
178                stats.messages_imported += 1;
179            } else {
180                stats.skipped += 1;
181            }
182        }
183
184        for sum in conv.summaries {
185            let result = sqlx::query(
186                "INSERT OR IGNORE INTO summaries \
187                 (id, conversation_id, content, first_message_id, last_message_id, token_estimate) \
188                 VALUES (?, ?, ?, ?, ?, ?)",
189            )
190            .bind(sum.id)
191            .bind(sum.conversation_id)
192            .bind(&sum.content)
193            .bind(sum.first_message_id)
194            .bind(sum.last_message_id)
195            .bind(sum.token_estimate)
196            .execute(sqlite.pool())
197            .await?;
198
199            if result.rows_affected() > 0 {
200                stats.summaries_imported += 1;
201            } else {
202                stats.skipped += 1;
203            }
204        }
205    }
206
207    Ok(stats)
208}
209
210fn chrono_now() -> String {
211    use std::time::{SystemTime, UNIX_EPOCH};
212    let secs = SystemTime::now()
213        .duration_since(UNIX_EPOCH)
214        .unwrap_or_default()
215        .as_secs();
216    // Format as ISO-8601 approximation without chrono dependency
217    let (year, month, day, hour, min, sec) = unix_to_parts(secs);
218    format!("{year:04}-{month:02}-{day:02}T{hour:02}:{min:02}:{sec:02}Z")
219}
220
221fn unix_to_parts(secs: u64) -> (u64, u64, u64, u64, u64, u64) {
222    let sec = secs % 60;
223    let total_mins = secs / 60;
224    let min = total_mins % 60;
225    let total_hours = total_mins / 60;
226    let hour = total_hours % 24;
227    let total_days = total_hours / 24;
228
229    // Gregorian calendar calculation (civil date from days since Unix epoch)
230    let adjusted = total_days + 719_468;
231    let era = adjusted / 146_097;
232    let doe = adjusted - era * 146_097;
233    let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146_096) / 365;
234    let year = yoe + era * 400;
235    let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
236    let mp = (5 * doy + 2) / 153;
237    let day = doy - (153 * mp + 2) / 5 + 1;
238    let month = if mp < 10 { mp + 3 } else { mp - 9 };
239    let year = if month <= 2 { year + 1 } else { year };
240    (year, month, day, hour, min, sec)
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[tokio::test]
248    async fn export_empty_database() {
249        let store = SqliteStore::new(":memory:").await.unwrap();
250        let snapshot = export_snapshot(&store).await.unwrap();
251        assert_eq!(snapshot.version, 1);
252        assert!(snapshot.conversations.is_empty());
253        assert!(!snapshot.exported_at.is_empty());
254    }
255
256    #[tokio::test]
257    async fn export_import_roundtrip() {
258        let src = SqliteStore::new(":memory:").await.unwrap();
259        let cid = src.create_conversation().await.unwrap();
260        src.save_message(cid, "user", "hello export").await.unwrap();
261        src.save_message(cid, "assistant", "hi import")
262            .await
263            .unwrap();
264
265        let snapshot = export_snapshot(&src).await.unwrap();
266        assert_eq!(snapshot.conversations.len(), 1);
267        assert_eq!(snapshot.conversations[0].messages.len(), 2);
268
269        let dst = SqliteStore::new(":memory:").await.unwrap();
270        let stats = import_snapshot(&dst, snapshot).await.unwrap();
271        assert_eq!(stats.conversations_imported, 1);
272        assert_eq!(stats.messages_imported, 2);
273        assert_eq!(stats.skipped, 0);
274
275        let history = dst.load_history(cid, 50).await.unwrap();
276        assert_eq!(history.len(), 2);
277        assert_eq!(history[0].content, "hello export");
278        assert_eq!(history[1].content, "hi import");
279    }
280
281    #[tokio::test]
282    async fn import_duplicate_skips() {
283        let src = SqliteStore::new(":memory:").await.unwrap();
284        let cid = src.create_conversation().await.unwrap();
285        src.save_message(cid, "user", "msg").await.unwrap();
286
287        let snapshot = export_snapshot(&src).await.unwrap();
288
289        let dst = SqliteStore::new(":memory:").await.unwrap();
290        let stats1 = import_snapshot(&dst, snapshot).await.unwrap();
291        assert_eq!(stats1.messages_imported, 1);
292
293        let snapshot2 = export_snapshot(&src).await.unwrap();
294        let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
295        assert_eq!(stats2.messages_imported, 0);
296        assert!(stats2.skipped > 0);
297    }
298
299    #[tokio::test]
300    async fn import_existing_conversation_increments_skipped_not_imported() {
301        let src = SqliteStore::new(":memory:").await.unwrap();
302        let cid = src.create_conversation().await.unwrap();
303        src.save_message(cid, "user", "only message").await.unwrap();
304
305        let snapshot = export_snapshot(&src).await.unwrap();
306
307        // Import once — conversation is new.
308        let dst = SqliteStore::new(":memory:").await.unwrap();
309        let stats1 = import_snapshot(&dst, snapshot).await.unwrap();
310        assert_eq!(stats1.conversations_imported, 1);
311        assert_eq!(stats1.messages_imported, 1);
312
313        // Import again with no new messages — conversation already exists, must be counted as skipped.
314        let snapshot2 = export_snapshot(&src).await.unwrap();
315        let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
316        assert_eq!(
317            stats2.conversations_imported, 0,
318            "existing conversation must not be counted as imported"
319        );
320        // The conversation itself contributes one skipped, plus the duplicate message.
321        assert!(
322            stats2.skipped >= 1,
323            "re-importing an existing conversation must increment skipped"
324        );
325    }
326
327    #[tokio::test]
328    async fn export_includes_summaries() {
329        let store = SqliteStore::new(":memory:").await.unwrap();
330        let cid = store.create_conversation().await.unwrap();
331        let m1 = store.save_message(cid, "user", "a").await.unwrap();
332        let m2 = store.save_message(cid, "assistant", "b").await.unwrap();
333        store.save_summary(cid, "summary", m1, m2, 5).await.unwrap();
334
335        let snapshot = export_snapshot(&store).await.unwrap();
336        assert_eq!(snapshot.conversations[0].summaries.len(), 1);
337        assert_eq!(snapshot.conversations[0].summaries[0].content, "summary");
338    }
339
340    #[test]
341    fn chrono_now_not_empty() {
342        let ts = chrono_now();
343        assert!(ts.contains('T'));
344        assert!(ts.ends_with('Z'));
345    }
346
347    #[test]
348    fn import_corrupt_json_returns_error() {
349        let result = serde_json::from_str::<MemorySnapshot>("not valid json at all {{{");
350        assert!(result.is_err());
351    }
352
353    #[tokio::test]
354    async fn import_unsupported_version_returns_error() {
355        let store = SqliteStore::new(":memory:").await.unwrap();
356        let snapshot = MemorySnapshot {
357            version: 99,
358            exported_at: "2026-01-01T00:00:00Z".into(),
359            conversations: vec![],
360        };
361        let err = import_snapshot(&store, snapshot).await.unwrap_err();
362        let msg = err.to_string();
363        assert!(msg.contains("unsupported snapshot version 99"));
364    }
365
366    #[tokio::test]
367    async fn import_partial_overlap_adds_new_messages() {
368        let src = SqliteStore::new(":memory:").await.unwrap();
369        let cid = src.create_conversation().await.unwrap();
370        src.save_message(cid, "user", "existing message")
371            .await
372            .unwrap();
373
374        let snapshot1 = export_snapshot(&src).await.unwrap();
375
376        let dst = SqliteStore::new(":memory:").await.unwrap();
377        let stats1 = import_snapshot(&dst, snapshot1).await.unwrap();
378        assert_eq!(stats1.messages_imported, 1);
379
380        src.save_message(cid, "assistant", "new reply")
381            .await
382            .unwrap();
383        let snapshot2 = export_snapshot(&src).await.unwrap();
384        let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
385
386        assert_eq!(
387            stats2.messages_imported, 1,
388            "only the new message should be imported"
389        );
390        // skipped includes the existing conversation (1) plus the duplicate message (1).
391        assert_eq!(
392            stats2.skipped, 2,
393            "existing conversation and duplicate message should be skipped"
394        );
395
396        let history = dst.load_history(cid, 50).await.unwrap();
397        assert_eq!(history.len(), 2);
398        assert_eq!(history[1].content, "new reply");
399    }
400}