1use serde::{Deserialize, Serialize};
2
3use crate::error::MemoryError;
4use crate::sqlite::SqliteStore;
5use crate::types::ConversationId;
6
7#[derive(Debug, Serialize, Deserialize)]
8pub struct MemorySnapshot {
9 pub version: u32,
10 pub exported_at: String,
11 pub conversations: Vec<ConversationSnapshot>,
12}
13
14#[derive(Debug, Serialize, Deserialize)]
15pub struct ConversationSnapshot {
16 pub id: i64,
17 pub messages: Vec<MessageSnapshot>,
18 pub summaries: Vec<SummarySnapshot>,
19}
20
21#[derive(Debug, Serialize, Deserialize)]
22pub struct MessageSnapshot {
23 pub id: i64,
24 pub conversation_id: i64,
25 pub role: String,
26 pub content: String,
27 pub parts_json: String,
28 pub created_at: i64,
29}
30
31#[derive(Debug, Serialize, Deserialize)]
32pub struct SummarySnapshot {
33 pub id: i64,
34 pub conversation_id: i64,
35 pub content: String,
36 pub first_message_id: i64,
37 pub last_message_id: i64,
38 pub token_estimate: i64,
39}
40
41#[derive(Debug, Default)]
42pub struct ImportStats {
43 pub conversations_imported: usize,
44 pub messages_imported: usize,
45 pub summaries_imported: usize,
46 pub skipped: usize,
47}
48
49pub async fn export_snapshot(sqlite: &SqliteStore) -> Result<MemorySnapshot, MemoryError> {
55 let conv_ids: Vec<(i64,)> = sqlx::query_as("SELECT id FROM conversations ORDER BY id ASC")
56 .fetch_all(sqlite.pool())
57 .await?;
58
59 let exported_at = chrono_now();
60 let mut conversations = Vec::with_capacity(conv_ids.len());
61
62 for (cid_raw,) in conv_ids {
63 let cid = ConversationId(cid_raw);
64
65 let msg_rows: Vec<(i64, String, String, String, i64)> = sqlx::query_as(
66 "SELECT id, role, content, parts, \
67 COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) \
68 FROM messages WHERE conversation_id = ? ORDER BY id ASC",
69 )
70 .bind(cid)
71 .fetch_all(sqlite.pool())
72 .await?;
73
74 let messages = msg_rows
75 .into_iter()
76 .map(
77 |(id, role, content, parts_json, created_at)| MessageSnapshot {
78 id,
79 conversation_id: cid_raw,
80 role,
81 content,
82 parts_json,
83 created_at,
84 },
85 )
86 .collect();
87
88 let sum_rows = sqlite.load_summaries(cid).await?;
89 let summaries = sum_rows
90 .into_iter()
91 .map(
92 |(
93 id,
94 conversation_id,
95 content,
96 first_message_id,
97 last_message_id,
98 token_estimate,
99 )| {
100 SummarySnapshot {
101 id,
102 conversation_id: conversation_id.0,
103 content,
104 first_message_id: first_message_id.0,
105 last_message_id: last_message_id.0,
106 token_estimate,
107 }
108 },
109 )
110 .collect();
111
112 conversations.push(ConversationSnapshot {
113 id: cid_raw,
114 messages,
115 summaries,
116 });
117 }
118
119 Ok(MemorySnapshot {
120 version: 1,
121 exported_at,
122 conversations,
123 })
124}
125
126pub async fn import_snapshot(
134 sqlite: &SqliteStore,
135 snapshot: MemorySnapshot,
136) -> Result<ImportStats, MemoryError> {
137 if snapshot.version != 1 {
138 return Err(MemoryError::Snapshot(format!(
139 "unsupported snapshot version {}: only version 1 is supported",
140 snapshot.version
141 )));
142 }
143 let mut stats = ImportStats::default();
144
145 for conv in snapshot.conversations {
146 let exists: Option<(i64,)> = sqlx::query_as("SELECT id FROM conversations WHERE id = ?")
147 .bind(conv.id)
148 .fetch_optional(sqlite.pool())
149 .await?;
150
151 if exists.is_none() {
152 sqlx::query("INSERT INTO conversations (id) VALUES (?)")
153 .bind(conv.id)
154 .execute(sqlite.pool())
155 .await?;
156 stats.conversations_imported += 1;
157 } else {
158 stats.skipped += 1;
159 }
160
161 for msg in conv.messages {
162 let result = sqlx::query(
163 "INSERT OR IGNORE INTO messages (id, conversation_id, role, content, parts) \
164 VALUES (?, ?, ?, ?, ?)",
165 )
166 .bind(msg.id)
167 .bind(msg.conversation_id)
168 .bind(&msg.role)
169 .bind(&msg.content)
170 .bind(&msg.parts_json)
171 .execute(sqlite.pool())
172 .await?;
173
174 if result.rows_affected() > 0 {
175 stats.messages_imported += 1;
176 } else {
177 stats.skipped += 1;
178 }
179 }
180
181 for sum in conv.summaries {
182 let result = sqlx::query(
183 "INSERT OR IGNORE INTO summaries \
184 (id, conversation_id, content, first_message_id, last_message_id, token_estimate) \
185 VALUES (?, ?, ?, ?, ?, ?)",
186 )
187 .bind(sum.id)
188 .bind(sum.conversation_id)
189 .bind(&sum.content)
190 .bind(sum.first_message_id)
191 .bind(sum.last_message_id)
192 .bind(sum.token_estimate)
193 .execute(sqlite.pool())
194 .await?;
195
196 if result.rows_affected() > 0 {
197 stats.summaries_imported += 1;
198 } else {
199 stats.skipped += 1;
200 }
201 }
202 }
203
204 Ok(stats)
205}
206
207fn chrono_now() -> String {
208 use std::time::{SystemTime, UNIX_EPOCH};
209 let secs = SystemTime::now()
210 .duration_since(UNIX_EPOCH)
211 .unwrap_or_default()
212 .as_secs();
213 let (year, month, day, hour, min, sec) = unix_to_parts(secs);
215 format!("{year:04}-{month:02}-{day:02}T{hour:02}:{min:02}:{sec:02}Z")
216}
217
218fn unix_to_parts(secs: u64) -> (u64, u64, u64, u64, u64, u64) {
219 let sec = secs % 60;
220 let total_mins = secs / 60;
221 let min = total_mins % 60;
222 let total_hours = total_mins / 60;
223 let hour = total_hours % 24;
224 let total_days = total_hours / 24;
225
226 let adjusted = total_days + 719_468;
228 let era = adjusted / 146_097;
229 let doe = adjusted - era * 146_097;
230 let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146_096) / 365;
231 let year = yoe + era * 400;
232 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
233 let mp = (5 * doy + 2) / 153;
234 let day = doy - (153 * mp + 2) / 5 + 1;
235 let month = if mp < 10 { mp + 3 } else { mp - 9 };
236 let year = if month <= 2 { year + 1 } else { year };
237 (year, month, day, hour, min, sec)
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[tokio::test]
245 async fn export_empty_database() {
246 let store = SqliteStore::new(":memory:").await.unwrap();
247 let snapshot = export_snapshot(&store).await.unwrap();
248 assert_eq!(snapshot.version, 1);
249 assert!(snapshot.conversations.is_empty());
250 assert!(!snapshot.exported_at.is_empty());
251 }
252
253 #[tokio::test]
254 async fn export_import_roundtrip() {
255 let src = SqliteStore::new(":memory:").await.unwrap();
256 let cid = src.create_conversation().await.unwrap();
257 src.save_message(cid, "user", "hello export").await.unwrap();
258 src.save_message(cid, "assistant", "hi import")
259 .await
260 .unwrap();
261
262 let snapshot = export_snapshot(&src).await.unwrap();
263 assert_eq!(snapshot.conversations.len(), 1);
264 assert_eq!(snapshot.conversations[0].messages.len(), 2);
265
266 let dst = SqliteStore::new(":memory:").await.unwrap();
267 let stats = import_snapshot(&dst, snapshot).await.unwrap();
268 assert_eq!(stats.conversations_imported, 1);
269 assert_eq!(stats.messages_imported, 2);
270 assert_eq!(stats.skipped, 0);
271
272 let history = dst.load_history(cid, 50).await.unwrap();
273 assert_eq!(history.len(), 2);
274 assert_eq!(history[0].content, "hello export");
275 assert_eq!(history[1].content, "hi import");
276 }
277
278 #[tokio::test]
279 async fn import_duplicate_skips() {
280 let src = SqliteStore::new(":memory:").await.unwrap();
281 let cid = src.create_conversation().await.unwrap();
282 src.save_message(cid, "user", "msg").await.unwrap();
283
284 let snapshot = export_snapshot(&src).await.unwrap();
285
286 let dst = SqliteStore::new(":memory:").await.unwrap();
287 let stats1 = import_snapshot(&dst, snapshot).await.unwrap();
288 assert_eq!(stats1.messages_imported, 1);
289
290 let snapshot2 = export_snapshot(&src).await.unwrap();
291 let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
292 assert_eq!(stats2.messages_imported, 0);
293 assert!(stats2.skipped > 0);
294 }
295
296 #[tokio::test]
297 async fn import_existing_conversation_increments_skipped_not_imported() {
298 let src = SqliteStore::new(":memory:").await.unwrap();
299 let cid = src.create_conversation().await.unwrap();
300 src.save_message(cid, "user", "only message").await.unwrap();
301
302 let snapshot = export_snapshot(&src).await.unwrap();
303
304 let dst = SqliteStore::new(":memory:").await.unwrap();
306 let stats1 = import_snapshot(&dst, snapshot).await.unwrap();
307 assert_eq!(stats1.conversations_imported, 1);
308 assert_eq!(stats1.messages_imported, 1);
309
310 let snapshot2 = export_snapshot(&src).await.unwrap();
312 let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
313 assert_eq!(
314 stats2.conversations_imported, 0,
315 "existing conversation must not be counted as imported"
316 );
317 assert!(
319 stats2.skipped >= 1,
320 "re-importing an existing conversation must increment skipped"
321 );
322 }
323
324 #[tokio::test]
325 async fn export_includes_summaries() {
326 let store = SqliteStore::new(":memory:").await.unwrap();
327 let cid = store.create_conversation().await.unwrap();
328 let m1 = store.save_message(cid, "user", "a").await.unwrap();
329 let m2 = store.save_message(cid, "assistant", "b").await.unwrap();
330 store.save_summary(cid, "summary", m1, m2, 5).await.unwrap();
331
332 let snapshot = export_snapshot(&store).await.unwrap();
333 assert_eq!(snapshot.conversations[0].summaries.len(), 1);
334 assert_eq!(snapshot.conversations[0].summaries[0].content, "summary");
335 }
336
337 #[test]
338 fn chrono_now_not_empty() {
339 let ts = chrono_now();
340 assert!(ts.contains('T'));
341 assert!(ts.ends_with('Z'));
342 }
343
344 #[test]
345 fn import_corrupt_json_returns_error() {
346 let result = serde_json::from_str::<MemorySnapshot>("not valid json at all {{{");
347 assert!(result.is_err());
348 }
349
350 #[tokio::test]
351 async fn import_unsupported_version_returns_error() {
352 let store = SqliteStore::new(":memory:").await.unwrap();
353 let snapshot = MemorySnapshot {
354 version: 99,
355 exported_at: "2026-01-01T00:00:00Z".into(),
356 conversations: vec![],
357 };
358 let err = import_snapshot(&store, snapshot).await.unwrap_err();
359 let msg = err.to_string();
360 assert!(msg.contains("unsupported snapshot version 99"));
361 }
362
363 #[tokio::test]
364 async fn import_partial_overlap_adds_new_messages() {
365 let src = SqliteStore::new(":memory:").await.unwrap();
366 let cid = src.create_conversation().await.unwrap();
367 src.save_message(cid, "user", "existing message")
368 .await
369 .unwrap();
370
371 let snapshot1 = export_snapshot(&src).await.unwrap();
372
373 let dst = SqliteStore::new(":memory:").await.unwrap();
374 let stats1 = import_snapshot(&dst, snapshot1).await.unwrap();
375 assert_eq!(stats1.messages_imported, 1);
376
377 src.save_message(cid, "assistant", "new reply")
378 .await
379 .unwrap();
380 let snapshot2 = export_snapshot(&src).await.unwrap();
381 let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
382
383 assert_eq!(
384 stats2.messages_imported, 1,
385 "only the new message should be imported"
386 );
387 assert_eq!(
389 stats2.skipped, 2,
390 "existing conversation and duplicate message should be skipped"
391 );
392
393 let history = dst.load_history(cid, 50).await.unwrap();
394 assert_eq!(history.len(), 2);
395 assert_eq!(history[1].content, "new reply");
396 }
397}