1use std::collections::HashMap;
6use std::sync::RwLock;
7
8use chrono::Utc;
9use uuid::Uuid;
10
11use crate::yarli_memory::adapter::MemoryAdapter;
12use crate::yarli_memory::error::MemoryError;
13use crate::yarli_memory::types::{
14 content_may_contain_secrets, InsertMemory, LinkMemories, MemoryQuery, MemoryRecord,
15 RelationshipKind, ScopeId,
16};
17
18#[derive(Debug, Clone)]
20struct ScopeState {
21 #[allow(dead_code)]
22 parent: Option<ScopeId>,
23 closed: bool,
24}
25
26#[derive(Debug, Clone)]
28struct MemoryLink {
29 from_memory_id: String,
30 to_memory_id: String,
31 #[allow(dead_code)]
32 relationship: RelationshipKind,
33 #[allow(dead_code)]
34 metadata: HashMap<String, String>,
35}
36
37struct InMemoryState {
39 scopes: HashMap<String, HashMap<String, ScopeState>>,
41 memories: HashMap<String, HashMap<String, MemoryRecord>>,
43 links: HashMap<String, Vec<MemoryLink>>,
45}
46
47pub struct InMemoryAdapter {
51 state: RwLock<InMemoryState>,
52}
53
54impl InMemoryAdapter {
55 pub fn new() -> Self {
57 Self {
58 state: RwLock::new(InMemoryState {
59 scopes: HashMap::new(),
60 memories: HashMap::new(),
61 links: HashMap::new(),
62 }),
63 }
64 }
65
66 pub fn memory_count(&self) -> usize {
68 let state = self.state.read().unwrap();
69 state.memories.values().map(|m| m.len()).sum()
70 }
71
72 pub fn link_count(&self, project: &str) -> usize {
74 let state = self.state.read().unwrap();
75 state.links.get(project).map(|l| l.len()).unwrap_or(0)
76 }
77}
78
79impl Default for InMemoryAdapter {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl MemoryAdapter for InMemoryAdapter {
86 async fn store(
87 &self,
88 project: &str,
89 request: InsertMemory,
90 ) -> Result<MemoryRecord, MemoryError> {
91 if content_may_contain_secrets(&request.content) {
93 return Err(MemoryError::RedactionRequired);
94 }
95
96 let mut state = self.state.write().unwrap();
97
98 if let Some(scopes) = state.scopes.get(project) {
100 if let Some(scope) = scopes.get(request.scope_id.as_str()) {
101 if scope.closed {
102 return Err(MemoryError::ScopeClosed(
103 request.scope_id.as_str().to_string(),
104 ));
105 }
106 }
107 }
108
109 let now = Utc::now();
110 let memory_id = Uuid::now_v7().to_string();
111
112 let record = MemoryRecord {
113 memory_id: memory_id.clone(),
114 scope_id: request.scope_id,
115 memory_class: request.memory_class,
116 content: request.content,
117 metadata: request.metadata,
118 relevance_score: 0.0,
119 retrieval_count: 0,
120 created_at: now,
121 updated_at: now,
122 };
123
124 state
125 .memories
126 .entry(project.to_string())
127 .or_default()
128 .insert(memory_id, record.clone());
129
130 Ok(record)
131 }
132
133 async fn query(
134 &self,
135 project: &str,
136 query: MemoryQuery,
137 ) -> Result<Vec<MemoryRecord>, MemoryError> {
138 let state = self.state.read().unwrap();
139
140 let memories = match state.memories.get(project) {
141 Some(m) => m,
142 None => return Ok(vec![]),
143 };
144
145 let query_lower = query.query_text.to_lowercase();
146 let query_terms: Vec<&str> = query_lower.split_whitespace().collect();
147
148 let mut results: Vec<MemoryRecord> = memories
149 .values()
150 .filter(|m| {
151 if m.scope_id != query.scope_id {
153 return false;
154 }
155 if let Some(class) = query.memory_class {
157 if m.memory_class != class {
158 return false;
159 }
160 }
161 if query_terms.is_empty() {
163 return true;
164 }
165 let content_lower = m.content.to_lowercase();
166 query_terms.iter().any(|term| content_lower.contains(term))
167 })
168 .cloned()
169 .enumerate()
170 .map(|(i, mut m)| {
171 let content_lower = m.content.to_lowercase();
173 let match_count = query_terms
174 .iter()
175 .filter(|t| content_lower.contains(**t))
176 .count();
177 m.relevance_score = match_count as f64 / query_terms.len().max(1) as f64;
178 m.retrieval_count += 1;
179 let _ = i; m
181 })
182 .collect();
183
184 results.sort_by(|a, b| {
186 b.relevance_score
187 .partial_cmp(&a.relevance_score)
188 .unwrap_or(std::cmp::Ordering::Equal)
189 });
190
191 results.truncate(query.limit as usize);
193
194 Ok(results)
195 }
196
197 async fn get(&self, project: &str, memory_id: &str) -> Result<MemoryRecord, MemoryError> {
198 let state = self.state.read().unwrap();
199
200 state
201 .memories
202 .get(project)
203 .and_then(|m| m.get(memory_id))
204 .cloned()
205 .ok_or_else(|| MemoryError::NotFound(memory_id.to_string()))
206 }
207
208 async fn delete(&self, project: &str, memory_id: &str) -> Result<(), MemoryError> {
209 let mut state = self.state.write().unwrap();
210
211 let removed = state
212 .memories
213 .get_mut(project)
214 .and_then(|m| m.remove(memory_id));
215
216 if removed.is_some() {
217 if let Some(links) = state.links.get_mut(project) {
219 links.retain(|l| l.from_memory_id != memory_id && l.to_memory_id != memory_id);
220 }
221 Ok(())
222 } else {
223 Err(MemoryError::NotFound(memory_id.to_string()))
224 }
225 }
226
227 async fn link(&self, project: &str, link: LinkMemories) -> Result<(), MemoryError> {
228 let mut state = self.state.write().unwrap();
229
230 let memories = state
232 .memories
233 .get(project)
234 .ok_or_else(|| MemoryError::NotFound(link.from_memory_id.clone()))?;
235
236 if !memories.contains_key(&link.from_memory_id) {
237 return Err(MemoryError::NotFound(link.from_memory_id));
238 }
239 if !memories.contains_key(&link.to_memory_id) {
240 return Err(MemoryError::NotFound(link.to_memory_id));
241 }
242
243 state
244 .links
245 .entry(project.to_string())
246 .or_default()
247 .push(MemoryLink {
248 from_memory_id: link.from_memory_id,
249 to_memory_id: link.to_memory_id,
250 relationship: link.relationship,
251 metadata: link.metadata,
252 });
253
254 Ok(())
255 }
256
257 async fn unlink(
258 &self,
259 project: &str,
260 from_memory_id: &str,
261 to_memory_id: &str,
262 ) -> Result<(), MemoryError> {
263 let mut state = self.state.write().unwrap();
264
265 if let Some(links) = state.links.get_mut(project) {
266 let before = links.len();
267 links.retain(|l| {
268 !(l.from_memory_id == from_memory_id && l.to_memory_id == to_memory_id)
269 });
270 if links.len() == before {
271 return Err(MemoryError::NotFound(format!(
272 "link {from_memory_id} -> {to_memory_id}"
273 )));
274 }
275 } else {
276 return Err(MemoryError::NotFound(format!(
277 "link {from_memory_id} -> {to_memory_id}"
278 )));
279 }
280
281 Ok(())
282 }
283
284 async fn create_scope(
285 &self,
286 project: &str,
287 scope_id: &ScopeId,
288 parent: Option<&ScopeId>,
289 ) -> Result<(), MemoryError> {
290 let mut state = self.state.write().unwrap();
291
292 state.scopes.entry(project.to_string()).or_default().insert(
293 scope_id.as_str().to_string(),
294 ScopeState {
295 parent: parent.cloned(),
296 closed: false,
297 },
298 );
299
300 Ok(())
301 }
302
303 async fn close_scope(&self, project: &str, scope_id: &ScopeId) -> Result<(), MemoryError> {
304 let mut state = self.state.write().unwrap();
305
306 let scope = state
307 .scopes
308 .get_mut(project)
309 .and_then(|s| s.get_mut(scope_id.as_str()))
310 .ok_or_else(|| MemoryError::ScopeNotFound {
311 project: project.to_string(),
312 scope: scope_id.as_str().to_string(),
313 })?;
314
315 scope.closed = true;
316 Ok(())
317 }
318
319 async fn health_check(&self) -> Result<bool, MemoryError> {
320 Ok(true)
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use crate::yarli_memory::types::MemoryClass;
328
329 fn test_scope() -> ScopeId {
330 ScopeId::for_run(uuid::Uuid::nil())
331 }
332
333 #[tokio::test]
334 async fn store_and_get() {
335 let adapter = InMemoryAdapter::new();
336 let req = InsertMemory::new(test_scope(), MemoryClass::Working, "test content");
337 let record = adapter.store("proj", req).await.unwrap();
338 assert_eq!(record.content, "test content");
339 assert_eq!(record.memory_class, MemoryClass::Working);
340
341 let fetched = adapter.get("proj", &record.memory_id).await.unwrap();
342 assert_eq!(fetched.memory_id, record.memory_id);
343 }
344
345 #[tokio::test]
346 async fn store_rejects_secrets() {
347 let adapter = InMemoryAdapter::new();
348 let req = InsertMemory::new(test_scope(), MemoryClass::Working, "my password=hunter2");
349 let err = adapter.store("proj", req).await.unwrap_err();
350 assert!(matches!(err, MemoryError::RedactionRequired));
351 }
352
353 #[tokio::test]
354 async fn store_rejects_private_key() {
355 let adapter = InMemoryAdapter::new();
356 let req = InsertMemory::new(
357 test_scope(),
358 MemoryClass::Working,
359 "-----BEGIN RSA PRIVATE KEY-----\nMIIE...",
360 );
361 let err = adapter.store("proj", req).await.unwrap_err();
362 assert!(matches!(err, MemoryError::RedactionRequired));
363 }
364
365 #[tokio::test]
366 async fn store_rejects_api_key() {
367 let adapter = InMemoryAdapter::new();
368 let req = InsertMemory::new(
369 test_scope(),
370 MemoryClass::Working,
371 "config: api_key=sk-proj-abc123",
372 );
373 let err = adapter.store("proj", req).await.unwrap_err();
374 assert!(matches!(err, MemoryError::RedactionRequired));
375 }
376
377 #[tokio::test]
378 async fn query_returns_matching() {
379 let adapter = InMemoryAdapter::new();
380 let scope = test_scope();
381
382 adapter
383 .store(
384 "proj",
385 InsertMemory::new(scope.clone(), MemoryClass::Semantic, "cargo build failed"),
386 )
387 .await
388 .unwrap();
389 adapter
390 .store(
391 "proj",
392 InsertMemory::new(scope.clone(), MemoryClass::Semantic, "tests passed ok"),
393 )
394 .await
395 .unwrap();
396
397 let results = adapter
398 .query("proj", MemoryQuery::new(scope, "cargo"))
399 .await
400 .unwrap();
401 assert_eq!(results.len(), 1);
402 assert!(results[0].content.contains("cargo"));
403 }
404
405 #[tokio::test]
406 async fn query_filters_by_class() {
407 let adapter = InMemoryAdapter::new();
408 let scope = test_scope();
409
410 adapter
411 .store(
412 "proj",
413 InsertMemory::new(scope.clone(), MemoryClass::Working, "working note"),
414 )
415 .await
416 .unwrap();
417 adapter
418 .store(
419 "proj",
420 InsertMemory::new(scope.clone(), MemoryClass::Semantic, "semantic lesson"),
421 )
422 .await
423 .unwrap();
424
425 let results = adapter
426 .query(
427 "proj",
428 MemoryQuery::new(scope, "note lesson").with_class(MemoryClass::Working),
429 )
430 .await
431 .unwrap();
432 assert_eq!(results.len(), 1);
433 assert!(results[0].content.contains("working"));
434 }
435
436 #[tokio::test]
437 async fn query_respects_limit() {
438 let adapter = InMemoryAdapter::new();
439 let scope = test_scope();
440
441 for i in 0..10 {
442 adapter
443 .store(
444 "proj",
445 InsertMemory::new(scope.clone(), MemoryClass::Working, format!("item {i}")),
446 )
447 .await
448 .unwrap();
449 }
450
451 let results = adapter
452 .query("proj", MemoryQuery::new(scope, "item").with_limit(3))
453 .await
454 .unwrap();
455 assert_eq!(results.len(), 3);
456 }
457
458 #[tokio::test]
459 async fn query_empty_project() {
460 let adapter = InMemoryAdapter::new();
461 let results = adapter
462 .query("nonexistent", MemoryQuery::new(test_scope(), "anything"))
463 .await
464 .unwrap();
465 assert!(results.is_empty());
466 }
467
468 #[tokio::test]
469 async fn delete_memory() {
470 let adapter = InMemoryAdapter::new();
471 let record = adapter
472 .store(
473 "proj",
474 InsertMemory::new(test_scope(), MemoryClass::Working, "to delete"),
475 )
476 .await
477 .unwrap();
478
479 adapter.delete("proj", &record.memory_id).await.unwrap();
480 let err = adapter.get("proj", &record.memory_id).await.unwrap_err();
481 assert!(matches!(err, MemoryError::NotFound(_)));
482 }
483
484 #[tokio::test]
485 async fn delete_not_found() {
486 let adapter = InMemoryAdapter::new();
487 let err = adapter.delete("proj", "nonexistent").await.unwrap_err();
488 assert!(matches!(err, MemoryError::NotFound(_)));
489 }
490
491 #[tokio::test]
492 async fn link_and_unlink() {
493 let adapter = InMemoryAdapter::new();
494 let scope = test_scope();
495
496 let m1 = adapter
497 .store(
498 "proj",
499 InsertMemory::new(scope.clone(), MemoryClass::Episodic, "incident occurred"),
500 )
501 .await
502 .unwrap();
503 let m2 = adapter
504 .store(
505 "proj",
506 InsertMemory::new(scope, MemoryClass::Semantic, "fix applied"),
507 )
508 .await
509 .unwrap();
510
511 adapter
512 .link(
513 "proj",
514 LinkMemories {
515 from_memory_id: m1.memory_id.clone(),
516 to_memory_id: m2.memory_id.clone(),
517 relationship: RelationshipKind::CauseEffect,
518 metadata: HashMap::new(),
519 },
520 )
521 .await
522 .unwrap();
523
524 assert_eq!(adapter.link_count("proj"), 1);
525
526 adapter
527 .unlink("proj", &m1.memory_id, &m2.memory_id)
528 .await
529 .unwrap();
530 assert_eq!(adapter.link_count("proj"), 0);
531 }
532
533 #[tokio::test]
534 async fn link_not_found() {
535 let adapter = InMemoryAdapter::new();
536 let err = adapter
537 .link(
538 "proj",
539 LinkMemories {
540 from_memory_id: "no-such".to_string(),
541 to_memory_id: "no-such-2".to_string(),
542 relationship: RelationshipKind::RelatesTo,
543 metadata: HashMap::new(),
544 },
545 )
546 .await
547 .unwrap_err();
548 assert!(matches!(err, MemoryError::NotFound(_)));
549 }
550
551 #[tokio::test]
552 async fn scope_lifecycle() {
553 let adapter = InMemoryAdapter::new();
554 let scope = ScopeId::for_run(uuid::Uuid::nil());
555
556 adapter.create_scope("proj", &scope, None).await.unwrap();
557
558 adapter
560 .store(
561 "proj",
562 InsertMemory::new(scope.clone(), MemoryClass::Working, "before close"),
563 )
564 .await
565 .unwrap();
566
567 adapter.close_scope("proj", &scope).await.unwrap();
569
570 let err = adapter
572 .store(
573 "proj",
574 InsertMemory::new(scope, MemoryClass::Working, "after close"),
575 )
576 .await
577 .unwrap_err();
578 assert!(matches!(err, MemoryError::ScopeClosed(_)));
579 }
580
581 #[tokio::test]
582 async fn close_nonexistent_scope() {
583 let adapter = InMemoryAdapter::new();
584 let scope = ScopeId::for_agent("ghost");
585 let err = adapter.close_scope("proj", &scope).await.unwrap_err();
586 assert!(matches!(err, MemoryError::ScopeNotFound { .. }));
587 }
588
589 #[tokio::test]
590 async fn health_check_ok() {
591 let adapter = InMemoryAdapter::new();
592 assert!(adapter.health_check().await.unwrap());
593 }
594
595 #[tokio::test]
596 async fn scope_id_constructors() {
597 let run_scope = ScopeId::for_run(uuid::Uuid::nil());
598 assert!(run_scope.as_str().starts_with("run/"));
599
600 let task_scope = ScopeId::for_task(uuid::Uuid::nil());
601 assert!(task_scope.as_str().starts_with("task/"));
602
603 let agent_scope = ScopeId::for_agent("my-agent");
604 assert_eq!(agent_scope.as_str(), "agent/my-agent");
605 }
606
607 #[tokio::test]
608 async fn memory_class_roundtrip() {
609 let json = serde_json::to_string(&MemoryClass::Semantic).unwrap();
610 assert_eq!(json, "\"semantic\"");
611 let decoded: MemoryClass = serde_json::from_str(&json).unwrap();
612 assert_eq!(decoded, MemoryClass::Semantic);
613 }
614
615 #[tokio::test]
616 async fn insert_with_metadata() {
617 let adapter = InMemoryAdapter::new();
618 let req = InsertMemory::new(test_scope(), MemoryClass::Episodic, "incident timeline")
619 .with_metadata("run_id", "abc123")
620 .with_metadata("severity", "high");
621 let record = adapter.store("proj", req).await.unwrap();
622 assert_eq!(record.metadata.get("run_id").unwrap(), "abc123");
623 assert_eq!(record.metadata.get("severity").unwrap(), "high");
624 }
625
626 #[tokio::test]
627 async fn delete_removes_associated_links() {
628 let adapter = InMemoryAdapter::new();
629 let scope = test_scope();
630
631 let m1 = adapter
632 .store(
633 "proj",
634 InsertMemory::new(scope.clone(), MemoryClass::Working, "first"),
635 )
636 .await
637 .unwrap();
638 let m2 = adapter
639 .store(
640 "proj",
641 InsertMemory::new(scope, MemoryClass::Working, "second"),
642 )
643 .await
644 .unwrap();
645
646 adapter
647 .link(
648 "proj",
649 LinkMemories {
650 from_memory_id: m1.memory_id.clone(),
651 to_memory_id: m2.memory_id.clone(),
652 relationship: RelationshipKind::DependsOn,
653 metadata: HashMap::new(),
654 },
655 )
656 .await
657 .unwrap();
658 assert_eq!(adapter.link_count("proj"), 1);
659
660 adapter.delete("proj", &m1.memory_id).await.unwrap();
662 assert_eq!(adapter.link_count("proj"), 0);
663 }
664
665 #[tokio::test]
666 async fn query_relevance_scoring() {
667 let adapter = InMemoryAdapter::new();
668 let scope = test_scope();
669
670 adapter
671 .store(
672 "proj",
673 InsertMemory::new(
674 scope.clone(),
675 MemoryClass::Semantic,
676 "rust cargo build test",
677 ),
678 )
679 .await
680 .unwrap();
681 adapter
682 .store(
683 "proj",
684 InsertMemory::new(scope.clone(), MemoryClass::Semantic, "just cargo"),
685 )
686 .await
687 .unwrap();
688
689 let results = adapter
690 .query("proj", MemoryQuery::new(scope, "cargo build"))
691 .await
692 .unwrap();
693 assert_eq!(results.len(), 2);
694 assert!(results[0].relevance_score >= results[1].relevance_score);
696 }
697
698 #[tokio::test]
699 async fn memory_count() {
700 let adapter = InMemoryAdapter::new();
701 assert_eq!(adapter.memory_count(), 0);
702
703 adapter
704 .store(
705 "proj",
706 InsertMemory::new(test_scope(), MemoryClass::Working, "one"),
707 )
708 .await
709 .unwrap();
710 assert_eq!(adapter.memory_count(), 1);
711 }
712
713 #[tokio::test]
714 async fn scope_with_parent() {
715 let adapter = InMemoryAdapter::new();
716 let parent = ScopeId::for_run(uuid::Uuid::nil());
717 let child = ScopeId::for_task(uuid::Uuid::nil());
718
719 adapter.create_scope("proj", &parent, None).await.unwrap();
720 adapter
721 .create_scope("proj", &child, Some(&parent))
722 .await
723 .unwrap();
724
725 adapter
727 .store(
728 "proj",
729 InsertMemory::new(parent, MemoryClass::Working, "parent note"),
730 )
731 .await
732 .unwrap();
733 adapter
734 .store(
735 "proj",
736 InsertMemory::new(child, MemoryClass::Working, "child note"),
737 )
738 .await
739 .unwrap();
740
741 assert_eq!(adapter.memory_count(), 2);
742 }
743
744 #[tokio::test]
745 async fn content_secret_detection() {
746 assert!(content_may_contain_secrets("password=hunter2"));
747 assert!(content_may_contain_secrets(
748 "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxx"
749 ));
750 assert!(content_may_contain_secrets(
751 "-----BEGIN RSA PRIVATE KEY-----"
752 ));
753 assert!(content_may_contain_secrets("token=abc123"));
754 assert!(content_may_contain_secrets(
755 "export AWS_SECRET_ACCESS_KEY=xxx"
756 ));
757 assert!(content_may_contain_secrets(
758 "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
759 ));
760 assert!(content_may_contain_secrets("sk-proj-abc123"));
761
762 assert!(!content_may_contain_secrets("normal text about code"));
763 assert!(!content_may_contain_secrets("the password field was empty"));
764 }
765
766 #[tokio::test]
767 async fn relationship_kind_roundtrip() {
768 let json = serde_json::to_string(&RelationshipKind::CauseEffect).unwrap();
769 assert_eq!(json, "\"cause_effect\"");
770 let decoded: RelationshipKind = serde_json::from_str(&json).unwrap();
771 assert_eq!(decoded, RelationshipKind::CauseEffect);
772 }
773
774 #[tokio::test]
775 async fn query_different_scopes_isolated() {
776 let adapter = InMemoryAdapter::new();
777 let scope1 = ScopeId::for_run(uuid::Uuid::from_u128(1));
778 let scope2 = ScopeId::for_run(uuid::Uuid::from_u128(2));
779
780 adapter
781 .store(
782 "proj",
783 InsertMemory::new(scope1.clone(), MemoryClass::Working, "scope1 data"),
784 )
785 .await
786 .unwrap();
787 adapter
788 .store(
789 "proj",
790 InsertMemory::new(scope2.clone(), MemoryClass::Working, "scope2 data"),
791 )
792 .await
793 .unwrap();
794
795 let r1 = adapter
796 .query("proj", MemoryQuery::new(scope1, "data"))
797 .await
798 .unwrap();
799 assert_eq!(r1.len(), 1);
800 assert!(r1[0].content.contains("scope1"));
801
802 let r2 = adapter
803 .query("proj", MemoryQuery::new(scope2, "data"))
804 .await
805 .unwrap();
806 assert_eq!(r2.len(), 1);
807 assert!(r2[0].content.contains("scope2"));
808 }
809}