1use std::collections::{HashMap, HashSet};
10use std::sync::atomic::{AtomicU64, Ordering};
11use super::facts::{FactValue, TypedFacts};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub struct FactHandle(u64);
16
17impl FactHandle {
18 pub fn new(id: u64) -> Self {
20 Self(id)
21 }
22
23 pub fn id(&self) -> u64 {
25 self.0
26 }
27}
28
29impl std::fmt::Display for FactHandle {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "FactHandle({})", self.0)
32 }
33}
34
35#[derive(Debug, Clone)]
37pub struct WorkingMemoryFact {
38 pub handle: FactHandle,
40 pub fact_type: String,
42 pub data: TypedFacts,
44 pub metadata: FactMetadata,
46}
47
48#[derive(Debug, Clone)]
50pub struct FactMetadata {
51 pub inserted_at: std::time::Instant,
53 pub updated_at: std::time::Instant,
55 pub update_count: usize,
57 pub retracted: bool,
59}
60
61impl Default for FactMetadata {
62 fn default() -> Self {
63 let now = std::time::Instant::now();
64 Self {
65 inserted_at: now,
66 updated_at: now,
67 update_count: 0,
68 retracted: false,
69 }
70 }
71}
72
73pub struct WorkingMemory {
75 facts: HashMap<FactHandle, WorkingMemoryFact>,
77 type_index: HashMap<String, HashSet<FactHandle>>,
79 next_id: AtomicU64,
81 modified_handles: HashSet<FactHandle>,
83 retracted_handles: HashSet<FactHandle>,
85}
86
87impl WorkingMemory {
88 pub fn new() -> Self {
90 Self {
91 facts: HashMap::new(),
92 type_index: HashMap::new(),
93 next_id: AtomicU64::new(1),
94 modified_handles: HashSet::new(),
95 retracted_handles: HashSet::new(),
96 }
97 }
98
99 pub fn insert(&mut self, fact_type: String, data: TypedFacts) -> FactHandle {
101 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
102 let handle = FactHandle::new(id);
103
104 let fact = WorkingMemoryFact {
105 handle,
106 fact_type: fact_type.clone(),
107 data,
108 metadata: FactMetadata::default(),
109 };
110
111 self.facts.insert(handle, fact);
112 self.type_index
113 .entry(fact_type)
114 .or_insert_with(HashSet::new)
115 .insert(handle);
116 self.modified_handles.insert(handle);
117
118 handle
119 }
120
121 pub fn update(&mut self, handle: FactHandle, data: TypedFacts) -> Result<(), String> {
123 let fact = self.facts.get_mut(&handle)
124 .ok_or_else(|| format!("FactHandle {} not found", handle))?;
125
126 if fact.metadata.retracted {
127 return Err(format!("FactHandle {} is retracted", handle));
128 }
129
130 fact.data = data;
131 fact.metadata.updated_at = std::time::Instant::now();
132 fact.metadata.update_count += 1;
133 self.modified_handles.insert(handle);
134
135 Ok(())
136 }
137
138 pub fn retract(&mut self, handle: FactHandle) -> Result<(), String> {
140 let fact = self.facts.get_mut(&handle)
141 .ok_or_else(|| format!("FactHandle {} not found", handle))?;
142
143 if fact.metadata.retracted {
144 return Err(format!("FactHandle {} already retracted", handle));
145 }
146
147 fact.metadata.retracted = true;
148 self.retracted_handles.insert(handle);
149
150 if let Some(handles) = self.type_index.get_mut(&fact.fact_type) {
152 handles.remove(&handle);
153 }
154
155 Ok(())
156 }
157
158 pub fn get(&self, handle: &FactHandle) -> Option<&WorkingMemoryFact> {
160 self.facts.get(handle).filter(|f| !f.metadata.retracted)
161 }
162
163 pub fn get_by_type(&self, fact_type: &str) -> Vec<&WorkingMemoryFact> {
165 if let Some(handles) = self.type_index.get(fact_type) {
166 handles
167 .iter()
168 .filter_map(|h| self.facts.get(h))
169 .filter(|f| !f.metadata.retracted)
170 .collect()
171 } else {
172 Vec::new()
173 }
174 }
175
176 pub fn get_all_facts(&self) -> Vec<&WorkingMemoryFact> {
178 self.facts
179 .values()
180 .filter(|f| !f.metadata.retracted)
181 .collect()
182 }
183
184 pub fn get_all_handles(&self) -> Vec<FactHandle> {
186 self.facts
187 .values()
188 .filter(|f| !f.metadata.retracted)
189 .map(|f| f.handle)
190 .collect()
191 }
192
193 pub fn get_modified_handles(&self) -> &HashSet<FactHandle> {
195 &self.modified_handles
196 }
197
198 pub fn get_retracted_handles(&self) -> &HashSet<FactHandle> {
200 &self.retracted_handles
201 }
202
203 pub fn clear_modification_tracking(&mut self) {
205 self.modified_handles.clear();
206 self.retracted_handles.clear();
207 }
208
209 pub fn stats(&self) -> WorkingMemoryStats {
211 let active_facts = self.facts.values().filter(|f| !f.metadata.retracted).count();
212 let retracted_facts = self.facts.values().filter(|f| f.metadata.retracted).count();
213
214 WorkingMemoryStats {
215 total_facts: self.facts.len(),
216 active_facts,
217 retracted_facts,
218 types: self.type_index.len(),
219 modified_pending: self.modified_handles.len(),
220 retracted_pending: self.retracted_handles.len(),
221 }
222 }
223
224 pub fn clear(&mut self) {
226 self.facts.clear();
227 self.type_index.clear();
228 self.modified_handles.clear();
229 self.retracted_handles.clear();
230 }
231
232 pub fn to_typed_facts(&self) -> TypedFacts {
235 let mut result = TypedFacts::new();
236
237 for fact in self.get_all_facts() {
238 let prefix = format!("{}.{}", fact.fact_type, fact.handle.id());
239 for (key, value) in fact.data.get_all() {
240 result.set(format!("{}.{}", prefix, key), value.clone());
241 }
242
243 for (key, value) in fact.data.get_all() {
245 result.set(format!("{}.{}", fact.fact_type, key), value.clone());
246 }
247
248 result.set_fact_handle(fact.fact_type.clone(), fact.handle);
251 }
252
253 result
254 }
255}
256
257impl Default for WorkingMemory {
258 fn default() -> Self {
259 Self::new()
260 }
261}
262
263#[derive(Debug, Clone)]
265pub struct WorkingMemoryStats {
266 pub total_facts: usize,
267 pub active_facts: usize,
268 pub retracted_facts: usize,
269 pub types: usize,
270 pub modified_pending: usize,
271 pub retracted_pending: usize,
272}
273
274impl std::fmt::Display for WorkingMemoryStats {
275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276 write!(
277 f,
278 "WM Stats: {} active, {} retracted, {} types, {} modified, {} pending retraction",
279 self.active_facts, self.retracted_facts, self.types,
280 self.modified_pending, self.retracted_pending
281 )
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_insert_and_get() {
291 let mut wm = WorkingMemory::new();
292 let mut person_data = TypedFacts::new();
293 person_data.set("name", "John");
294 person_data.set("age", 25i64);
295
296 let handle = wm.insert("Person".to_string(), person_data);
297
298 let fact = wm.get(&handle).unwrap();
299 assert_eq!(fact.fact_type, "Person");
300 assert_eq!(fact.data.get("name").unwrap().as_string(), "John");
301 }
302
303 #[test]
304 fn test_update() {
305 let mut wm = WorkingMemory::new();
306 let mut data = TypedFacts::new();
307 data.set("age", 25i64);
308
309 let handle = wm.insert("Person".to_string(), data);
310
311 let mut updated_data = TypedFacts::new();
312 updated_data.set("age", 26i64);
313 wm.update(handle, updated_data).unwrap();
314
315 let fact = wm.get(&handle).unwrap();
316 assert_eq!(fact.data.get("age").unwrap().as_integer(), Some(26));
317 assert_eq!(fact.metadata.update_count, 1);
318 }
319
320 #[test]
321 fn test_retract() {
322 let mut wm = WorkingMemory::new();
323 let data = TypedFacts::new();
324 let handle = wm.insert("Person".to_string(), data);
325
326 wm.retract(handle).unwrap();
327
328 assert!(wm.get(&handle).is_none());
329 assert_eq!(wm.get_all_facts().len(), 0);
330 }
331
332 #[test]
333 fn test_type_index() {
334 let mut wm = WorkingMemory::new();
335
336 for i in 0..5 {
337 let mut data = TypedFacts::new();
338 data.set("id", i as i64);
339 wm.insert("Person".to_string(), data);
340 }
341
342 for i in 0..3 {
343 let mut data = TypedFacts::new();
344 data.set("id", i as i64);
345 wm.insert("Order".to_string(), data);
346 }
347
348 assert_eq!(wm.get_by_type("Person").len(), 5);
349 assert_eq!(wm.get_by_type("Order").len(), 3);
350 assert_eq!(wm.get_by_type("Unknown").len(), 0);
351 }
352
353 #[test]
354 fn test_modification_tracking() {
355 let mut wm = WorkingMemory::new();
356 let data = TypedFacts::new();
357 let h1 = wm.insert("Person".to_string(), data.clone());
358 let h2 = wm.insert("Person".to_string(), data.clone());
359
360 assert_eq!(wm.get_modified_handles().len(), 2);
361
362 wm.clear_modification_tracking();
363 assert_eq!(wm.get_modified_handles().len(), 0);
364
365 wm.update(h1, data.clone()).unwrap();
366 assert_eq!(wm.get_modified_handles().len(), 1);
367
368 wm.retract(h2).unwrap();
369 assert_eq!(wm.get_retracted_handles().len(), 1);
370 }
371}