1use super::facts::TypedFacts;
10use std::collections::{HashMap, HashSet};
11use std::sync::atomic::{AtomicU64, Ordering};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub struct FactHandle(u64);
16
17impl FactHandle {
18 pub fn new(id: u64) -> Self {
20 Self(id)
21 }
22
23 pub fn id(&self) -> u64 {
25 self.0
26 }
27}
28
29impl std::fmt::Display for FactHandle {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "FactHandle({})", self.0)
32 }
33}
34
35#[derive(Debug, Clone)]
37pub struct WorkingMemoryFact {
38 pub handle: FactHandle,
40 pub fact_type: String,
42 pub data: TypedFacts,
44 pub metadata: FactMetadata,
46 #[cfg(feature = "streaming")]
48 pub stream_source: Option<String>,
49 #[cfg(feature = "streaming")]
51 pub stream_event: Option<crate::streaming::event::StreamEvent>,
52}
53
54#[derive(Debug, Clone)]
56pub struct FactMetadata {
57 pub inserted_at: std::time::Instant,
59 pub updated_at: std::time::Instant,
61 pub update_count: usize,
63 pub retracted: bool,
65}
66
67impl Default for FactMetadata {
68 fn default() -> Self {
69 let now = std::time::Instant::now();
70 Self {
71 inserted_at: now,
72 updated_at: now,
73 update_count: 0,
74 retracted: false,
75 }
76 }
77}
78
79pub struct WorkingMemory {
81 facts: HashMap<FactHandle, WorkingMemoryFact>,
83 type_index: HashMap<String, HashSet<FactHandle>>,
85 next_id: AtomicU64,
87 modified_handles: HashSet<FactHandle>,
89 retracted_handles: HashSet<FactHandle>,
91}
92
93impl WorkingMemory {
94 pub fn new() -> Self {
96 Self {
97 facts: HashMap::new(),
98 type_index: HashMap::new(),
99 next_id: AtomicU64::new(1),
100 modified_handles: HashSet::new(),
101 retracted_handles: HashSet::new(),
102 }
103 }
104
105 pub fn insert(&mut self, fact_type: String, data: TypedFacts) -> FactHandle {
107 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
108 let handle = FactHandle::new(id);
109
110 let fact = WorkingMemoryFact {
111 handle,
112 fact_type: fact_type.clone(),
113 data,
114 metadata: FactMetadata::default(),
115 #[cfg(feature = "streaming")]
116 stream_source: None,
117 #[cfg(feature = "streaming")]
118 stream_event: None,
119 };
120
121 self.facts.insert(handle, fact);
122 self.type_index.entry(fact_type).or_default().insert(handle);
123 self.modified_handles.insert(handle);
124
125 handle
126 }
127
128 #[cfg(feature = "streaming")]
130 pub fn insert_from_stream(
131 &mut self,
132 stream_name: String,
133 event: crate::streaming::event::StreamEvent,
134 ) -> FactHandle {
135 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
136 let handle = FactHandle::new(id);
137
138 let mut typed_facts = super::facts::TypedFacts::new();
140 for (key, value) in &event.data {
141 let fact_value: super::facts::FactValue = value.clone().into();
142 typed_facts.set(key.clone(), fact_value);
143 }
144
145 let fact = WorkingMemoryFact {
146 handle,
147 fact_type: event.event_type.clone(),
148 data: typed_facts,
149 metadata: FactMetadata::default(),
150 stream_source: Some(stream_name.clone()),
151 stream_event: Some(event),
152 };
153
154 self.facts.insert(handle, fact);
155 self.type_index
156 .entry(stream_name)
157 .or_default()
158 .insert(handle);
159 self.modified_handles.insert(handle);
160
161 handle
162 }
163
164 pub fn update(&mut self, handle: FactHandle, data: TypedFacts) -> Result<(), String> {
166 let fact = self
167 .facts
168 .get_mut(&handle)
169 .ok_or_else(|| format!("FactHandle {} not found", handle))?;
170
171 if fact.metadata.retracted {
172 return Err(format!("FactHandle {} is retracted", handle));
173 }
174
175 fact.data = data;
176 fact.metadata.updated_at = std::time::Instant::now();
177 fact.metadata.update_count += 1;
178 self.modified_handles.insert(handle);
179
180 Ok(())
181 }
182
183 pub fn retract(&mut self, handle: FactHandle) -> Result<(), String> {
185 let fact = self
186 .facts
187 .get_mut(&handle)
188 .ok_or_else(|| format!("FactHandle {} not found", handle))?;
189
190 if fact.metadata.retracted {
191 return Err(format!("FactHandle {} already retracted", handle));
192 }
193
194 fact.metadata.retracted = true;
195 self.retracted_handles.insert(handle);
196
197 if let Some(handles) = self.type_index.get_mut(&fact.fact_type) {
199 handles.remove(&handle);
200 }
201
202 Ok(())
203 }
204
205 pub fn get(&self, handle: &FactHandle) -> Option<&WorkingMemoryFact> {
207 self.facts.get(handle).filter(|f| !f.metadata.retracted)
208 }
209
210 pub fn get_by_type(&self, fact_type: &str) -> Vec<&WorkingMemoryFact> {
212 if let Some(handles) = self.type_index.get(fact_type) {
213 handles
214 .iter()
215 .filter_map(|h| self.facts.get(h))
216 .filter(|f| !f.metadata.retracted)
217 .collect()
218 } else {
219 Vec::new()
220 }
221 }
222
223 pub fn get_all_facts(&self) -> Vec<&WorkingMemoryFact> {
225 self.facts
226 .values()
227 .filter(|f| !f.metadata.retracted)
228 .collect()
229 }
230
231 pub fn get_all_handles(&self) -> Vec<FactHandle> {
233 self.facts
234 .values()
235 .filter(|f| !f.metadata.retracted)
236 .map(|f| f.handle)
237 .collect()
238 }
239
240 pub fn get_modified_handles(&self) -> &HashSet<FactHandle> {
242 &self.modified_handles
243 }
244
245 pub fn get_retracted_handles(&self) -> &HashSet<FactHandle> {
247 &self.retracted_handles
248 }
249
250 pub fn clear_modification_tracking(&mut self) {
252 self.modified_handles.clear();
253 self.retracted_handles.clear();
254 }
255
256 pub fn stats(&self) -> WorkingMemoryStats {
258 let active_facts = self
259 .facts
260 .values()
261 .filter(|f| !f.metadata.retracted)
262 .count();
263 let retracted_facts = self.facts.values().filter(|f| f.metadata.retracted).count();
264
265 WorkingMemoryStats {
266 total_facts: self.facts.len(),
267 active_facts,
268 retracted_facts,
269 types: self.type_index.len(),
270 modified_pending: self.modified_handles.len(),
271 retracted_pending: self.retracted_handles.len(),
272 }
273 }
274
275 pub fn clear(&mut self) {
277 self.facts.clear();
278 self.type_index.clear();
279 self.modified_handles.clear();
280 self.retracted_handles.clear();
281 }
282
283 pub fn to_typed_facts(&self) -> TypedFacts {
286 let mut result = TypedFacts::new();
287
288 for fact in self.get_all_facts() {
289 let prefix = format!("{}.{}", fact.fact_type, fact.handle.id());
290 for (key, value) in fact.data.get_all() {
291 result.set(format!("{}.{}", prefix, key), value.clone());
292 }
293
294 for (key, value) in fact.data.get_all() {
296 result.set(format!("{}.{}", fact.fact_type, key), value.clone());
297 }
298
299 result.set_fact_handle(fact.fact_type.clone(), fact.handle);
302 }
303
304 result
305 }
306}
307
308impl Default for WorkingMemory {
309 fn default() -> Self {
310 Self::new()
311 }
312}
313
314#[derive(Debug, Clone)]
316pub struct WorkingMemoryStats {
317 pub total_facts: usize,
318 pub active_facts: usize,
319 pub retracted_facts: usize,
320 pub types: usize,
321 pub modified_pending: usize,
322 pub retracted_pending: usize,
323}
324
325impl std::fmt::Display for WorkingMemoryStats {
326 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327 write!(
328 f,
329 "WM Stats: {} active, {} retracted, {} types, {} modified, {} pending retraction",
330 self.active_facts,
331 self.retracted_facts,
332 self.types,
333 self.modified_pending,
334 self.retracted_pending
335 )
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342
343 #[test]
344 fn test_insert_and_get() {
345 let mut wm = WorkingMemory::new();
346 let mut person_data = TypedFacts::new();
347 person_data.set("name", "John");
348 person_data.set("age", 25i64);
349
350 let handle = wm.insert("Person".to_string(), person_data);
351
352 let fact = wm.get(&handle).unwrap();
353 assert_eq!(fact.fact_type, "Person");
354 assert_eq!(fact.data.get("name").unwrap().as_string(), "John");
355 }
356
357 #[test]
358 fn test_update() {
359 let mut wm = WorkingMemory::new();
360 let mut data = TypedFacts::new();
361 data.set("age", 25i64);
362
363 let handle = wm.insert("Person".to_string(), data);
364
365 let mut updated_data = TypedFacts::new();
366 updated_data.set("age", 26i64);
367 wm.update(handle, updated_data).unwrap();
368
369 let fact = wm.get(&handle).unwrap();
370 assert_eq!(fact.data.get("age").unwrap().as_integer(), Some(26));
371 assert_eq!(fact.metadata.update_count, 1);
372 }
373
374 #[test]
375 fn test_retract() {
376 let mut wm = WorkingMemory::new();
377 let data = TypedFacts::new();
378 let handle = wm.insert("Person".to_string(), data);
379
380 wm.retract(handle).unwrap();
381
382 assert!(wm.get(&handle).is_none());
383 assert_eq!(wm.get_all_facts().len(), 0);
384 }
385
386 #[test]
387 fn test_type_index() {
388 let mut wm = WorkingMemory::new();
389
390 for i in 0..5 {
391 let mut data = TypedFacts::new();
392 data.set("id", i as i64);
393 wm.insert("Person".to_string(), data);
394 }
395
396 for i in 0..3 {
397 let mut data = TypedFacts::new();
398 data.set("id", i as i64);
399 wm.insert("Order".to_string(), data);
400 }
401
402 assert_eq!(wm.get_by_type("Person").len(), 5);
403 assert_eq!(wm.get_by_type("Order").len(), 3);
404 assert_eq!(wm.get_by_type("Unknown").len(), 0);
405 }
406
407 #[test]
408 fn test_modification_tracking() {
409 let mut wm = WorkingMemory::new();
410 let data = TypedFacts::new();
411 let h1 = wm.insert("Person".to_string(), data.clone());
412 let h2 = wm.insert("Person".to_string(), data.clone());
413
414 assert_eq!(wm.get_modified_handles().len(), 2);
415
416 wm.clear_modification_tracking();
417 assert_eq!(wm.get_modified_handles().len(), 0);
418
419 wm.update(h1, data.clone()).unwrap();
420 assert_eq!(wm.get_modified_handles().len(), 1);
421
422 wm.retract(h2).unwrap();
423 assert_eq!(wm.get_retracted_handles().len(), 1);
424 }
425}