sh_layer3/memory_system/
working.rs1use crate::memory_system::{DecayPolicy, MemoryStore, TimeBasedDecay};
6use crate::types::{Layer3Result, MemoryEntry, MemoryQuery, MemoryTier};
7use async_trait::async_trait;
8use parking_lot::RwLock;
9use std::collections::VecDeque;
10use std::sync::Arc;
11
12#[allow(dead_code)]
16pub struct WorkingMemory {
17 buffer: Arc<RwLock<VecDeque<MemoryEntry>>>,
19 max_size: usize,
21 #[allow(dead_code)]
23 decay_policy: Box<dyn DecayPolicy>,
24}
25
26impl WorkingMemory {
27 pub fn new(max_size: usize) -> Self {
28 Self {
29 buffer: Arc::new(RwLock::new(VecDeque::with_capacity(max_size))),
30 max_size,
31 decay_policy: Box::new(TimeBasedDecay::default()),
32 }
33 }
34}
35
36impl Default for WorkingMemory {
37 fn default() -> Self {
38 Self::new(100)
39 }
40}
41
42#[async_trait]
43impl MemoryStore for WorkingMemory {
44 fn tier(&self) -> MemoryTier {
45 MemoryTier::Working
46 }
47
48 async fn store(&self, entry: MemoryEntry) -> Layer3Result<String> {
49 let mut buffer = self.buffer.write();
50 if buffer.len() >= self.max_size {
51 buffer.pop_front();
52 }
53 let id = entry.id.clone();
54 buffer.push_back(entry);
55 Ok(id)
56 }
57
58 async fn get(&self, id: &str) -> Layer3Result<Option<MemoryEntry>> {
59 let buffer = self.buffer.read();
60 Ok(buffer.iter().find(|e| e.id == id).cloned())
61 }
62
63 async fn delete(&self, id: &str) -> Layer3Result<bool> {
64 let mut buffer = self.buffer.write();
65 let len_before = buffer.len();
66 buffer.retain(|e| e.id != id);
67 Ok(buffer.len() < len_before)
68 }
69
70 async fn query(&self, query: &MemoryQuery) -> Layer3Result<Vec<MemoryEntry>> {
71 let buffer = self.buffer.read();
72 let results: Vec<MemoryEntry> = buffer
73 .iter()
74 .filter(|e| {
75 if let Some(tier) = query.tier {
76 if e.tier != tier {
77 return false;
78 }
79 }
80 e.content.contains(&query.query)
81 })
82 .take(query.limit.unwrap_or(10))
83 .cloned()
84 .collect();
85 Ok(results)
86 }
87
88 async fn list(&self, limit: Option<usize>) -> Layer3Result<Vec<MemoryEntry>> {
89 let buffer = self.buffer.read();
90 Ok(buffer
91 .iter()
92 .take(limit.unwrap_or(usize::MAX))
93 .cloned()
94 .collect())
95 }
96
97 async fn clear(&self) -> Layer3Result<usize> {
98 let mut buffer = self.buffer.write();
99 let count = buffer.len();
100 buffer.clear();
101 Ok(count)
102 }
103
104 async fn count(&self) -> Layer3Result<usize> {
105 Ok(self.buffer.read().len())
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112
113 #[tokio::test]
114 async fn test_working_memory_store() {
115 let memory = WorkingMemory::new(10);
116 let entry = MemoryEntry {
117 id: "test-1".to_string(),
118 tier: MemoryTier::Working,
119 content: "test content".to_string(),
120 metadata: Default::default(),
121 created_at: chrono::Utc::now(),
122 last_accessed: chrono::Utc::now(),
123 access_count: 0,
124 importance: 0.5,
125 };
126 memory.store(entry).await.unwrap();
127 assert_eq!(memory.count().await.unwrap(), 1);
128 }
129}