1use serde::{Deserialize, Serialize};
5#[allow(unused_imports)]
6use zeph_db::sql;
7
8use zeph_db::ActiveDialect;
9
10use crate::error::MemoryError;
11use crate::store::SqliteStore;
12use crate::types::ConversationId;
13
14#[derive(Debug, Serialize, Deserialize)]
15pub struct MemorySnapshot {
16 pub version: u32,
17 pub exported_at: String,
18 pub conversations: Vec<ConversationSnapshot>,
19}
20
21#[derive(Debug, Serialize, Deserialize)]
22pub struct ConversationSnapshot {
23 pub id: i64,
24 pub messages: Vec<MessageSnapshot>,
25 pub summaries: Vec<SummarySnapshot>,
26}
27
28#[derive(Debug, Serialize, Deserialize)]
29pub struct MessageSnapshot {
30 pub id: i64,
31 pub conversation_id: i64,
32 pub role: String,
33 pub content: String,
34 pub parts_json: String,
35 pub created_at: i64,
36}
37
38#[derive(Debug, Serialize, Deserialize)]
39pub struct SummarySnapshot {
40 pub id: i64,
41 pub conversation_id: i64,
42 pub content: String,
43 pub first_message_id: Option<i64>,
44 pub last_message_id: Option<i64>,
45 pub token_estimate: i64,
46}
47
48#[derive(Debug, Default)]
49pub struct ImportStats {
50 pub conversations_imported: usize,
51 pub messages_imported: usize,
52 pub summaries_imported: usize,
53 pub skipped: usize,
54}
55
56pub async fn export_snapshot(sqlite: &SqliteStore) -> Result<MemorySnapshot, MemoryError> {
62 let conv_ids: Vec<(i64,)> =
63 zeph_db::query_as(sql!("SELECT id FROM conversations ORDER BY id ASC"))
64 .fetch_all(sqlite.pool())
65 .await?;
66
67 let exported_at = chrono_now();
68 let mut conversations = Vec::with_capacity(conv_ids.len());
69
70 for (cid_raw,) in conv_ids {
71 let cid = ConversationId(cid_raw);
72
73 let epoch_expr = <ActiveDialect as zeph_db::dialect::Dialect>::epoch_from_col("created_at");
74 let msg_sql = zeph_db::rewrite_placeholders(&format!(
75 "SELECT id, role, content, parts, {epoch_expr} \
76 FROM messages WHERE conversation_id = ? ORDER BY id ASC"
77 ));
78 let msg_rows: Vec<(i64, String, String, String, i64)> = zeph_db::query_as(&msg_sql)
79 .bind(cid)
80 .fetch_all(sqlite.pool())
81 .await?;
82
83 let messages = msg_rows
84 .into_iter()
85 .map(
86 |(id, role, content, parts_json, created_at)| MessageSnapshot {
87 id,
88 conversation_id: cid_raw,
89 role,
90 content,
91 parts_json,
92 created_at,
93 },
94 )
95 .collect();
96
97 let sum_rows = sqlite.load_summaries(cid).await?;
98 let summaries = sum_rows
99 .into_iter()
100 .map(
101 |(
102 id,
103 conversation_id,
104 content,
105 first_message_id,
106 last_message_id,
107 token_estimate,
108 )| {
109 SummarySnapshot {
110 id,
111 conversation_id: conversation_id.0,
112 content,
113 first_message_id: first_message_id.map(|m| m.0),
114 last_message_id: last_message_id.map(|m| m.0),
115 token_estimate,
116 }
117 },
118 )
119 .collect();
120
121 conversations.push(ConversationSnapshot {
122 id: cid_raw,
123 messages,
124 summaries,
125 });
126 }
127
128 Ok(MemorySnapshot {
129 version: 1,
130 exported_at,
131 conversations,
132 })
133}
134
135pub async fn import_snapshot(
143 sqlite: &SqliteStore,
144 snapshot: MemorySnapshot,
145) -> Result<ImportStats, MemoryError> {
146 if snapshot.version != 1 {
147 return Err(MemoryError::Snapshot(format!(
148 "unsupported snapshot version {}: only version 1 is supported",
149 snapshot.version
150 )));
151 }
152 let mut stats = ImportStats::default();
153
154 for conv in snapshot.conversations {
155 let exists: Option<(i64,)> =
156 zeph_db::query_as(sql!("SELECT id FROM conversations WHERE id = ?"))
157 .bind(conv.id)
158 .fetch_optional(sqlite.pool())
159 .await?;
160
161 if exists.is_none() {
162 zeph_db::query(sql!("INSERT INTO conversations (id) VALUES (?)"))
163 .bind(conv.id)
164 .execute(sqlite.pool())
165 .await?;
166 stats.conversations_imported += 1;
167 } else {
168 stats.skipped += 1;
169 }
170
171 for msg in conv.messages {
172 let msg_sql = format!(
173 "{} INTO messages (id, conversation_id, role, content, parts) VALUES (?, ?, ?, ?, ?){}",
174 <ActiveDialect as zeph_db::dialect::Dialect>::INSERT_IGNORE,
175 <ActiveDialect as zeph_db::dialect::Dialect>::CONFLICT_NOTHING,
176 );
177 let result = zeph_db::query(&msg_sql)
178 .bind(msg.id)
179 .bind(msg.conversation_id)
180 .bind(&msg.role)
181 .bind(&msg.content)
182 .bind(&msg.parts_json)
183 .execute(sqlite.pool())
184 .await?;
185
186 if result.rows_affected() > 0 {
187 stats.messages_imported += 1;
188 } else {
189 stats.skipped += 1;
190 }
191 }
192
193 for sum in conv.summaries {
194 let sum_sql = format!(
195 "{} INTO summaries (id, conversation_id, content, first_message_id, last_message_id, token_estimate) VALUES (?, ?, ?, ?, ?, ?){}",
196 <ActiveDialect as zeph_db::dialect::Dialect>::INSERT_IGNORE,
197 <ActiveDialect as zeph_db::dialect::Dialect>::CONFLICT_NOTHING,
198 );
199 let result = zeph_db::query(&sum_sql)
200 .bind(sum.id)
201 .bind(sum.conversation_id)
202 .bind(&sum.content)
203 .bind(sum.first_message_id)
204 .bind(sum.last_message_id)
205 .bind(sum.token_estimate)
206 .execute(sqlite.pool())
207 .await?;
208
209 if result.rows_affected() > 0 {
210 stats.summaries_imported += 1;
211 } else {
212 stats.skipped += 1;
213 }
214 }
215 }
216
217 Ok(stats)
218}
219
220fn chrono_now() -> String {
221 use std::time::{SystemTime, UNIX_EPOCH};
222 let secs = SystemTime::now()
223 .duration_since(UNIX_EPOCH)
224 .unwrap_or_default()
225 .as_secs();
226 let (year, month, day, hour, min, sec) = unix_to_parts(secs);
228 format!("{year:04}-{month:02}-{day:02}T{hour:02}:{min:02}:{sec:02}Z")
229}
230
231fn unix_to_parts(secs: u64) -> (u64, u64, u64, u64, u64, u64) {
232 let sec = secs % 60;
233 let total_mins = secs / 60;
234 let min = total_mins % 60;
235 let total_hours = total_mins / 60;
236 let hour = total_hours % 24;
237 let total_days = total_hours / 24;
238
239 let adjusted = total_days + 719_468;
241 let era = adjusted / 146_097;
242 let doe = adjusted - era * 146_097;
243 let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146_096) / 365;
244 let year = yoe + era * 400;
245 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
246 let mp = (5 * doy + 2) / 153;
247 let day = doy - (153 * mp + 2) / 5 + 1;
248 let month = if mp < 10 { mp + 3 } else { mp - 9 };
249 let year = if month <= 2 { year + 1 } else { year };
250 (year, month, day, hour, min, sec)
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[tokio::test]
258 async fn export_empty_database() {
259 let store = SqliteStore::new(":memory:").await.unwrap();
260 let snapshot = export_snapshot(&store).await.unwrap();
261 assert_eq!(snapshot.version, 1);
262 assert!(snapshot.conversations.is_empty());
263 assert!(!snapshot.exported_at.is_empty());
264 }
265
266 #[tokio::test]
267 async fn export_import_roundtrip() {
268 let src = SqliteStore::new(":memory:").await.unwrap();
269 let cid = src.create_conversation().await.unwrap();
270 src.save_message(cid, "user", "hello export").await.unwrap();
271 src.save_message(cid, "assistant", "hi import")
272 .await
273 .unwrap();
274
275 let snapshot = export_snapshot(&src).await.unwrap();
276 assert_eq!(snapshot.conversations.len(), 1);
277 assert_eq!(snapshot.conversations[0].messages.len(), 2);
278
279 let dst = SqliteStore::new(":memory:").await.unwrap();
280 let stats = import_snapshot(&dst, snapshot).await.unwrap();
281 assert_eq!(stats.conversations_imported, 1);
282 assert_eq!(stats.messages_imported, 2);
283 assert_eq!(stats.skipped, 0);
284
285 let history = dst.load_history(cid, 50).await.unwrap();
286 assert_eq!(history.len(), 2);
287 assert_eq!(history[0].content, "hello export");
288 assert_eq!(history[1].content, "hi import");
289 }
290
291 #[tokio::test]
292 async fn import_duplicate_skips() {
293 let src = SqliteStore::new(":memory:").await.unwrap();
294 let cid = src.create_conversation().await.unwrap();
295 src.save_message(cid, "user", "msg").await.unwrap();
296
297 let snapshot = export_snapshot(&src).await.unwrap();
298
299 let dst = SqliteStore::new(":memory:").await.unwrap();
300 let stats1 = import_snapshot(&dst, snapshot).await.unwrap();
301 assert_eq!(stats1.messages_imported, 1);
302
303 let snapshot2 = export_snapshot(&src).await.unwrap();
304 let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
305 assert_eq!(stats2.messages_imported, 0);
306 assert!(stats2.skipped > 0);
307 }
308
309 #[tokio::test]
310 async fn import_existing_conversation_increments_skipped_not_imported() {
311 let src = SqliteStore::new(":memory:").await.unwrap();
312 let cid = src.create_conversation().await.unwrap();
313 src.save_message(cid, "user", "only message").await.unwrap();
314
315 let snapshot = export_snapshot(&src).await.unwrap();
316
317 let dst = SqliteStore::new(":memory:").await.unwrap();
319 let stats1 = import_snapshot(&dst, snapshot).await.unwrap();
320 assert_eq!(stats1.conversations_imported, 1);
321 assert_eq!(stats1.messages_imported, 1);
322
323 let snapshot2 = export_snapshot(&src).await.unwrap();
325 let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
326 assert_eq!(
327 stats2.conversations_imported, 0,
328 "existing conversation must not be counted as imported"
329 );
330 assert!(
332 stats2.skipped >= 1,
333 "re-importing an existing conversation must increment skipped"
334 );
335 }
336
337 #[tokio::test]
338 async fn export_includes_summaries() {
339 let store = SqliteStore::new(":memory:").await.unwrap();
340 let cid = store.create_conversation().await.unwrap();
341 let m1 = store.save_message(cid, "user", "a").await.unwrap();
342 let m2 = store.save_message(cid, "assistant", "b").await.unwrap();
343 store
344 .save_summary(cid, "summary", Some(m1), Some(m2), 5)
345 .await
346 .unwrap();
347
348 let snapshot = export_snapshot(&store).await.unwrap();
349 assert_eq!(snapshot.conversations[0].summaries.len(), 1);
350 assert_eq!(snapshot.conversations[0].summaries[0].content, "summary");
351 }
352
353 #[test]
354 fn chrono_now_not_empty() {
355 let ts = chrono_now();
356 assert!(ts.contains('T'));
357 assert!(ts.ends_with('Z'));
358 }
359
360 #[test]
361 fn import_corrupt_json_returns_error() {
362 let result = serde_json::from_str::<MemorySnapshot>("not valid json at all {{{");
363 assert!(result.is_err());
364 }
365
366 #[tokio::test]
367 async fn import_unsupported_version_returns_error() {
368 let store = SqliteStore::new(":memory:").await.unwrap();
369 let snapshot = MemorySnapshot {
370 version: 99,
371 exported_at: "2026-01-01T00:00:00Z".into(),
372 conversations: vec![],
373 };
374 let err = import_snapshot(&store, snapshot).await.unwrap_err();
375 let msg = err.to_string();
376 assert!(msg.contains("unsupported snapshot version 99"));
377 }
378
379 #[tokio::test]
380 async fn import_partial_overlap_adds_new_messages() {
381 let src = SqliteStore::new(":memory:").await.unwrap();
382 let cid = src.create_conversation().await.unwrap();
383 src.save_message(cid, "user", "existing message")
384 .await
385 .unwrap();
386
387 let snapshot1 = export_snapshot(&src).await.unwrap();
388
389 let dst = SqliteStore::new(":memory:").await.unwrap();
390 let stats1 = import_snapshot(&dst, snapshot1).await.unwrap();
391 assert_eq!(stats1.messages_imported, 1);
392
393 src.save_message(cid, "assistant", "new reply")
394 .await
395 .unwrap();
396 let snapshot2 = export_snapshot(&src).await.unwrap();
397 let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
398
399 assert_eq!(
400 stats2.messages_imported, 1,
401 "only the new message should be imported"
402 );
403 assert_eq!(
405 stats2.skipped, 2,
406 "existing conversation and duplicate message should be skipped"
407 );
408
409 let history = dst.load_history(cid, 50).await.unwrap();
410 assert_eq!(history.len(), 2);
411 assert_eq!(history[1].content, "new reply");
412 }
413}