1use serde::{Deserialize, Serialize};
13#[allow(unused_imports)]
14use zeph_db::sql;
15
16use zeph_db::ActiveDialect;
17
18use crate::error::MemoryError;
19use crate::store::SqliteStore;
20use crate::types::ConversationId;
21
22#[derive(Debug, Serialize, Deserialize)]
24pub struct MemorySnapshot {
25 pub version: u32,
27 pub exported_at: String,
29 pub conversations: Vec<ConversationSnapshot>,
31}
32
33#[derive(Debug, Serialize, Deserialize)]
35pub struct ConversationSnapshot {
36 pub id: i64,
38 pub messages: Vec<MessageSnapshot>,
40 pub summaries: Vec<SummarySnapshot>,
42}
43
44#[derive(Debug, Serialize, Deserialize)]
46pub struct MessageSnapshot {
47 pub id: i64,
49 pub conversation_id: i64,
51 pub role: String,
53 pub content: String,
55 pub parts_json: String,
57 pub created_at: i64,
59}
60
61#[derive(Debug, Serialize, Deserialize)]
63pub struct SummarySnapshot {
64 pub id: i64,
66 pub conversation_id: i64,
68 pub content: String,
70 pub first_message_id: Option<i64>,
72 pub last_message_id: Option<i64>,
74 pub token_estimate: i64,
76}
77
78#[derive(Debug, Default)]
80pub struct ImportStats {
81 pub conversations_imported: usize,
83 pub messages_imported: usize,
85 pub summaries_imported: usize,
87 pub skipped: usize,
89}
90
91pub async fn export_snapshot(sqlite: &SqliteStore) -> Result<MemorySnapshot, MemoryError> {
97 let conv_ids: Vec<(i64,)> =
98 zeph_db::query_as(sql!("SELECT id FROM conversations ORDER BY id ASC"))
99 .fetch_all(sqlite.pool())
100 .await?;
101
102 let exported_at = chrono_now();
103 let mut conversations = Vec::with_capacity(conv_ids.len());
104
105 for (cid_raw,) in conv_ids {
106 let cid = ConversationId(cid_raw);
107
108 let epoch_expr = <ActiveDialect as zeph_db::dialect::Dialect>::epoch_from_col("created_at");
109 let msg_sql = zeph_db::rewrite_placeholders(&format!(
110 "SELECT id, role, content, parts, {epoch_expr} \
111 FROM messages WHERE conversation_id = ? ORDER BY id ASC"
112 ));
113 let msg_rows: Vec<(i64, String, String, String, i64)> = zeph_db::query_as(&msg_sql)
114 .bind(cid)
115 .fetch_all(sqlite.pool())
116 .await?;
117
118 let messages = msg_rows
119 .into_iter()
120 .map(
121 |(id, role, content, parts_json, created_at)| MessageSnapshot {
122 id,
123 conversation_id: cid_raw,
124 role,
125 content,
126 parts_json,
127 created_at,
128 },
129 )
130 .collect();
131
132 let sum_rows = sqlite.load_summaries(cid).await?;
133 let summaries = sum_rows
134 .into_iter()
135 .map(
136 |(
137 id,
138 conversation_id,
139 content,
140 first_message_id,
141 last_message_id,
142 token_estimate,
143 )| {
144 SummarySnapshot {
145 id,
146 conversation_id: conversation_id.0,
147 content,
148 first_message_id: first_message_id.map(|m| m.0),
149 last_message_id: last_message_id.map(|m| m.0),
150 token_estimate,
151 }
152 },
153 )
154 .collect();
155
156 conversations.push(ConversationSnapshot {
157 id: cid_raw,
158 messages,
159 summaries,
160 });
161 }
162
163 Ok(MemorySnapshot {
164 version: 1,
165 exported_at,
166 conversations,
167 })
168}
169
170pub async fn import_snapshot(
178 sqlite: &SqliteStore,
179 snapshot: MemorySnapshot,
180) -> Result<ImportStats, MemoryError> {
181 if snapshot.version != 1 {
182 return Err(MemoryError::Snapshot(format!(
183 "unsupported snapshot version {}: only version 1 is supported",
184 snapshot.version
185 )));
186 }
187 let mut stats = ImportStats::default();
188
189 for conv in snapshot.conversations {
190 let exists: Option<(i64,)> =
191 zeph_db::query_as(sql!("SELECT id FROM conversations WHERE id = ?"))
192 .bind(conv.id)
193 .fetch_optional(sqlite.pool())
194 .await?;
195
196 if exists.is_none() {
197 zeph_db::query(sql!("INSERT INTO conversations (id) VALUES (?)"))
198 .bind(conv.id)
199 .execute(sqlite.pool())
200 .await?;
201 stats.conversations_imported += 1;
202 } else {
203 stats.skipped += 1;
204 }
205
206 for msg in conv.messages {
207 let msg_sql = format!(
208 "{} INTO messages (id, conversation_id, role, content, parts) VALUES (?, ?, ?, ?, ?){}",
209 <ActiveDialect as zeph_db::dialect::Dialect>::INSERT_IGNORE,
210 <ActiveDialect as zeph_db::dialect::Dialect>::CONFLICT_NOTHING,
211 );
212 let result = zeph_db::query(&msg_sql)
213 .bind(msg.id)
214 .bind(msg.conversation_id)
215 .bind(&msg.role)
216 .bind(&msg.content)
217 .bind(&msg.parts_json)
218 .execute(sqlite.pool())
219 .await?;
220
221 if result.rows_affected() > 0 {
222 stats.messages_imported += 1;
223 } else {
224 stats.skipped += 1;
225 }
226 }
227
228 for sum in conv.summaries {
229 let sum_sql = format!(
230 "{} INTO summaries (id, conversation_id, content, first_message_id, last_message_id, token_estimate) VALUES (?, ?, ?, ?, ?, ?){}",
231 <ActiveDialect as zeph_db::dialect::Dialect>::INSERT_IGNORE,
232 <ActiveDialect as zeph_db::dialect::Dialect>::CONFLICT_NOTHING,
233 );
234 let result = zeph_db::query(&sum_sql)
235 .bind(sum.id)
236 .bind(sum.conversation_id)
237 .bind(&sum.content)
238 .bind(sum.first_message_id)
239 .bind(sum.last_message_id)
240 .bind(sum.token_estimate)
241 .execute(sqlite.pool())
242 .await?;
243
244 if result.rows_affected() > 0 {
245 stats.summaries_imported += 1;
246 } else {
247 stats.skipped += 1;
248 }
249 }
250 }
251
252 Ok(stats)
253}
254
255fn chrono_now() -> String {
256 use std::time::{SystemTime, UNIX_EPOCH};
257 let secs = SystemTime::now()
258 .duration_since(UNIX_EPOCH)
259 .unwrap_or_default()
260 .as_secs();
261 let (year, month, day, hour, min, sec) = unix_to_parts(secs);
263 format!("{year:04}-{month:02}-{day:02}T{hour:02}:{min:02}:{sec:02}Z")
264}
265
266fn unix_to_parts(secs: u64) -> (u64, u64, u64, u64, u64, u64) {
267 let sec = secs % 60;
268 let total_mins = secs / 60;
269 let min = total_mins % 60;
270 let total_hours = total_mins / 60;
271 let hour = total_hours % 24;
272 let total_days = total_hours / 24;
273
274 let adjusted = total_days + 719_468;
276 let era = adjusted / 146_097;
277 let doe = adjusted - era * 146_097;
278 let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146_096) / 365;
279 let year = yoe + era * 400;
280 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
281 let mp = (5 * doy + 2) / 153;
282 let day = doy - (153 * mp + 2) / 5 + 1;
283 let month = if mp < 10 { mp + 3 } else { mp - 9 };
284 let year = if month <= 2 { year + 1 } else { year };
285 (year, month, day, hour, min, sec)
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[tokio::test]
293 async fn export_empty_database() {
294 let store = SqliteStore::new(":memory:").await.unwrap();
295 let snapshot = export_snapshot(&store).await.unwrap();
296 assert_eq!(snapshot.version, 1);
297 assert!(snapshot.conversations.is_empty());
298 assert!(!snapshot.exported_at.is_empty());
299 }
300
301 #[tokio::test]
302 async fn export_import_roundtrip() {
303 let src = SqliteStore::new(":memory:").await.unwrap();
304 let cid = src.create_conversation().await.unwrap();
305 src.save_message(cid, "user", "hello export").await.unwrap();
306 src.save_message(cid, "assistant", "hi import")
307 .await
308 .unwrap();
309
310 let snapshot = export_snapshot(&src).await.unwrap();
311 assert_eq!(snapshot.conversations.len(), 1);
312 assert_eq!(snapshot.conversations[0].messages.len(), 2);
313
314 let dst = SqliteStore::new(":memory:").await.unwrap();
315 let stats = import_snapshot(&dst, snapshot).await.unwrap();
316 assert_eq!(stats.conversations_imported, 1);
317 assert_eq!(stats.messages_imported, 2);
318 assert_eq!(stats.skipped, 0);
319
320 let history = dst.load_history(cid, 50).await.unwrap();
321 assert_eq!(history.len(), 2);
322 assert_eq!(history[0].content, "hello export");
323 assert_eq!(history[1].content, "hi import");
324 }
325
326 #[tokio::test]
327 async fn import_duplicate_skips() {
328 let src = SqliteStore::new(":memory:").await.unwrap();
329 let cid = src.create_conversation().await.unwrap();
330 src.save_message(cid, "user", "msg").await.unwrap();
331
332 let snapshot = export_snapshot(&src).await.unwrap();
333
334 let dst = SqliteStore::new(":memory:").await.unwrap();
335 let stats1 = import_snapshot(&dst, snapshot).await.unwrap();
336 assert_eq!(stats1.messages_imported, 1);
337
338 let snapshot2 = export_snapshot(&src).await.unwrap();
339 let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
340 assert_eq!(stats2.messages_imported, 0);
341 assert!(stats2.skipped > 0);
342 }
343
344 #[tokio::test]
345 async fn import_existing_conversation_increments_skipped_not_imported() {
346 let src = SqliteStore::new(":memory:").await.unwrap();
347 let cid = src.create_conversation().await.unwrap();
348 src.save_message(cid, "user", "only message").await.unwrap();
349
350 let snapshot = export_snapshot(&src).await.unwrap();
351
352 let dst = SqliteStore::new(":memory:").await.unwrap();
354 let stats1 = import_snapshot(&dst, snapshot).await.unwrap();
355 assert_eq!(stats1.conversations_imported, 1);
356 assert_eq!(stats1.messages_imported, 1);
357
358 let snapshot2 = export_snapshot(&src).await.unwrap();
360 let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
361 assert_eq!(
362 stats2.conversations_imported, 0,
363 "existing conversation must not be counted as imported"
364 );
365 assert!(
367 stats2.skipped >= 1,
368 "re-importing an existing conversation must increment skipped"
369 );
370 }
371
372 #[tokio::test]
373 async fn export_includes_summaries() {
374 let store = SqliteStore::new(":memory:").await.unwrap();
375 let cid = store.create_conversation().await.unwrap();
376 let m1 = store.save_message(cid, "user", "a").await.unwrap();
377 let m2 = store.save_message(cid, "assistant", "b").await.unwrap();
378 store
379 .save_summary(cid, "summary", Some(m1), Some(m2), 5)
380 .await
381 .unwrap();
382
383 let snapshot = export_snapshot(&store).await.unwrap();
384 assert_eq!(snapshot.conversations[0].summaries.len(), 1);
385 assert_eq!(snapshot.conversations[0].summaries[0].content, "summary");
386 }
387
388 #[test]
389 fn chrono_now_not_empty() {
390 let ts = chrono_now();
391 assert!(ts.contains('T'));
392 assert!(ts.ends_with('Z'));
393 }
394
395 #[test]
396 fn import_corrupt_json_returns_error() {
397 let result = serde_json::from_str::<MemorySnapshot>("not valid json at all {{{");
398 assert!(result.is_err());
399 }
400
401 #[tokio::test]
402 async fn import_unsupported_version_returns_error() {
403 let store = SqliteStore::new(":memory:").await.unwrap();
404 let snapshot = MemorySnapshot {
405 version: 99,
406 exported_at: "2026-01-01T00:00:00Z".into(),
407 conversations: vec![],
408 };
409 let err = import_snapshot(&store, snapshot).await.unwrap_err();
410 let msg = err.to_string();
411 assert!(msg.contains("unsupported snapshot version 99"));
412 }
413
414 #[tokio::test]
415 async fn import_partial_overlap_adds_new_messages() {
416 let src = SqliteStore::new(":memory:").await.unwrap();
417 let cid = src.create_conversation().await.unwrap();
418 src.save_message(cid, "user", "existing message")
419 .await
420 .unwrap();
421
422 let snapshot1 = export_snapshot(&src).await.unwrap();
423
424 let dst = SqliteStore::new(":memory:").await.unwrap();
425 let stats1 = import_snapshot(&dst, snapshot1).await.unwrap();
426 assert_eq!(stats1.messages_imported, 1);
427
428 src.save_message(cid, "assistant", "new reply")
429 .await
430 .unwrap();
431 let snapshot2 = export_snapshot(&src).await.unwrap();
432 let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
433
434 assert_eq!(
435 stats2.messages_imported, 1,
436 "only the new message should be imported"
437 );
438 assert_eq!(
440 stats2.skipped, 2,
441 "existing conversation and duplicate message should be skipped"
442 );
443
444 let history = dst.load_history(cid, 50).await.unwrap();
445 assert_eq!(history.len(), 2);
446 assert_eq!(history[1].content, "new reply");
447 }
448}