1use serde::{Deserialize, Serialize};
5
6use crate::error::MemoryError;
7use crate::sqlite::SqliteStore;
8use crate::types::ConversationId;
9
10#[derive(Debug, Serialize, Deserialize)]
11pub struct MemorySnapshot {
12 pub version: u32,
13 pub exported_at: String,
14 pub conversations: Vec<ConversationSnapshot>,
15}
16
17#[derive(Debug, Serialize, Deserialize)]
18pub struct ConversationSnapshot {
19 pub id: i64,
20 pub messages: Vec<MessageSnapshot>,
21 pub summaries: Vec<SummarySnapshot>,
22}
23
24#[derive(Debug, Serialize, Deserialize)]
25pub struct MessageSnapshot {
26 pub id: i64,
27 pub conversation_id: i64,
28 pub role: String,
29 pub content: String,
30 pub parts_json: String,
31 pub created_at: i64,
32}
33
34#[derive(Debug, Serialize, Deserialize)]
35pub struct SummarySnapshot {
36 pub id: i64,
37 pub conversation_id: i64,
38 pub content: String,
39 pub first_message_id: i64,
40 pub last_message_id: i64,
41 pub token_estimate: i64,
42}
43
44#[derive(Debug, Default)]
45pub struct ImportStats {
46 pub conversations_imported: usize,
47 pub messages_imported: usize,
48 pub summaries_imported: usize,
49 pub skipped: usize,
50}
51
52pub async fn export_snapshot(sqlite: &SqliteStore) -> Result<MemorySnapshot, MemoryError> {
58 let conv_ids: Vec<(i64,)> = sqlx::query_as("SELECT id FROM conversations ORDER BY id ASC")
59 .fetch_all(sqlite.pool())
60 .await?;
61
62 let exported_at = chrono_now();
63 let mut conversations = Vec::with_capacity(conv_ids.len());
64
65 for (cid_raw,) in conv_ids {
66 let cid = ConversationId(cid_raw);
67
68 let msg_rows: Vec<(i64, String, String, String, i64)> = sqlx::query_as(
69 "SELECT id, role, content, parts, \
70 COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) \
71 FROM messages WHERE conversation_id = ? ORDER BY id ASC",
72 )
73 .bind(cid)
74 .fetch_all(sqlite.pool())
75 .await?;
76
77 let messages = msg_rows
78 .into_iter()
79 .map(
80 |(id, role, content, parts_json, created_at)| MessageSnapshot {
81 id,
82 conversation_id: cid_raw,
83 role,
84 content,
85 parts_json,
86 created_at,
87 },
88 )
89 .collect();
90
91 let sum_rows = sqlite.load_summaries(cid).await?;
92 let summaries = sum_rows
93 .into_iter()
94 .map(
95 |(
96 id,
97 conversation_id,
98 content,
99 first_message_id,
100 last_message_id,
101 token_estimate,
102 )| {
103 SummarySnapshot {
104 id,
105 conversation_id: conversation_id.0,
106 content,
107 first_message_id: first_message_id.0,
108 last_message_id: last_message_id.0,
109 token_estimate,
110 }
111 },
112 )
113 .collect();
114
115 conversations.push(ConversationSnapshot {
116 id: cid_raw,
117 messages,
118 summaries,
119 });
120 }
121
122 Ok(MemorySnapshot {
123 version: 1,
124 exported_at,
125 conversations,
126 })
127}
128
129pub async fn import_snapshot(
137 sqlite: &SqliteStore,
138 snapshot: MemorySnapshot,
139) -> Result<ImportStats, MemoryError> {
140 if snapshot.version != 1 {
141 return Err(MemoryError::Snapshot(format!(
142 "unsupported snapshot version {}: only version 1 is supported",
143 snapshot.version
144 )));
145 }
146 let mut stats = ImportStats::default();
147
148 for conv in snapshot.conversations {
149 let exists: Option<(i64,)> = sqlx::query_as("SELECT id FROM conversations WHERE id = ?")
150 .bind(conv.id)
151 .fetch_optional(sqlite.pool())
152 .await?;
153
154 if exists.is_none() {
155 sqlx::query("INSERT INTO conversations (id) VALUES (?)")
156 .bind(conv.id)
157 .execute(sqlite.pool())
158 .await?;
159 stats.conversations_imported += 1;
160 } else {
161 stats.skipped += 1;
162 }
163
164 for msg in conv.messages {
165 let result = sqlx::query(
166 "INSERT OR IGNORE INTO messages (id, conversation_id, role, content, parts) \
167 VALUES (?, ?, ?, ?, ?)",
168 )
169 .bind(msg.id)
170 .bind(msg.conversation_id)
171 .bind(&msg.role)
172 .bind(&msg.content)
173 .bind(&msg.parts_json)
174 .execute(sqlite.pool())
175 .await?;
176
177 if result.rows_affected() > 0 {
178 stats.messages_imported += 1;
179 } else {
180 stats.skipped += 1;
181 }
182 }
183
184 for sum in conv.summaries {
185 let result = sqlx::query(
186 "INSERT OR IGNORE INTO summaries \
187 (id, conversation_id, content, first_message_id, last_message_id, token_estimate) \
188 VALUES (?, ?, ?, ?, ?, ?)",
189 )
190 .bind(sum.id)
191 .bind(sum.conversation_id)
192 .bind(&sum.content)
193 .bind(sum.first_message_id)
194 .bind(sum.last_message_id)
195 .bind(sum.token_estimate)
196 .execute(sqlite.pool())
197 .await?;
198
199 if result.rows_affected() > 0 {
200 stats.summaries_imported += 1;
201 } else {
202 stats.skipped += 1;
203 }
204 }
205 }
206
207 Ok(stats)
208}
209
210fn chrono_now() -> String {
211 use std::time::{SystemTime, UNIX_EPOCH};
212 let secs = SystemTime::now()
213 .duration_since(UNIX_EPOCH)
214 .unwrap_or_default()
215 .as_secs();
216 let (year, month, day, hour, min, sec) = unix_to_parts(secs);
218 format!("{year:04}-{month:02}-{day:02}T{hour:02}:{min:02}:{sec:02}Z")
219}
220
221fn unix_to_parts(secs: u64) -> (u64, u64, u64, u64, u64, u64) {
222 let sec = secs % 60;
223 let total_mins = secs / 60;
224 let min = total_mins % 60;
225 let total_hours = total_mins / 60;
226 let hour = total_hours % 24;
227 let total_days = total_hours / 24;
228
229 let adjusted = total_days + 719_468;
231 let era = adjusted / 146_097;
232 let doe = adjusted - era * 146_097;
233 let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146_096) / 365;
234 let year = yoe + era * 400;
235 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
236 let mp = (5 * doy + 2) / 153;
237 let day = doy - (153 * mp + 2) / 5 + 1;
238 let month = if mp < 10 { mp + 3 } else { mp - 9 };
239 let year = if month <= 2 { year + 1 } else { year };
240 (year, month, day, hour, min, sec)
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[tokio::test]
248 async fn export_empty_database() {
249 let store = SqliteStore::new(":memory:").await.unwrap();
250 let snapshot = export_snapshot(&store).await.unwrap();
251 assert_eq!(snapshot.version, 1);
252 assert!(snapshot.conversations.is_empty());
253 assert!(!snapshot.exported_at.is_empty());
254 }
255
256 #[tokio::test]
257 async fn export_import_roundtrip() {
258 let src = SqliteStore::new(":memory:").await.unwrap();
259 let cid = src.create_conversation().await.unwrap();
260 src.save_message(cid, "user", "hello export").await.unwrap();
261 src.save_message(cid, "assistant", "hi import")
262 .await
263 .unwrap();
264
265 let snapshot = export_snapshot(&src).await.unwrap();
266 assert_eq!(snapshot.conversations.len(), 1);
267 assert_eq!(snapshot.conversations[0].messages.len(), 2);
268
269 let dst = SqliteStore::new(":memory:").await.unwrap();
270 let stats = import_snapshot(&dst, snapshot).await.unwrap();
271 assert_eq!(stats.conversations_imported, 1);
272 assert_eq!(stats.messages_imported, 2);
273 assert_eq!(stats.skipped, 0);
274
275 let history = dst.load_history(cid, 50).await.unwrap();
276 assert_eq!(history.len(), 2);
277 assert_eq!(history[0].content, "hello export");
278 assert_eq!(history[1].content, "hi import");
279 }
280
281 #[tokio::test]
282 async fn import_duplicate_skips() {
283 let src = SqliteStore::new(":memory:").await.unwrap();
284 let cid = src.create_conversation().await.unwrap();
285 src.save_message(cid, "user", "msg").await.unwrap();
286
287 let snapshot = export_snapshot(&src).await.unwrap();
288
289 let dst = SqliteStore::new(":memory:").await.unwrap();
290 let stats1 = import_snapshot(&dst, snapshot).await.unwrap();
291 assert_eq!(stats1.messages_imported, 1);
292
293 let snapshot2 = export_snapshot(&src).await.unwrap();
294 let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
295 assert_eq!(stats2.messages_imported, 0);
296 assert!(stats2.skipped > 0);
297 }
298
299 #[tokio::test]
300 async fn import_existing_conversation_increments_skipped_not_imported() {
301 let src = SqliteStore::new(":memory:").await.unwrap();
302 let cid = src.create_conversation().await.unwrap();
303 src.save_message(cid, "user", "only message").await.unwrap();
304
305 let snapshot = export_snapshot(&src).await.unwrap();
306
307 let dst = SqliteStore::new(":memory:").await.unwrap();
309 let stats1 = import_snapshot(&dst, snapshot).await.unwrap();
310 assert_eq!(stats1.conversations_imported, 1);
311 assert_eq!(stats1.messages_imported, 1);
312
313 let snapshot2 = export_snapshot(&src).await.unwrap();
315 let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
316 assert_eq!(
317 stats2.conversations_imported, 0,
318 "existing conversation must not be counted as imported"
319 );
320 assert!(
322 stats2.skipped >= 1,
323 "re-importing an existing conversation must increment skipped"
324 );
325 }
326
327 #[tokio::test]
328 async fn export_includes_summaries() {
329 let store = SqliteStore::new(":memory:").await.unwrap();
330 let cid = store.create_conversation().await.unwrap();
331 let m1 = store.save_message(cid, "user", "a").await.unwrap();
332 let m2 = store.save_message(cid, "assistant", "b").await.unwrap();
333 store.save_summary(cid, "summary", m1, m2, 5).await.unwrap();
334
335 let snapshot = export_snapshot(&store).await.unwrap();
336 assert_eq!(snapshot.conversations[0].summaries.len(), 1);
337 assert_eq!(snapshot.conversations[0].summaries[0].content, "summary");
338 }
339
340 #[test]
341 fn chrono_now_not_empty() {
342 let ts = chrono_now();
343 assert!(ts.contains('T'));
344 assert!(ts.ends_with('Z'));
345 }
346
347 #[test]
348 fn import_corrupt_json_returns_error() {
349 let result = serde_json::from_str::<MemorySnapshot>("not valid json at all {{{");
350 assert!(result.is_err());
351 }
352
353 #[tokio::test]
354 async fn import_unsupported_version_returns_error() {
355 let store = SqliteStore::new(":memory:").await.unwrap();
356 let snapshot = MemorySnapshot {
357 version: 99,
358 exported_at: "2026-01-01T00:00:00Z".into(),
359 conversations: vec![],
360 };
361 let err = import_snapshot(&store, snapshot).await.unwrap_err();
362 let msg = err.to_string();
363 assert!(msg.contains("unsupported snapshot version 99"));
364 }
365
366 #[tokio::test]
367 async fn import_partial_overlap_adds_new_messages() {
368 let src = SqliteStore::new(":memory:").await.unwrap();
369 let cid = src.create_conversation().await.unwrap();
370 src.save_message(cid, "user", "existing message")
371 .await
372 .unwrap();
373
374 let snapshot1 = export_snapshot(&src).await.unwrap();
375
376 let dst = SqliteStore::new(":memory:").await.unwrap();
377 let stats1 = import_snapshot(&dst, snapshot1).await.unwrap();
378 assert_eq!(stats1.messages_imported, 1);
379
380 src.save_message(cid, "assistant", "new reply")
381 .await
382 .unwrap();
383 let snapshot2 = export_snapshot(&src).await.unwrap();
384 let stats2 = import_snapshot(&dst, snapshot2).await.unwrap();
385
386 assert_eq!(
387 stats2.messages_imported, 1,
388 "only the new message should be imported"
389 );
390 assert_eq!(
392 stats2.skipped, 2,
393 "existing conversation and duplicate message should be skipped"
394 );
395
396 let history = dst.load_history(cid, 50).await.unwrap();
397 assert_eq!(history.len(), 2);
398 assert_eq!(history[1].content, "new reply");
399 }
400}