Skip to main content

zeph_memory/
snapshot.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use serde::{Deserialize, Serialize};
5#[allow(unused_imports)]
6use zeph_db::sql;
7
8use zeph_db::ActiveDialect;
9
10use crate::error::MemoryError;
11use crate::store::SqliteStore;
12use crate::types::ConversationId;
13
14#[derive(Debug, Serialize, Deserialize)]
15pub struct MemorySnapshot {
16    pub version: u32,
17    pub exported_at: String,
18    pub conversations: Vec<ConversationSnapshot>,
19}
20
21#[derive(Debug, Serialize, Deserialize)]
22pub struct ConversationSnapshot {
23    pub id: i64,
24    pub messages: Vec<MessageSnapshot>,
25    pub summaries: Vec<SummarySnapshot>,
26}
27
28#[derive(Debug, Serialize, Deserialize)]
29pub struct MessageSnapshot {
30    pub id: i64,
31    pub conversation_id: i64,
32    pub role: String,
33    pub content: String,
34    pub parts_json: String,
35    pub created_at: i64,
36}
37
38#[derive(Debug, Serialize, Deserialize)]
39pub struct SummarySnapshot {
40    pub id: i64,
41    pub conversation_id: i64,
42    pub content: String,
43    pub first_message_id: Option<i64>,
44    pub last_message_id: Option<i64>,
45    pub token_estimate: i64,
46}
47
48#[derive(Debug, Default)]
49pub struct ImportStats {
50    pub conversations_imported: usize,
51    pub messages_imported: usize,
52    pub summaries_imported: usize,
53    pub skipped: usize,
54}
55
56/// Export all conversations, messages and summaries from `SQLite` into a snapshot.
57///
58/// # Errors
59///
60/// Returns an error if any database query fails.
61pub async fn export_snapshot(sqlite: &SqliteStore) -> Result<MemorySnapshot, MemoryError> {
62    let conv_ids: Vec<(i64,)> =
63        zeph_db::query_as(sql!("SELECT id FROM conversations ORDER BY id ASC"))
64            .fetch_all(sqlite.pool())
65            .await?;
66
67    let exported_at = chrono_now();
68    let mut conversations = Vec::with_capacity(conv_ids.len());
69
70    for (cid_raw,) in conv_ids {
71        let cid = ConversationId(cid_raw);
72
73        let epoch_expr = <ActiveDialect as zeph_db::dialect::Dialect>::epoch_from_col("created_at");
74        let msg_sql = zeph_db::rewrite_placeholders(&format!(
75            "SELECT id, role, content, parts, {epoch_expr} \
76            FROM messages WHERE conversation_id = ? ORDER BY id ASC"
77        ));
78        let msg_rows: Vec<(i64, String, String, String, i64)> = zeph_db::query_as(&msg_sql)
79            .bind(cid)
80            .fetch_all(sqlite.pool())
81            .await?;
82
83        let messages = msg_rows
84            .into_iter()
85            .map(
86                |(id, role, content, parts_json, created_at)| MessageSnapshot {
87                    id,
88                    conversation_id: cid_raw,
89                    role,
90                    content,
91                    parts_json,
92                    created_at,
93                },
94            )
95            .collect();
96
97        let sum_rows = sqlite.load_summaries(cid).await?;
98        let summaries = sum_rows
99            .into_iter()
100            .map(
101                |(
102                    id,
103                    conversation_id,
104                    content,
105                    first_message_id,
106                    last_message_id,
107                    token_estimate,
108                )| {
109                    SummarySnapshot {
110                        id,
111                        conversation_id: conversation_id.0,
112                        content,
113                        first_message_id: first_message_id.map(|m| m.0),
114                        last_message_id: last_message_id.map(|m| m.0),
115                        token_estimate,
116                    }
117                },
118            )
119            .collect();
120
121        conversations.push(ConversationSnapshot {
122            id: cid_raw,
123            messages,
124            summaries,
125        });
126    }
127
128    Ok(MemorySnapshot {
129        version: 1,
130        exported_at,
131        conversations,
132    })
133}
134
135/// Import a snapshot into `SQLite`, skipping duplicate entries.
136///
137/// Returns stats about what was imported.
138///
139/// # Errors
140///
141/// Returns an error if any database operation fails.
142pub async fn import_snapshot(
143    sqlite: &SqliteStore,
144    snapshot: MemorySnapshot,
145) -> Result<ImportStats, MemoryError> {
146    if snapshot.version != 1 {
147        return Err(MemoryError::Snapshot(format!(
148            "unsupported snapshot version {}: only version 1 is supported",
149            snapshot.version
150        )));
151    }
152    let mut stats = ImportStats::default();
153
154    for conv in snapshot.conversations {
155        let exists: Option<(i64,)> =
156            zeph_db::query_as(sql!("SELECT id FROM conversations WHERE id = ?"))
157                .bind(conv.id)
158                .fetch_optional(sqlite.pool())
159                .await?;
160
161        if exists.is_none() {
162            zeph_db::query(sql!("INSERT INTO conversations (id) VALUES (?)"))
163                .bind(conv.id)
164                .execute(sqlite.pool())
165                .await?;
166            stats.conversations_imported += 1;
167        } else {
168            stats.skipped += 1;
169        }
170
171        for msg in conv.messages {
172            let msg_sql = format!(
173                "{} INTO messages (id, conversation_id, role, content, parts) VALUES (?, ?, ?, ?, ?){}",
174                <ActiveDialect as zeph_db::dialect::Dialect>::INSERT_IGNORE,
175                <ActiveDialect as zeph_db::dialect::Dialect>::CONFLICT_NOTHING,
176            );
177            let result = zeph_db::query(&msg_sql)
178                .bind(msg.id)
179                .bind(msg.conversation_id)
180                .bind(&msg.role)
181                .bind(&msg.content)
182                .bind(&msg.parts_json)
183                .execute(sqlite.pool())
184                .await?;
185
186            if result.rows_affected() > 0 {
187                stats.messages_imported += 1;
188            } else {
189                stats.skipped += 1;
190            }
191        }
192
193        for sum in conv.summaries {
194            let sum_sql = format!(
195                "{} INTO summaries (id, conversation_id, content, first_message_id, last_message_id, token_estimate) VALUES (?, ?, ?, ?, ?, ?){}",
196                <ActiveDialect as zeph_db::dialect::Dialect>::INSERT_IGNORE,
197                <ActiveDialect as zeph_db::dialect::Dialect>::CONFLICT_NOTHING,
198            );
199            let result = zeph_db::query(&sum_sql)
200                .bind(sum.id)
201                .bind(sum.conversation_id)
202                .bind(&sum.content)
203                .bind(sum.first_message_id)
204                .bind(sum.last_message_id)
205                .bind(sum.token_estimate)
206                .execute(sqlite.pool())
207                .await?;
208
209            if result.rows_affected() > 0 {
210                stats.summaries_imported += 1;
211            } else {
212                stats.skipped += 1;
213            }
214        }
215    }
216
217    Ok(stats)
218}
219
220fn chrono_now() -> String {
221    use std::time::{SystemTime, UNIX_EPOCH};
222    let secs = SystemTime::now()
223        .duration_since(UNIX_EPOCH)
224        .unwrap_or_default()
225        .as_secs();
226    // Format as ISO-8601 approximation without chrono dependency
227    let (year, month, day, hour, min, sec) = unix_to_parts(secs);
228    format!("{year:04}-{month:02}-{day:02}T{hour:02}:{min:02}:{sec:02}Z")
229}
230
231fn unix_to_parts(secs: u64) -> (u64, u64, u64, u64, u64, u64) {
232    let sec = secs % 60;
233    let total_mins = secs / 60;
234    let min = total_mins % 60;
235    let total_hours = total_mins / 60;
236    let hour = total_hours % 24;
237    let total_days = total_hours / 24;
238
239    // Gregorian calendar calculation (civil date from days since Unix epoch)
240    let adjusted = total_days + 719_468;
241    let era = adjusted / 146_097;
242    let doe = adjusted - era * 146_097;
243    let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146_096) / 365;
244    let year = yoe + era * 400;
245    let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
246    let mp = (5 * doy + 2) / 153;
247    let day = doy - (153 * mp + 2) / 5 + 1;
248    let month = if mp < 10 { mp + 3 } else { mp - 9 };
249    let year = if month <= 2 { year + 1 } else { year };
250    (year, month, day, hour, min, sec)
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[tokio::test]
258    async fn export_empty_database() {
259        let store = SqliteStore::new(":memory:").await.unwrap();
260        let snapshot = export_snapshot(&store).await.unwrap();
261        assert_eq!(snapshot.version, 1);
262        assert!(snapshot.conversations.is_empty());
263        assert!(!snapshot.exported_at.is_empty());
264    }
265
266    #[tokio::test]
267    async fn export_import_roundtrip() {
268        let src = SqliteStore::new(":memory:").await.unwrap();
269        let cid = src.create_conversation().await.unwrap();
270        src.save_message(cid, "user", "hello export").await.unwrap();
271        src.save_message(cid, "assistant", "hi import")
272            .await
273            .unwrap();
274
275        let snapshot = export_snapshot(&src).await.unwrap();
276        assert_eq!(snapshot.conversations.len(), 1);
277        assert_eq!(snapshot.conversations[0].messages.len(), 2);
278
279        let dst = SqliteStore::new(":memory:").await.unwrap();
280        let stats = import_snapshot(&dst, snapshot).await.unwrap();
281        assert_eq!(stats.conversations_imported, 1);
282        assert_eq!(stats.messages_imported, 2);
283        assert_eq!(stats.skipped, 0);
284
285        let history = dst.load_history(cid, 50).await.unwrap();
286        assert_eq!(history.len(), 2);
287        assert_eq!(history[0].content, "hello export");
288        assert_eq!(history[1].content, "hi import");
289    }
290
291    #[tokio::test]
292    async fn import_duplicate_skips() {
293        let src = SqliteStore::new(":memory:").await.unwrap();
294        let cid = src.create_conversation().await.unwrap();
295        src.save_message(cid, "user", "msg").await.unwrap();
296
297        let snapshot = export_snapshot(&src).await.unwrap();
298
299        let dst = SqliteStore::new(":memory:").await.unwrap();
300        let stats1 = import_snapshot(&dst, snapshot).await.unwrap();
301        assert_eq!(stats1.messages_imported, 1);
302
303        let snapshot2 = export_snapshot(&src).await.unwrap();
304        let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
305        assert_eq!(stats2.messages_imported, 0);
306        assert!(stats2.skipped > 0);
307    }
308
309    #[tokio::test]
310    async fn import_existing_conversation_increments_skipped_not_imported() {
311        let src = SqliteStore::new(":memory:").await.unwrap();
312        let cid = src.create_conversation().await.unwrap();
313        src.save_message(cid, "user", "only message").await.unwrap();
314
315        let snapshot = export_snapshot(&src).await.unwrap();
316
317        // Import once — conversation is new.
318        let dst = SqliteStore::new(":memory:").await.unwrap();
319        let stats1 = import_snapshot(&dst, snapshot).await.unwrap();
320        assert_eq!(stats1.conversations_imported, 1);
321        assert_eq!(stats1.messages_imported, 1);
322
323        // Import again with no new messages — conversation already exists, must be counted as skipped.
324        let snapshot2 = export_snapshot(&src).await.unwrap();
325        let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
326        assert_eq!(
327            stats2.conversations_imported, 0,
328            "existing conversation must not be counted as imported"
329        );
330        // The conversation itself contributes one skipped, plus the duplicate message.
331        assert!(
332            stats2.skipped >= 1,
333            "re-importing an existing conversation must increment skipped"
334        );
335    }
336
337    #[tokio::test]
338    async fn export_includes_summaries() {
339        let store = SqliteStore::new(":memory:").await.unwrap();
340        let cid = store.create_conversation().await.unwrap();
341        let m1 = store.save_message(cid, "user", "a").await.unwrap();
342        let m2 = store.save_message(cid, "assistant", "b").await.unwrap();
343        store
344            .save_summary(cid, "summary", Some(m1), Some(m2), 5)
345            .await
346            .unwrap();
347
348        let snapshot = export_snapshot(&store).await.unwrap();
349        assert_eq!(snapshot.conversations[0].summaries.len(), 1);
350        assert_eq!(snapshot.conversations[0].summaries[0].content, "summary");
351    }
352
353    #[test]
354    fn chrono_now_not_empty() {
355        let ts = chrono_now();
356        assert!(ts.contains('T'));
357        assert!(ts.ends_with('Z'));
358    }
359
360    #[test]
361    fn import_corrupt_json_returns_error() {
362        let result = serde_json::from_str::<MemorySnapshot>("not valid json at all {{{");
363        assert!(result.is_err());
364    }
365
366    #[tokio::test]
367    async fn import_unsupported_version_returns_error() {
368        let store = SqliteStore::new(":memory:").await.unwrap();
369        let snapshot = MemorySnapshot {
370            version: 99,
371            exported_at: "2026-01-01T00:00:00Z".into(),
372            conversations: vec![],
373        };
374        let err = import_snapshot(&store, snapshot).await.unwrap_err();
375        let msg = err.to_string();
376        assert!(msg.contains("unsupported snapshot version 99"));
377    }
378
379    #[tokio::test]
380    async fn import_partial_overlap_adds_new_messages() {
381        let src = SqliteStore::new(":memory:").await.unwrap();
382        let cid = src.create_conversation().await.unwrap();
383        src.save_message(cid, "user", "existing message")
384            .await
385            .unwrap();
386
387        let snapshot1 = export_snapshot(&src).await.unwrap();
388
389        let dst = SqliteStore::new(":memory:").await.unwrap();
390        let stats1 = import_snapshot(&dst, snapshot1).await.unwrap();
391        assert_eq!(stats1.messages_imported, 1);
392
393        src.save_message(cid, "assistant", "new reply")
394            .await
395            .unwrap();
396        let snapshot2 = export_snapshot(&src).await.unwrap();
397        let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
398
399        assert_eq!(
400            stats2.messages_imported, 1,
401            "only the new message should be imported"
402        );
403        // skipped includes the existing conversation (1) plus the duplicate message (1).
404        assert_eq!(
405            stats2.skipped, 2,
406            "existing conversation and duplicate message should be skipped"
407        );
408
409        let history = dst.load_history(cid, 50).await.unwrap();
410        assert_eq!(history.len(), 2);
411        assert_eq!(history[1].content, "new reply");
412    }
413}