Skip to main content

zeph_memory/
snapshot.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Memory snapshot export and import.
5//!
6//! A [`MemorySnapshot`] serializes the full set of conversations, messages, and
7//! summaries from `SQLite` into a JSON document suitable for backup or migration.
8//!
9//! Import is idempotent: conversations and messages already present in the target
10//! database are skipped and counted in [`ImportStats::skipped`].
11
12use serde::{Deserialize, Serialize};
13#[allow(unused_imports)]
14use zeph_db::sql;
15
16use zeph_db::ActiveDialect;
17
18use crate::error::MemoryError;
19use crate::store::SqliteStore;
20use crate::types::ConversationId;
21
22/// Complete serializable snapshot of the `SQLite` memory store.
23#[derive(Debug, Serialize, Deserialize)]
24pub struct MemorySnapshot {
25    /// Schema version for forward-compatibility checks.
26    pub version: u32,
27    /// RFC 3339 timestamp when the snapshot was created.
28    pub exported_at: String,
29    /// All conversations with their messages and summaries.
30    pub conversations: Vec<ConversationSnapshot>,
31}
32
33/// One conversation and all its associated data in a [`MemorySnapshot`].
34#[derive(Debug, Serialize, Deserialize)]
35pub struct ConversationSnapshot {
36    /// `SQLite` row ID of the conversation.
37    pub id: i64,
38    /// All messages in this conversation.
39    pub messages: Vec<MessageSnapshot>,
40    /// Compression summaries for this conversation.
41    pub summaries: Vec<SummarySnapshot>,
42}
43
44/// A single message as stored in a [`MemorySnapshot`].
45#[derive(Debug, Serialize, Deserialize)]
46pub struct MessageSnapshot {
47    /// `SQLite` row ID.
48    pub id: i64,
49    /// Parent conversation ID.
50    pub conversation_id: i64,
51    /// Message role (`"user"`, `"assistant"`, `"system"`).
52    pub role: String,
53    /// Flattened text content.
54    pub content: String,
55    /// JSON-encoded `Vec<MessagePart>` payload.
56    pub parts_json: String,
57    /// Unix timestamp (seconds since epoch).
58    pub created_at: i64,
59}
60
61/// A compression summary record in a [`MemorySnapshot`].
62#[derive(Debug, Serialize, Deserialize)]
63pub struct SummarySnapshot {
64    /// `SQLite` row ID.
65    pub id: i64,
66    /// Parent conversation ID.
67    pub conversation_id: i64,
68    /// Summary text.
69    pub content: String,
70    /// Inclusive lower bound of the summarised message range.
71    pub first_message_id: Option<i64>,
72    /// Inclusive upper bound of the summarised message range.
73    pub last_message_id: Option<i64>,
74    /// Estimated token count of the summary content.
75    pub token_estimate: i64,
76}
77
78/// Counters returned by [`import_snapshot`].
79#[derive(Debug, Default)]
80pub struct ImportStats {
81    /// Number of conversations inserted.
82    pub conversations_imported: usize,
83    /// Number of messages inserted.
84    pub messages_imported: usize,
85    /// Number of summaries inserted.
86    pub summaries_imported: usize,
87    /// Number of rows skipped because they already existed.
88    pub skipped: usize,
89}
90
91/// Export all conversations, messages and summaries from `SQLite` into a snapshot.
92///
93/// # Errors
94///
95/// Returns an error if any database query fails.
96pub async fn export_snapshot(sqlite: &SqliteStore) -> Result<MemorySnapshot, MemoryError> {
97    let conv_ids: Vec<(i64,)> =
98        zeph_db::query_as(sql!("SELECT id FROM conversations ORDER BY id ASC"))
99            .fetch_all(sqlite.pool())
100            .await?;
101
102    let exported_at = chrono_now();
103    let mut conversations = Vec::with_capacity(conv_ids.len());
104
105    for (cid_raw,) in conv_ids {
106        let cid = ConversationId(cid_raw);
107
108        let epoch_expr = <ActiveDialect as zeph_db::dialect::Dialect>::epoch_from_col("created_at");
109        let msg_sql = zeph_db::rewrite_placeholders(&format!(
110            "SELECT id, role, content, parts, {epoch_expr} \
111            FROM messages WHERE conversation_id = ? ORDER BY id ASC"
112        ));
113        let msg_rows: Vec<(i64, String, String, String, i64)> = zeph_db::query_as(&msg_sql)
114            .bind(cid)
115            .fetch_all(sqlite.pool())
116            .await?;
117
118        let messages = msg_rows
119            .into_iter()
120            .map(
121                |(id, role, content, parts_json, created_at)| MessageSnapshot {
122                    id,
123                    conversation_id: cid_raw,
124                    role,
125                    content,
126                    parts_json,
127                    created_at,
128                },
129            )
130            .collect();
131
132        let sum_rows = sqlite.load_summaries(cid).await?;
133        let summaries = sum_rows
134            .into_iter()
135            .map(
136                |(
137                    id,
138                    conversation_id,
139                    content,
140                    first_message_id,
141                    last_message_id,
142                    token_estimate,
143                )| {
144                    SummarySnapshot {
145                        id,
146                        conversation_id: conversation_id.0,
147                        content,
148                        first_message_id: first_message_id.map(|m| m.0),
149                        last_message_id: last_message_id.map(|m| m.0),
150                        token_estimate,
151                    }
152                },
153            )
154            .collect();
155
156        conversations.push(ConversationSnapshot {
157            id: cid_raw,
158            messages,
159            summaries,
160        });
161    }
162
163    Ok(MemorySnapshot {
164        version: 1,
165        exported_at,
166        conversations,
167    })
168}
169
170/// Import a snapshot into `SQLite`, skipping duplicate entries.
171///
172/// Returns stats about what was imported.
173///
174/// # Errors
175///
176/// Returns an error if any database operation fails.
177pub async fn import_snapshot(
178    sqlite: &SqliteStore,
179    snapshot: MemorySnapshot,
180) -> Result<ImportStats, MemoryError> {
181    if snapshot.version != 1 {
182        return Err(MemoryError::Snapshot(format!(
183            "unsupported snapshot version {}: only version 1 is supported",
184            snapshot.version
185        )));
186    }
187    let mut stats = ImportStats::default();
188
189    for conv in snapshot.conversations {
190        let exists: Option<(i64,)> =
191            zeph_db::query_as(sql!("SELECT id FROM conversations WHERE id = ?"))
192                .bind(conv.id)
193                .fetch_optional(sqlite.pool())
194                .await?;
195
196        if exists.is_none() {
197            zeph_db::query(sql!("INSERT INTO conversations (id) VALUES (?)"))
198                .bind(conv.id)
199                .execute(sqlite.pool())
200                .await?;
201            stats.conversations_imported += 1;
202        } else {
203            stats.skipped += 1;
204        }
205
206        for msg in conv.messages {
207            let msg_sql = format!(
208                "{} INTO messages (id, conversation_id, role, content, parts) VALUES (?, ?, ?, ?, ?){}",
209                <ActiveDialect as zeph_db::dialect::Dialect>::INSERT_IGNORE,
210                <ActiveDialect as zeph_db::dialect::Dialect>::CONFLICT_NOTHING,
211            );
212            let result = zeph_db::query(&msg_sql)
213                .bind(msg.id)
214                .bind(msg.conversation_id)
215                .bind(&msg.role)
216                .bind(&msg.content)
217                .bind(&msg.parts_json)
218                .execute(sqlite.pool())
219                .await?;
220
221            if result.rows_affected() > 0 {
222                stats.messages_imported += 1;
223            } else {
224                stats.skipped += 1;
225            }
226        }
227
228        for sum in conv.summaries {
229            let sum_sql = format!(
230                "{} INTO summaries (id, conversation_id, content, first_message_id, last_message_id, token_estimate) VALUES (?, ?, ?, ?, ?, ?){}",
231                <ActiveDialect as zeph_db::dialect::Dialect>::INSERT_IGNORE,
232                <ActiveDialect as zeph_db::dialect::Dialect>::CONFLICT_NOTHING,
233            );
234            let result = zeph_db::query(&sum_sql)
235                .bind(sum.id)
236                .bind(sum.conversation_id)
237                .bind(&sum.content)
238                .bind(sum.first_message_id)
239                .bind(sum.last_message_id)
240                .bind(sum.token_estimate)
241                .execute(sqlite.pool())
242                .await?;
243
244            if result.rows_affected() > 0 {
245                stats.summaries_imported += 1;
246            } else {
247                stats.skipped += 1;
248            }
249        }
250    }
251
252    Ok(stats)
253}
254
255fn chrono_now() -> String {
256    use std::time::{SystemTime, UNIX_EPOCH};
257    let secs = SystemTime::now()
258        .duration_since(UNIX_EPOCH)
259        .unwrap_or_default()
260        .as_secs();
261    // Format as ISO-8601 approximation without chrono dependency
262    let (year, month, day, hour, min, sec) = unix_to_parts(secs);
263    format!("{year:04}-{month:02}-{day:02}T{hour:02}:{min:02}:{sec:02}Z")
264}
265
266fn unix_to_parts(secs: u64) -> (u64, u64, u64, u64, u64, u64) {
267    let sec = secs % 60;
268    let total_mins = secs / 60;
269    let min = total_mins % 60;
270    let total_hours = total_mins / 60;
271    let hour = total_hours % 24;
272    let total_days = total_hours / 24;
273
274    // Gregorian calendar calculation (civil date from days since Unix epoch)
275    let adjusted = total_days + 719_468;
276    let era = adjusted / 146_097;
277    let doe = adjusted - era * 146_097;
278    let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146_096) / 365;
279    let year = yoe + era * 400;
280    let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
281    let mp = (5 * doy + 2) / 153;
282    let day = doy - (153 * mp + 2) / 5 + 1;
283    let month = if mp < 10 { mp + 3 } else { mp - 9 };
284    let year = if month <= 2 { year + 1 } else { year };
285    (year, month, day, hour, min, sec)
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    #[tokio::test]
293    async fn export_empty_database() {
294        let store = SqliteStore::new(":memory:").await.unwrap();
295        let snapshot = export_snapshot(&store).await.unwrap();
296        assert_eq!(snapshot.version, 1);
297        assert!(snapshot.conversations.is_empty());
298        assert!(!snapshot.exported_at.is_empty());
299    }
300
301    #[tokio::test]
302    async fn export_import_roundtrip() {
303        let src = SqliteStore::new(":memory:").await.unwrap();
304        let cid = src.create_conversation().await.unwrap();
305        src.save_message(cid, "user", "hello export").await.unwrap();
306        src.save_message(cid, "assistant", "hi import")
307            .await
308            .unwrap();
309
310        let snapshot = export_snapshot(&src).await.unwrap();
311        assert_eq!(snapshot.conversations.len(), 1);
312        assert_eq!(snapshot.conversations[0].messages.len(), 2);
313
314        let dst = SqliteStore::new(":memory:").await.unwrap();
315        let stats = import_snapshot(&dst, snapshot).await.unwrap();
316        assert_eq!(stats.conversations_imported, 1);
317        assert_eq!(stats.messages_imported, 2);
318        assert_eq!(stats.skipped, 0);
319
320        let history = dst.load_history(cid, 50).await.unwrap();
321        assert_eq!(history.len(), 2);
322        assert_eq!(history[0].content, "hello export");
323        assert_eq!(history[1].content, "hi import");
324    }
325
326    #[tokio::test]
327    async fn import_duplicate_skips() {
328        let src = SqliteStore::new(":memory:").await.unwrap();
329        let cid = src.create_conversation().await.unwrap();
330        src.save_message(cid, "user", "msg").await.unwrap();
331
332        let snapshot = export_snapshot(&src).await.unwrap();
333
334        let dst = SqliteStore::new(":memory:").await.unwrap();
335        let stats1 = import_snapshot(&dst, snapshot).await.unwrap();
336        assert_eq!(stats1.messages_imported, 1);
337
338        let snapshot2 = export_snapshot(&src).await.unwrap();
339        let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
340        assert_eq!(stats2.messages_imported, 0);
341        assert!(stats2.skipped > 0);
342    }
343
344    #[tokio::test]
345    async fn import_existing_conversation_increments_skipped_not_imported() {
346        let src = SqliteStore::new(":memory:").await.unwrap();
347        let cid = src.create_conversation().await.unwrap();
348        src.save_message(cid, "user", "only message").await.unwrap();
349
350        let snapshot = export_snapshot(&src).await.unwrap();
351
352        // Import once — conversation is new.
353        let dst = SqliteStore::new(":memory:").await.unwrap();
354        let stats1 = import_snapshot(&dst, snapshot).await.unwrap();
355        assert_eq!(stats1.conversations_imported, 1);
356        assert_eq!(stats1.messages_imported, 1);
357
358        // Import again with no new messages — conversation already exists, must be counted as skipped.
359        let snapshot2 = export_snapshot(&src).await.unwrap();
360        let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
361        assert_eq!(
362            stats2.conversations_imported, 0,
363            "existing conversation must not be counted as imported"
364        );
365        // The conversation itself contributes one skipped, plus the duplicate message.
366        assert!(
367            stats2.skipped >= 1,
368            "re-importing an existing conversation must increment skipped"
369        );
370    }
371
372    #[tokio::test]
373    async fn export_includes_summaries() {
374        let store = SqliteStore::new(":memory:").await.unwrap();
375        let cid = store.create_conversation().await.unwrap();
376        let m1 = store.save_message(cid, "user", "a").await.unwrap();
377        let m2 = store.save_message(cid, "assistant", "b").await.unwrap();
378        store
379            .save_summary(cid, "summary", Some(m1), Some(m2), 5)
380            .await
381            .unwrap();
382
383        let snapshot = export_snapshot(&store).await.unwrap();
384        assert_eq!(snapshot.conversations[0].summaries.len(), 1);
385        assert_eq!(snapshot.conversations[0].summaries[0].content, "summary");
386    }
387
388    #[test]
389    fn chrono_now_not_empty() {
390        let ts = chrono_now();
391        assert!(ts.contains('T'));
392        assert!(ts.ends_with('Z'));
393    }
394
395    #[test]
396    fn import_corrupt_json_returns_error() {
397        let result = serde_json::from_str::<MemorySnapshot>("not valid json at all {{{");
398        assert!(result.is_err());
399    }
400
401    #[tokio::test]
402    async fn import_unsupported_version_returns_error() {
403        let store = SqliteStore::new(":memory:").await.unwrap();
404        let snapshot = MemorySnapshot {
405            version: 99,
406            exported_at: "2026-01-01T00:00:00Z".into(),
407            conversations: vec![],
408        };
409        let err = import_snapshot(&store, snapshot).await.unwrap_err();
410        let msg = err.to_string();
411        assert!(msg.contains("unsupported snapshot version 99"));
412    }
413
414    #[tokio::test]
415    async fn import_partial_overlap_adds_new_messages() {
416        let src = SqliteStore::new(":memory:").await.unwrap();
417        let cid = src.create_conversation().await.unwrap();
418        src.save_message(cid, "user", "existing message")
419            .await
420            .unwrap();
421
422        let snapshot1 = export_snapshot(&src).await.unwrap();
423
424        let dst = SqliteStore::new(":memory:").await.unwrap();
425        let stats1 = import_snapshot(&dst, snapshot1).await.unwrap();
426        assert_eq!(stats1.messages_imported, 1);
427
428        src.save_message(cid, "assistant", "new reply")
429            .await
430            .unwrap();
431        let snapshot2 = export_snapshot(&src).await.unwrap();
432        let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
433
434        assert_eq!(
435            stats2.messages_imported, 1,
436            "only the new message should be imported"
437        );
438        // skipped includes the existing conversation (1) plus the duplicate message (1).
439        assert_eq!(
440            stats2.skipped, 2,
441            "existing conversation and duplicate message should be skipped"
442        );
443
444        let history = dst.load_history(cid, 50).await.unwrap();
445        assert_eq!(history.len(), 2);
446        assert_eq!(history[1].content, "new reply");
447    }
448}