1use crate::types::{Layer3Result, MemoryEntry};
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use std::path::{Path, PathBuf};
12use tokio::fs;
13use tokio::io::AsyncWriteExt;
14
15pub const STORAGE_VERSION: u32 = 1;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct StorageContainer {
21 pub version: u32,
23 pub session_id: String,
25 pub created_at: chrono::DateTime<chrono::Utc>,
27 pub modified_at: chrono::DateTime<chrono::Utc>,
29 pub entries: Vec<MemoryEntry>,
31}
32
33impl StorageContainer {
34 pub fn new(session_id: impl Into<String>, entries: Vec<MemoryEntry>) -> Self {
36 let now = chrono::Utc::now();
37 Self {
38 version: STORAGE_VERSION,
39 session_id: session_id.into(),
40 created_at: now,
41 modified_at: now,
42 entries,
43 }
44 }
45
46 pub fn touch(&mut self) {
48 self.modified_at = chrono::Utc::now();
49 }
50
51 pub fn migrate_from(value: serde_json::Value) -> Layer3Result<Self> {
53 let version = value.get("version").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
54
55 match version {
56 0 => {
57 let entries: Vec<MemoryEntry> = serde_json::from_value(value)?;
58 Ok(Self::new("migrated", entries))
59 }
60 1 => Ok(serde_json::from_value(value)?),
61 v => anyhow::bail!("Unsupported storage version: {}", v),
62 }
63 }
64}
65
66impl Default for StorageContainer {
67 fn default() -> Self {
68 Self::new("default", Vec::new())
69 }
70}
71
72#[async_trait]
74pub trait FileBackend: Send + Sync {
75 async fn save(&self, entries: &[MemoryEntry]) -> Layer3Result<()>;
77
78 async fn save_with_session(
80 &self,
81 session_id: &str,
82 entries: &[MemoryEntry],
83 ) -> Layer3Result<()>;
84
85 async fn load(&self) -> Layer3Result<Vec<MemoryEntry>>;
87
88 async fn load_container(&self) -> Layer3Result<StorageContainer>;
90
91 async fn exists(&self) -> bool;
93
94 async fn clear(&self) -> Layer3Result<()>;
96
97 fn path(&self) -> &Path;
99
100 fn version(&self) -> u32 {
102 STORAGE_VERSION
103 }
104}
105
106pub struct JsonFileBackend {
108 path: PathBuf,
110 pretty: bool,
112 session_id: Option<String>,
114}
115
116impl JsonFileBackend {
117 pub fn new(path: impl Into<PathBuf>) -> Self {
119 Self {
120 path: path.into(),
121 pretty: true,
122 session_id: None,
123 }
124 }
125
126 pub fn with_pretty(path: impl Into<PathBuf>, pretty: bool) -> Self {
128 Self {
129 path: path.into(),
130 pretty,
131 session_id: None,
132 }
133 }
134
135 pub fn with_session_id(path: impl Into<PathBuf>, session_id: impl Into<String>) -> Self {
137 Self {
138 path: path.into(),
139 pretty: true,
140 session_id: Some(session_id.into()),
141 }
142 }
143
144 pub fn set_session_id(&mut self, session_id: impl Into<String>) {
146 self.session_id = Some(session_id.into());
147 }
148
149 pub async fn get_stored_session_id(&self) -> Layer3Result<Option<String>> {
151 if !self.path.exists() {
152 return Ok(None);
153 }
154
155 let content = fs::read_to_string(&self.path).await?;
156 if content.trim().is_empty() {
157 return Ok(None);
158 }
159
160 let value: serde_json::Value = serde_json::from_str(&content)?;
161 if let Some(session_id) = value.get("session_id").and_then(|v| v.as_str()) {
162 Ok(Some(session_id.to_string()))
163 } else {
164 Ok(None)
165 }
166 }
167
168 fn temp_path(&self) -> PathBuf {
170 let mut temp = self.path.clone();
171 let file_name = temp.file_name().and_then(|n| n.to_str()).unwrap_or("temp");
172 let temp_name = format!("{}.tmp.{}", file_name, std::process::id());
173 temp.set_file_name(temp_name);
174 temp
175 }
176
177 async fn ensure_parent(&self) -> Layer3Result<()> {
179 if let Some(parent) = self.path.parent() {
180 if !parent.exists() {
181 fs::create_dir_all(parent).await?;
182 }
183 }
184 Ok(())
185 }
186
187 async fn atomic_write(&self, content: &str) -> Layer3Result<()> {
189 self.ensure_parent().await?;
190
191 let temp_path = self.temp_path();
192
193 let mut file = fs::File::create(&temp_path).await?;
195 file.write_all(content.as_bytes()).await?;
196 file.sync_all().await?;
197 drop(file);
198
199 fs::rename(&temp_path, &self.path).await?;
201
202 tracing::debug!("Atomically saved to {:?}", self.path);
203 Ok(())
204 }
205}
206
207#[async_trait]
208impl FileBackend for JsonFileBackend {
209 async fn save(&self, entries: &[MemoryEntry]) -> Layer3Result<()> {
210 let session_id = self.session_id.as_deref().unwrap_or("unknown");
211 self.save_with_session(session_id, entries).await
212 }
213
214 async fn save_with_session(
215 &self,
216 session_id: &str,
217 entries: &[MemoryEntry],
218 ) -> Layer3Result<()> {
219 let container = StorageContainer::new(session_id, entries.to_vec());
220
221 let json = if self.pretty {
222 serde_json::to_string_pretty(&container)?
223 } else {
224 serde_json::to_string(&container)?
225 };
226
227 self.atomic_write(&json).await?;
228 tracing::debug!(
229 "Saved {} entries to {:?} (session: {})",
230 entries.len(),
231 self.path,
232 session_id
233 );
234 Ok(())
235 }
236
237 async fn load(&self) -> Layer3Result<Vec<MemoryEntry>> {
238 let container = self.load_container().await?;
239 Ok(container.entries)
240 }
241
242 async fn load_container(&self) -> Layer3Result<StorageContainer> {
243 if !self.path.exists() {
244 return Ok(StorageContainer::default());
245 }
246
247 let content = fs::read_to_string(&self.path).await?;
248
249 if content.trim().is_empty() {
250 return Ok(StorageContainer::default());
251 }
252
253 let value: serde_json::Value = serde_json::from_str(&content)?;
254 let container = StorageContainer::migrate_from(value)?;
255 tracing::debug!(
256 "Loaded {} entries from {:?} (version: {})",
257 container.entries.len(),
258 self.path,
259 container.version
260 );
261 Ok(container)
262 }
263
264 async fn exists(&self) -> bool {
265 self.path.exists()
266 }
267
268 async fn clear(&self) -> Layer3Result<()> {
269 if self.path.exists() {
270 let temp_path = self.temp_path();
272 if temp_path.exists() {
273 fs::remove_file(&temp_path).await?;
274 }
275 fs::remove_file(&self.path).await?;
276 tracing::debug!("Cleared backend at {:?}", self.path);
277 }
278 Ok(())
279 }
280
281 fn path(&self) -> &Path {
282 &self.path
283 }
284}
285
286#[cfg(feature = "msgpack")]
288pub struct MsgPackFileBackend {
289 path: PathBuf,
290 session_id: Option<String>,
291}
292
293#[cfg(feature = "msgpack")]
294impl MsgPackFileBackend {
295 pub fn new(path: impl Into<PathBuf>) -> Self {
296 Self {
297 path: path.into(),
298 session_id: None,
299 }
300 }
301
302 pub fn with_session_id(path: impl Into<PathBuf>, session_id: impl Into<String>) -> Self {
303 Self {
304 path: path.into(),
305 session_id: Some(session_id.into()),
306 }
307 }
308
309 fn temp_path(&self) -> PathBuf {
310 let mut temp = self.path.clone();
311 let file_name = temp.file_name().and_then(|n| n.to_str()).unwrap_or("temp");
312 let temp_name = format!("{}.tmp.{}", file_name, std::process::id());
313 temp.set_file_name(temp_name);
314 temp
315 }
316
317 async fn ensure_parent(&self) -> Layer3Result<()> {
318 if let Some(parent) = self.path.parent() {
319 if !parent.exists() {
320 fs::create_dir_all(parent).await?;
321 }
322 }
323 Ok(())
324 }
325
326 async fn atomic_write(&self, bytes: &[u8]) -> Layer3Result<()> {
327 self.ensure_parent().await?;
328
329 let temp_path = self.temp_path();
330
331 let mut file = fs::File::create(&temp_path).await?;
332 file.write_all(bytes).await?;
333 file.sync_all().await?;
334 drop(file);
335
336 fs::rename(&temp_path, &self.path).await?;
337 Ok(())
338 }
339}
340
341#[cfg(feature = "msgpack")]
342#[async_trait]
343impl FileBackend for MsgPackFileBackend {
344 async fn save(&self, entries: &[MemoryEntry]) -> Layer3Result<()> {
345 let session_id = self.session_id.as_deref().unwrap_or("unknown");
346 self.save_with_session(session_id, entries).await
347 }
348
349 async fn save_with_session(
350 &self,
351 session_id: &str,
352 entries: &[MemoryEntry],
353 ) -> Layer3Result<()> {
354 let container = StorageContainer::new(session_id, entries.to_vec());
355 let bytes = rmp_serde::to_vec(&container)?;
356 self.atomic_write(&bytes).await?;
357 Ok(())
358 }
359
360 async fn load(&self) -> Layer3Result<Vec<MemoryEntry>> {
361 if !self.path.exists() {
362 return Ok(Vec::new());
363 }
364 let bytes = fs::read(&self.path).await?;
365 let container: StorageContainer = rmp_serde::from_slice(&bytes)?;
366 Ok(container.entries)
367 }
368
369 async fn load_container(&self) -> Layer3Result<StorageContainer> {
370 if !self.path.exists() {
371 return Ok(StorageContainer::default());
372 }
373 let bytes = fs::read(&self.path).await?;
374 let container: StorageContainer = rmp_serde::from_slice(&bytes)?;
375 Ok(container)
376 }
377
378 async fn exists(&self) -> bool {
379 self.path.exists()
380 }
381
382 async fn clear(&self) -> Layer3Result<()> {
383 if self.path.exists() {
384 let temp_path = self.temp_path();
385 if temp_path.exists() {
386 fs::remove_file(&temp_path).await?;
387 }
388 fs::remove_file(&self.path).await?;
389 }
390 Ok(())
391 }
392
393 fn path(&self) -> &Path {
394 &self.path
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use crate::types::MemoryTier;
402 use std::sync::Arc;
403 use tempfile::tempdir;
404
405 fn create_test_entry(id: &str, content: &str) -> MemoryEntry {
406 MemoryEntry {
407 id: id.to_string(),
408 tier: MemoryTier::Session,
409 content: content.to_string(),
410 metadata: Default::default(),
411 created_at: chrono::Utc::now(),
412 last_accessed: chrono::Utc::now(),
413 access_count: 0,
414 importance: 0.5,
415 }
416 }
417
418 #[tokio::test]
419 async fn test_json_backend_save_load() {
420 let dir = tempdir().unwrap();
421 let path = dir.path().join("session.json");
422 let backend = JsonFileBackend::new(&path);
423
424 let entries = vec![create_test_entry("1", "test content")];
425
426 backend.save(&entries).await.unwrap();
427 assert!(path.exists());
428
429 let loaded = backend.load().await.unwrap();
430 assert_eq!(loaded.len(), 1);
431 assert_eq!(loaded[0].content, "test content");
432 }
433
434 #[tokio::test]
435 async fn test_json_backend_empty_load() {
436 let dir = tempdir().unwrap();
437 let path = dir.path().join("nonexistent.json");
438 let backend = JsonFileBackend::new(&path);
439
440 let loaded = backend.load().await.unwrap();
441 assert!(loaded.is_empty());
442 }
443
444 #[tokio::test]
445 async fn test_json_backend_clear() {
446 let dir = tempdir().unwrap();
447 let path = dir.path().join("session.json");
448 let backend = JsonFileBackend::new(&path);
449
450 backend
451 .save(&[create_test_entry("1", "test")])
452 .await
453 .unwrap();
454 assert!(backend.exists().await);
455
456 backend.clear().await.unwrap();
457 assert!(!backend.exists().await);
458 }
459
460 #[tokio::test]
461 async fn test_atomic_write_no_temp_file_left() {
462 let dir = tempdir().unwrap();
463 let path = dir.path().join("atomic_test.json");
464 let backend = JsonFileBackend::new(&path);
465
466 backend
467 .save(&[create_test_entry("1", "test")])
468 .await
469 .unwrap();
470
471 for entry in std::fs::read_dir(dir.path()).unwrap() {
473 let entry = entry.unwrap();
474 let name = entry.file_name().to_string_lossy().to_string();
475 assert!(!name.contains(".tmp."), "Temp file left behind: {}", name);
476 }
477 }
478
479 #[tokio::test]
480 async fn test_version_container() {
481 let dir = tempdir().unwrap();
482 let path = dir.path().join("versioned.json");
483 let backend = JsonFileBackend::with_session_id(&path, "test-session-123");
484
485 backend
486 .save(&[create_test_entry("1", "test")])
487 .await
488 .unwrap();
489
490 let container = backend.load_container().await.unwrap();
491 assert_eq!(container.version, STORAGE_VERSION);
492 assert_eq!(container.session_id, "test-session-123");
493 assert_eq!(container.entries.len(), 1);
494 }
495
496 #[tokio::test]
497 async fn test_migration_from_v0() {
498 let dir = tempdir().unwrap();
499 let path = dir.path().join("legacy.json");
500
501 let legacy_entries = vec![create_test_entry("legacy-1", "legacy content")];
503 let legacy_json = serde_json::to_string_pretty(&legacy_entries).unwrap();
504 std::fs::write(&path, legacy_json).unwrap();
505
506 let backend = JsonFileBackend::new(&path);
507 let loaded = backend.load().await.unwrap();
508
509 assert_eq!(loaded.len(), 1);
510 assert_eq!(loaded[0].content, "legacy content");
511 }
512
513 #[tokio::test]
514 async fn test_session_id_retrieval() {
515 let dir = tempdir().unwrap();
516 let path = dir.path().join("session_id.json");
517 let backend = JsonFileBackend::with_session_id(&path, "session-abc");
518
519 assert!(backend.get_stored_session_id().await.unwrap().is_none());
520
521 backend
522 .save(&[create_test_entry("1", "test")])
523 .await
524 .unwrap();
525
526 let stored_id = backend.get_stored_session_id().await.unwrap();
527 assert_eq!(stored_id, Some("session-abc".to_string()));
528 }
529
530 #[tokio::test]
531 async fn test_concurrent_safe_operations() {
532 let dir = tempdir().unwrap();
533 let path = dir.path().join("concurrent.json");
534 let backend = Arc::new(JsonFileBackend::new(&path));
535
536 for i in 0..5 {
539 backend
540 .save(&[create_test_entry(&format!("entry-{}", i), "content")])
541 .await
542 .unwrap();
543 }
544
545 assert!(path.exists());
547 let loaded = backend.load().await.unwrap();
548 assert!(!loaded.is_empty());
549 }
550}