1use std::collections::{HashMap, HashSet};
10use std::sync::atomic::{AtomicU64, Ordering};
11use super::facts::{FactValue, TypedFacts};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub struct FactHandle(u64);
16
17impl FactHandle {
18 fn new(id: u64) -> Self {
20 Self(id)
21 }
22
23 pub fn id(&self) -> u64 {
25 self.0
26 }
27}
28
29impl std::fmt::Display for FactHandle {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "FactHandle({})", self.0)
32 }
33}
34
35#[derive(Debug, Clone)]
37pub struct WorkingMemoryFact {
38 pub handle: FactHandle,
40 pub fact_type: String,
42 pub data: TypedFacts,
44 pub metadata: FactMetadata,
46}
47
48#[derive(Debug, Clone)]
50pub struct FactMetadata {
51 pub inserted_at: std::time::Instant,
53 pub updated_at: std::time::Instant,
55 pub update_count: usize,
57 pub retracted: bool,
59}
60
61impl Default for FactMetadata {
62 fn default() -> Self {
63 let now = std::time::Instant::now();
64 Self {
65 inserted_at: now,
66 updated_at: now,
67 update_count: 0,
68 retracted: false,
69 }
70 }
71}
72
73pub struct WorkingMemory {
75 facts: HashMap<FactHandle, WorkingMemoryFact>,
77 type_index: HashMap<String, HashSet<FactHandle>>,
79 next_id: AtomicU64,
81 modified_handles: HashSet<FactHandle>,
83 retracted_handles: HashSet<FactHandle>,
85}
86
87impl WorkingMemory {
88 pub fn new() -> Self {
90 Self {
91 facts: HashMap::new(),
92 type_index: HashMap::new(),
93 next_id: AtomicU64::new(1),
94 modified_handles: HashSet::new(),
95 retracted_handles: HashSet::new(),
96 }
97 }
98
99 pub fn insert(&mut self, fact_type: String, data: TypedFacts) -> FactHandle {
101 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
102 let handle = FactHandle::new(id);
103
104 let fact = WorkingMemoryFact {
105 handle,
106 fact_type: fact_type.clone(),
107 data,
108 metadata: FactMetadata::default(),
109 };
110
111 self.facts.insert(handle, fact);
112 self.type_index
113 .entry(fact_type)
114 .or_insert_with(HashSet::new)
115 .insert(handle);
116 self.modified_handles.insert(handle);
117
118 handle
119 }
120
121 pub fn update(&mut self, handle: FactHandle, data: TypedFacts) -> Result<(), String> {
123 let fact = self.facts.get_mut(&handle)
124 .ok_or_else(|| format!("FactHandle {} not found", handle))?;
125
126 if fact.metadata.retracted {
127 return Err(format!("FactHandle {} is retracted", handle));
128 }
129
130 fact.data = data;
131 fact.metadata.updated_at = std::time::Instant::now();
132 fact.metadata.update_count += 1;
133 self.modified_handles.insert(handle);
134
135 Ok(())
136 }
137
138 pub fn retract(&mut self, handle: FactHandle) -> Result<(), String> {
140 let fact = self.facts.get_mut(&handle)
141 .ok_or_else(|| format!("FactHandle {} not found", handle))?;
142
143 if fact.metadata.retracted {
144 return Err(format!("FactHandle {} already retracted", handle));
145 }
146
147 fact.metadata.retracted = true;
148 self.retracted_handles.insert(handle);
149
150 if let Some(handles) = self.type_index.get_mut(&fact.fact_type) {
152 handles.remove(&handle);
153 }
154
155 Ok(())
156 }
157
158 pub fn get(&self, handle: &FactHandle) -> Option<&WorkingMemoryFact> {
160 self.facts.get(handle).filter(|f| !f.metadata.retracted)
161 }
162
163 pub fn get_by_type(&self, fact_type: &str) -> Vec<&WorkingMemoryFact> {
165 if let Some(handles) = self.type_index.get(fact_type) {
166 handles
167 .iter()
168 .filter_map(|h| self.facts.get(h))
169 .filter(|f| !f.metadata.retracted)
170 .collect()
171 } else {
172 Vec::new()
173 }
174 }
175
176 pub fn get_all_facts(&self) -> Vec<&WorkingMemoryFact> {
178 self.facts
179 .values()
180 .filter(|f| !f.metadata.retracted)
181 .collect()
182 }
183
184 pub fn get_all_handles(&self) -> Vec<FactHandle> {
186 self.facts
187 .values()
188 .filter(|f| !f.metadata.retracted)
189 .map(|f| f.handle)
190 .collect()
191 }
192
193 pub fn get_modified_handles(&self) -> &HashSet<FactHandle> {
195 &self.modified_handles
196 }
197
198 pub fn get_retracted_handles(&self) -> &HashSet<FactHandle> {
200 &self.retracted_handles
201 }
202
203 pub fn clear_modification_tracking(&mut self) {
205 self.modified_handles.clear();
206 self.retracted_handles.clear();
207 }
208
209 pub fn stats(&self) -> WorkingMemoryStats {
211 let active_facts = self.facts.values().filter(|f| !f.metadata.retracted).count();
212 let retracted_facts = self.facts.values().filter(|f| f.metadata.retracted).count();
213
214 WorkingMemoryStats {
215 total_facts: self.facts.len(),
216 active_facts,
217 retracted_facts,
218 types: self.type_index.len(),
219 modified_pending: self.modified_handles.len(),
220 retracted_pending: self.retracted_handles.len(),
221 }
222 }
223
224 pub fn clear(&mut self) {
226 self.facts.clear();
227 self.type_index.clear();
228 self.modified_handles.clear();
229 self.retracted_handles.clear();
230 }
231
232 pub fn to_typed_facts(&self) -> TypedFacts {
235 let mut result = TypedFacts::new();
236
237 for fact in self.get_all_facts() {
238 let prefix = format!("{}.{}", fact.fact_type, fact.handle.id());
239 for (key, value) in fact.data.get_all() {
240 result.set(format!("{}.{}", prefix, key), value.clone());
241 }
242
243 for (key, value) in fact.data.get_all() {
245 result.set(format!("{}.{}", fact.fact_type, key), value.clone());
246 }
247 }
248
249 result
250 }
251}
252
253impl Default for WorkingMemory {
254 fn default() -> Self {
255 Self::new()
256 }
257}
258
259#[derive(Debug, Clone)]
261pub struct WorkingMemoryStats {
262 pub total_facts: usize,
263 pub active_facts: usize,
264 pub retracted_facts: usize,
265 pub types: usize,
266 pub modified_pending: usize,
267 pub retracted_pending: usize,
268}
269
270impl std::fmt::Display for WorkingMemoryStats {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 write!(
273 f,
274 "WM Stats: {} active, {} retracted, {} types, {} modified, {} pending retraction",
275 self.active_facts, self.retracted_facts, self.types,
276 self.modified_pending, self.retracted_pending
277 )
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_insert_and_get() {
287 let mut wm = WorkingMemory::new();
288 let mut person_data = TypedFacts::new();
289 person_data.set("name", "John");
290 person_data.set("age", 25i64);
291
292 let handle = wm.insert("Person".to_string(), person_data);
293
294 let fact = wm.get(&handle).unwrap();
295 assert_eq!(fact.fact_type, "Person");
296 assert_eq!(fact.data.get("name").unwrap().as_string(), "John");
297 }
298
299 #[test]
300 fn test_update() {
301 let mut wm = WorkingMemory::new();
302 let mut data = TypedFacts::new();
303 data.set("age", 25i64);
304
305 let handle = wm.insert("Person".to_string(), data);
306
307 let mut updated_data = TypedFacts::new();
308 updated_data.set("age", 26i64);
309 wm.update(handle, updated_data).unwrap();
310
311 let fact = wm.get(&handle).unwrap();
312 assert_eq!(fact.data.get("age").unwrap().as_integer(), Some(26));
313 assert_eq!(fact.metadata.update_count, 1);
314 }
315
316 #[test]
317 fn test_retract() {
318 let mut wm = WorkingMemory::new();
319 let data = TypedFacts::new();
320 let handle = wm.insert("Person".to_string(), data);
321
322 wm.retract(handle).unwrap();
323
324 assert!(wm.get(&handle).is_none());
325 assert_eq!(wm.get_all_facts().len(), 0);
326 }
327
328 #[test]
329 fn test_type_index() {
330 let mut wm = WorkingMemory::new();
331
332 for i in 0..5 {
333 let mut data = TypedFacts::new();
334 data.set("id", i as i64);
335 wm.insert("Person".to_string(), data);
336 }
337
338 for i in 0..3 {
339 let mut data = TypedFacts::new();
340 data.set("id", i as i64);
341 wm.insert("Order".to_string(), data);
342 }
343
344 assert_eq!(wm.get_by_type("Person").len(), 5);
345 assert_eq!(wm.get_by_type("Order").len(), 3);
346 assert_eq!(wm.get_by_type("Unknown").len(), 0);
347 }
348
349 #[test]
350 fn test_modification_tracking() {
351 let mut wm = WorkingMemory::new();
352 let data = TypedFacts::new();
353 let h1 = wm.insert("Person".to_string(), data.clone());
354 let h2 = wm.insert("Person".to_string(), data.clone());
355
356 assert_eq!(wm.get_modified_handles().len(), 2);
357
358 wm.clear_modification_tracking();
359 assert_eq!(wm.get_modified_handles().len(), 0);
360
361 wm.update(h1, data.clone()).unwrap();
362 assert_eq!(wm.get_modified_handles().len(), 1);
363
364 wm.retract(h2).unwrap();
365 assert_eq!(wm.get_retracted_handles().len(), 1);
366 }
367}