Skip to main content

yarli_cli/yarli-memory/src/
in_memory.rs

1//! In-memory adapter for testing.
2//!
3//! No gRPC dependency; stores records in a thread-safe HashMap.
4
5use std::collections::HashMap;
6use std::sync::RwLock;
7
8use chrono::Utc;
9use uuid::Uuid;
10
11use crate::yarli_memory::adapter::MemoryAdapter;
12use crate::yarli_memory::error::MemoryError;
13use crate::yarli_memory::types::{
14    content_may_contain_secrets, InsertMemory, LinkMemories, MemoryQuery, MemoryRecord,
15    RelationshipKind, ScopeId,
16};
17
18/// Scope state tracking.
19#[derive(Debug, Clone)]
20struct ScopeState {
21    #[allow(dead_code)]
22    parent: Option<ScopeId>,
23    closed: bool,
24}
25
26/// A link between two memories.
27#[derive(Debug, Clone)]
28struct MemoryLink {
29    from_memory_id: String,
30    to_memory_id: String,
31    #[allow(dead_code)]
32    relationship: RelationshipKind,
33    #[allow(dead_code)]
34    metadata: HashMap<String, String>,
35}
36
37/// In-memory adapter state.
38struct InMemoryState {
39    /// project -> scope_id -> ScopeState
40    scopes: HashMap<String, HashMap<String, ScopeState>>,
41    /// project -> memory_id -> MemoryRecord
42    memories: HashMap<String, HashMap<String, MemoryRecord>>,
43    /// project -> links
44    links: HashMap<String, Vec<MemoryLink>>,
45}
46
47/// In-memory implementation of [`MemoryAdapter`].
48///
49/// Suitable for testing and development. Not persistent.
50pub struct InMemoryAdapter {
51    state: RwLock<InMemoryState>,
52}
53
54impl InMemoryAdapter {
55    /// Create a new empty adapter.
56    pub fn new() -> Self {
57        Self {
58            state: RwLock::new(InMemoryState {
59                scopes: HashMap::new(),
60                memories: HashMap::new(),
61                links: HashMap::new(),
62            }),
63        }
64    }
65
66    /// Count all memories across all projects.
67    pub fn memory_count(&self) -> usize {
68        let state = self.state.read().unwrap();
69        state.memories.values().map(|m| m.len()).sum()
70    }
71
72    /// Count links for a project.
73    pub fn link_count(&self, project: &str) -> usize {
74        let state = self.state.read().unwrap();
75        state.links.get(project).map(|l| l.len()).unwrap_or(0)
76    }
77}
78
79impl Default for InMemoryAdapter {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl MemoryAdapter for InMemoryAdapter {
86    async fn store(
87        &self,
88        project: &str,
89        request: InsertMemory,
90    ) -> Result<MemoryRecord, MemoryError> {
91        // Section 14.4: redaction check
92        if content_may_contain_secrets(&request.content) {
93            return Err(MemoryError::RedactionRequired);
94        }
95
96        let mut state = self.state.write().unwrap();
97
98        // Check scope exists and is not closed
99        if let Some(scopes) = state.scopes.get(project) {
100            if let Some(scope) = scopes.get(request.scope_id.as_str()) {
101                if scope.closed {
102                    return Err(MemoryError::ScopeClosed(
103                        request.scope_id.as_str().to_string(),
104                    ));
105                }
106            }
107        }
108
109        let now = Utc::now();
110        let memory_id = Uuid::now_v7().to_string();
111
112        let record = MemoryRecord {
113            memory_id: memory_id.clone(),
114            scope_id: request.scope_id,
115            memory_class: request.memory_class,
116            content: request.content,
117            metadata: request.metadata,
118            relevance_score: 0.0,
119            retrieval_count: 0,
120            created_at: now,
121            updated_at: now,
122        };
123
124        state
125            .memories
126            .entry(project.to_string())
127            .or_default()
128            .insert(memory_id, record.clone());
129
130        Ok(record)
131    }
132
133    async fn query(
134        &self,
135        project: &str,
136        query: MemoryQuery,
137    ) -> Result<Vec<MemoryRecord>, MemoryError> {
138        let state = self.state.read().unwrap();
139
140        let memories = match state.memories.get(project) {
141            Some(m) => m,
142            None => return Ok(vec![]),
143        };
144
145        let query_lower = query.query_text.to_lowercase();
146        let query_terms: Vec<&str> = query_lower.split_whitespace().collect();
147
148        let mut results: Vec<MemoryRecord> = memories
149            .values()
150            .filter(|m| {
151                // Filter by scope
152                if m.scope_id != query.scope_id {
153                    return false;
154                }
155                // Filter by class
156                if let Some(class) = query.memory_class {
157                    if m.memory_class != class {
158                        return false;
159                    }
160                }
161                // Simple text matching (BM25 approximation for in-memory)
162                if query_terms.is_empty() {
163                    return true;
164                }
165                let content_lower = m.content.to_lowercase();
166                query_terms.iter().any(|term| content_lower.contains(term))
167            })
168            .cloned()
169            .enumerate()
170            .map(|(i, mut m)| {
171                // Simple relevance scoring: count matching terms
172                let content_lower = m.content.to_lowercase();
173                let match_count = query_terms
174                    .iter()
175                    .filter(|t| content_lower.contains(**t))
176                    .count();
177                m.relevance_score = match_count as f64 / query_terms.len().max(1) as f64;
178                m.retrieval_count += 1;
179                let _ = i; // suppress unused
180                m
181            })
182            .collect();
183
184        // Sort by relevance descending
185        results.sort_by(|a, b| {
186            b.relevance_score
187                .partial_cmp(&a.relevance_score)
188                .unwrap_or(std::cmp::Ordering::Equal)
189        });
190
191        // Apply limit
192        results.truncate(query.limit as usize);
193
194        Ok(results)
195    }
196
197    async fn get(&self, project: &str, memory_id: &str) -> Result<MemoryRecord, MemoryError> {
198        let state = self.state.read().unwrap();
199
200        state
201            .memories
202            .get(project)
203            .and_then(|m| m.get(memory_id))
204            .cloned()
205            .ok_or_else(|| MemoryError::NotFound(memory_id.to_string()))
206    }
207
208    async fn delete(&self, project: &str, memory_id: &str) -> Result<(), MemoryError> {
209        let mut state = self.state.write().unwrap();
210
211        let removed = state
212            .memories
213            .get_mut(project)
214            .and_then(|m| m.remove(memory_id));
215
216        if removed.is_some() {
217            // Also remove any links involving this memory
218            if let Some(links) = state.links.get_mut(project) {
219                links.retain(|l| l.from_memory_id != memory_id && l.to_memory_id != memory_id);
220            }
221            Ok(())
222        } else {
223            Err(MemoryError::NotFound(memory_id.to_string()))
224        }
225    }
226
227    async fn link(&self, project: &str, link: LinkMemories) -> Result<(), MemoryError> {
228        let mut state = self.state.write().unwrap();
229
230        // Verify both memories exist
231        let memories = state
232            .memories
233            .get(project)
234            .ok_or_else(|| MemoryError::NotFound(link.from_memory_id.clone()))?;
235
236        if !memories.contains_key(&link.from_memory_id) {
237            return Err(MemoryError::NotFound(link.from_memory_id));
238        }
239        if !memories.contains_key(&link.to_memory_id) {
240            return Err(MemoryError::NotFound(link.to_memory_id));
241        }
242
243        state
244            .links
245            .entry(project.to_string())
246            .or_default()
247            .push(MemoryLink {
248                from_memory_id: link.from_memory_id,
249                to_memory_id: link.to_memory_id,
250                relationship: link.relationship,
251                metadata: link.metadata,
252            });
253
254        Ok(())
255    }
256
257    async fn unlink(
258        &self,
259        project: &str,
260        from_memory_id: &str,
261        to_memory_id: &str,
262    ) -> Result<(), MemoryError> {
263        let mut state = self.state.write().unwrap();
264
265        if let Some(links) = state.links.get_mut(project) {
266            let before = links.len();
267            links.retain(|l| {
268                !(l.from_memory_id == from_memory_id && l.to_memory_id == to_memory_id)
269            });
270            if links.len() == before {
271                return Err(MemoryError::NotFound(format!(
272                    "link {from_memory_id} -> {to_memory_id}"
273                )));
274            }
275        } else {
276            return Err(MemoryError::NotFound(format!(
277                "link {from_memory_id} -> {to_memory_id}"
278            )));
279        }
280
281        Ok(())
282    }
283
284    async fn create_scope(
285        &self,
286        project: &str,
287        scope_id: &ScopeId,
288        parent: Option<&ScopeId>,
289    ) -> Result<(), MemoryError> {
290        let mut state = self.state.write().unwrap();
291
292        state.scopes.entry(project.to_string()).or_default().insert(
293            scope_id.as_str().to_string(),
294            ScopeState {
295                parent: parent.cloned(),
296                closed: false,
297            },
298        );
299
300        Ok(())
301    }
302
303    async fn close_scope(&self, project: &str, scope_id: &ScopeId) -> Result<(), MemoryError> {
304        let mut state = self.state.write().unwrap();
305
306        let scope = state
307            .scopes
308            .get_mut(project)
309            .and_then(|s| s.get_mut(scope_id.as_str()))
310            .ok_or_else(|| MemoryError::ScopeNotFound {
311                project: project.to_string(),
312                scope: scope_id.as_str().to_string(),
313            })?;
314
315        scope.closed = true;
316        Ok(())
317    }
318
319    async fn health_check(&self) -> Result<bool, MemoryError> {
320        Ok(true)
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use crate::yarli_memory::types::MemoryClass;
328
329    fn test_scope() -> ScopeId {
330        ScopeId::for_run(uuid::Uuid::nil())
331    }
332
333    #[tokio::test]
334    async fn store_and_get() {
335        let adapter = InMemoryAdapter::new();
336        let req = InsertMemory::new(test_scope(), MemoryClass::Working, "test content");
337        let record = adapter.store("proj", req).await.unwrap();
338        assert_eq!(record.content, "test content");
339        assert_eq!(record.memory_class, MemoryClass::Working);
340
341        let fetched = adapter.get("proj", &record.memory_id).await.unwrap();
342        assert_eq!(fetched.memory_id, record.memory_id);
343    }
344
345    #[tokio::test]
346    async fn store_rejects_secrets() {
347        let adapter = InMemoryAdapter::new();
348        let req = InsertMemory::new(test_scope(), MemoryClass::Working, "my password=hunter2");
349        let err = adapter.store("proj", req).await.unwrap_err();
350        assert!(matches!(err, MemoryError::RedactionRequired));
351    }
352
353    #[tokio::test]
354    async fn store_rejects_private_key() {
355        let adapter = InMemoryAdapter::new();
356        let req = InsertMemory::new(
357            test_scope(),
358            MemoryClass::Working,
359            "-----BEGIN RSA PRIVATE KEY-----\nMIIE...",
360        );
361        let err = adapter.store("proj", req).await.unwrap_err();
362        assert!(matches!(err, MemoryError::RedactionRequired));
363    }
364
365    #[tokio::test]
366    async fn store_rejects_api_key() {
367        let adapter = InMemoryAdapter::new();
368        let req = InsertMemory::new(
369            test_scope(),
370            MemoryClass::Working,
371            "config: api_key=sk-proj-abc123",
372        );
373        let err = adapter.store("proj", req).await.unwrap_err();
374        assert!(matches!(err, MemoryError::RedactionRequired));
375    }
376
377    #[tokio::test]
378    async fn query_returns_matching() {
379        let adapter = InMemoryAdapter::new();
380        let scope = test_scope();
381
382        adapter
383            .store(
384                "proj",
385                InsertMemory::new(scope.clone(), MemoryClass::Semantic, "cargo build failed"),
386            )
387            .await
388            .unwrap();
389        adapter
390            .store(
391                "proj",
392                InsertMemory::new(scope.clone(), MemoryClass::Semantic, "tests passed ok"),
393            )
394            .await
395            .unwrap();
396
397        let results = adapter
398            .query("proj", MemoryQuery::new(scope, "cargo"))
399            .await
400            .unwrap();
401        assert_eq!(results.len(), 1);
402        assert!(results[0].content.contains("cargo"));
403    }
404
405    #[tokio::test]
406    async fn query_filters_by_class() {
407        let adapter = InMemoryAdapter::new();
408        let scope = test_scope();
409
410        adapter
411            .store(
412                "proj",
413                InsertMemory::new(scope.clone(), MemoryClass::Working, "working note"),
414            )
415            .await
416            .unwrap();
417        adapter
418            .store(
419                "proj",
420                InsertMemory::new(scope.clone(), MemoryClass::Semantic, "semantic lesson"),
421            )
422            .await
423            .unwrap();
424
425        let results = adapter
426            .query(
427                "proj",
428                MemoryQuery::new(scope, "note lesson").with_class(MemoryClass::Working),
429            )
430            .await
431            .unwrap();
432        assert_eq!(results.len(), 1);
433        assert!(results[0].content.contains("working"));
434    }
435
436    #[tokio::test]
437    async fn query_respects_limit() {
438        let adapter = InMemoryAdapter::new();
439        let scope = test_scope();
440
441        for i in 0..10 {
442            adapter
443                .store(
444                    "proj",
445                    InsertMemory::new(scope.clone(), MemoryClass::Working, format!("item {i}")),
446                )
447                .await
448                .unwrap();
449        }
450
451        let results = adapter
452            .query("proj", MemoryQuery::new(scope, "item").with_limit(3))
453            .await
454            .unwrap();
455        assert_eq!(results.len(), 3);
456    }
457
458    #[tokio::test]
459    async fn query_empty_project() {
460        let adapter = InMemoryAdapter::new();
461        let results = adapter
462            .query("nonexistent", MemoryQuery::new(test_scope(), "anything"))
463            .await
464            .unwrap();
465        assert!(results.is_empty());
466    }
467
468    #[tokio::test]
469    async fn delete_memory() {
470        let adapter = InMemoryAdapter::new();
471        let record = adapter
472            .store(
473                "proj",
474                InsertMemory::new(test_scope(), MemoryClass::Working, "to delete"),
475            )
476            .await
477            .unwrap();
478
479        adapter.delete("proj", &record.memory_id).await.unwrap();
480        let err = adapter.get("proj", &record.memory_id).await.unwrap_err();
481        assert!(matches!(err, MemoryError::NotFound(_)));
482    }
483
484    #[tokio::test]
485    async fn delete_not_found() {
486        let adapter = InMemoryAdapter::new();
487        let err = adapter.delete("proj", "nonexistent").await.unwrap_err();
488        assert!(matches!(err, MemoryError::NotFound(_)));
489    }
490
491    #[tokio::test]
492    async fn link_and_unlink() {
493        let adapter = InMemoryAdapter::new();
494        let scope = test_scope();
495
496        let m1 = adapter
497            .store(
498                "proj",
499                InsertMemory::new(scope.clone(), MemoryClass::Episodic, "incident occurred"),
500            )
501            .await
502            .unwrap();
503        let m2 = adapter
504            .store(
505                "proj",
506                InsertMemory::new(scope, MemoryClass::Semantic, "fix applied"),
507            )
508            .await
509            .unwrap();
510
511        adapter
512            .link(
513                "proj",
514                LinkMemories {
515                    from_memory_id: m1.memory_id.clone(),
516                    to_memory_id: m2.memory_id.clone(),
517                    relationship: RelationshipKind::CauseEffect,
518                    metadata: HashMap::new(),
519                },
520            )
521            .await
522            .unwrap();
523
524        assert_eq!(adapter.link_count("proj"), 1);
525
526        adapter
527            .unlink("proj", &m1.memory_id, &m2.memory_id)
528            .await
529            .unwrap();
530        assert_eq!(adapter.link_count("proj"), 0);
531    }
532
533    #[tokio::test]
534    async fn link_not_found() {
535        let adapter = InMemoryAdapter::new();
536        let err = adapter
537            .link(
538                "proj",
539                LinkMemories {
540                    from_memory_id: "no-such".to_string(),
541                    to_memory_id: "no-such-2".to_string(),
542                    relationship: RelationshipKind::RelatesTo,
543                    metadata: HashMap::new(),
544                },
545            )
546            .await
547            .unwrap_err();
548        assert!(matches!(err, MemoryError::NotFound(_)));
549    }
550
551    #[tokio::test]
552    async fn scope_lifecycle() {
553        let adapter = InMemoryAdapter::new();
554        let scope = ScopeId::for_run(uuid::Uuid::nil());
555
556        adapter.create_scope("proj", &scope, None).await.unwrap();
557
558        // Can store into open scope
559        adapter
560            .store(
561                "proj",
562                InsertMemory::new(scope.clone(), MemoryClass::Working, "before close"),
563            )
564            .await
565            .unwrap();
566
567        // Close the scope
568        adapter.close_scope("proj", &scope).await.unwrap();
569
570        // Writes to closed scope are rejected
571        let err = adapter
572            .store(
573                "proj",
574                InsertMemory::new(scope, MemoryClass::Working, "after close"),
575            )
576            .await
577            .unwrap_err();
578        assert!(matches!(err, MemoryError::ScopeClosed(_)));
579    }
580
581    #[tokio::test]
582    async fn close_nonexistent_scope() {
583        let adapter = InMemoryAdapter::new();
584        let scope = ScopeId::for_agent("ghost");
585        let err = adapter.close_scope("proj", &scope).await.unwrap_err();
586        assert!(matches!(err, MemoryError::ScopeNotFound { .. }));
587    }
588
589    #[tokio::test]
590    async fn health_check_ok() {
591        let adapter = InMemoryAdapter::new();
592        assert!(adapter.health_check().await.unwrap());
593    }
594
595    #[tokio::test]
596    async fn scope_id_constructors() {
597        let run_scope = ScopeId::for_run(uuid::Uuid::nil());
598        assert!(run_scope.as_str().starts_with("run/"));
599
600        let task_scope = ScopeId::for_task(uuid::Uuid::nil());
601        assert!(task_scope.as_str().starts_with("task/"));
602
603        let agent_scope = ScopeId::for_agent("my-agent");
604        assert_eq!(agent_scope.as_str(), "agent/my-agent");
605    }
606
607    #[tokio::test]
608    async fn memory_class_roundtrip() {
609        let json = serde_json::to_string(&MemoryClass::Semantic).unwrap();
610        assert_eq!(json, "\"semantic\"");
611        let decoded: MemoryClass = serde_json::from_str(&json).unwrap();
612        assert_eq!(decoded, MemoryClass::Semantic);
613    }
614
615    #[tokio::test]
616    async fn insert_with_metadata() {
617        let adapter = InMemoryAdapter::new();
618        let req = InsertMemory::new(test_scope(), MemoryClass::Episodic, "incident timeline")
619            .with_metadata("run_id", "abc123")
620            .with_metadata("severity", "high");
621        let record = adapter.store("proj", req).await.unwrap();
622        assert_eq!(record.metadata.get("run_id").unwrap(), "abc123");
623        assert_eq!(record.metadata.get("severity").unwrap(), "high");
624    }
625
626    #[tokio::test]
627    async fn delete_removes_associated_links() {
628        let adapter = InMemoryAdapter::new();
629        let scope = test_scope();
630
631        let m1 = adapter
632            .store(
633                "proj",
634                InsertMemory::new(scope.clone(), MemoryClass::Working, "first"),
635            )
636            .await
637            .unwrap();
638        let m2 = adapter
639            .store(
640                "proj",
641                InsertMemory::new(scope, MemoryClass::Working, "second"),
642            )
643            .await
644            .unwrap();
645
646        adapter
647            .link(
648                "proj",
649                LinkMemories {
650                    from_memory_id: m1.memory_id.clone(),
651                    to_memory_id: m2.memory_id.clone(),
652                    relationship: RelationshipKind::DependsOn,
653                    metadata: HashMap::new(),
654                },
655            )
656            .await
657            .unwrap();
658        assert_eq!(adapter.link_count("proj"), 1);
659
660        // Deleting m1 should also remove the link
661        adapter.delete("proj", &m1.memory_id).await.unwrap();
662        assert_eq!(adapter.link_count("proj"), 0);
663    }
664
665    #[tokio::test]
666    async fn query_relevance_scoring() {
667        let adapter = InMemoryAdapter::new();
668        let scope = test_scope();
669
670        adapter
671            .store(
672                "proj",
673                InsertMemory::new(
674                    scope.clone(),
675                    MemoryClass::Semantic,
676                    "rust cargo build test",
677                ),
678            )
679            .await
680            .unwrap();
681        adapter
682            .store(
683                "proj",
684                InsertMemory::new(scope.clone(), MemoryClass::Semantic, "just cargo"),
685            )
686            .await
687            .unwrap();
688
689        let results = adapter
690            .query("proj", MemoryQuery::new(scope, "cargo build"))
691            .await
692            .unwrap();
693        assert_eq!(results.len(), 2);
694        // First result should have higher relevance (matches more terms)
695        assert!(results[0].relevance_score >= results[1].relevance_score);
696    }
697
698    #[tokio::test]
699    async fn memory_count() {
700        let adapter = InMemoryAdapter::new();
701        assert_eq!(adapter.memory_count(), 0);
702
703        adapter
704            .store(
705                "proj",
706                InsertMemory::new(test_scope(), MemoryClass::Working, "one"),
707            )
708            .await
709            .unwrap();
710        assert_eq!(adapter.memory_count(), 1);
711    }
712
713    #[tokio::test]
714    async fn scope_with_parent() {
715        let adapter = InMemoryAdapter::new();
716        let parent = ScopeId::for_run(uuid::Uuid::nil());
717        let child = ScopeId::for_task(uuid::Uuid::nil());
718
719        adapter.create_scope("proj", &parent, None).await.unwrap();
720        adapter
721            .create_scope("proj", &child, Some(&parent))
722            .await
723            .unwrap();
724
725        // Both scopes accept writes
726        adapter
727            .store(
728                "proj",
729                InsertMemory::new(parent, MemoryClass::Working, "parent note"),
730            )
731            .await
732            .unwrap();
733        adapter
734            .store(
735                "proj",
736                InsertMemory::new(child, MemoryClass::Working, "child note"),
737            )
738            .await
739            .unwrap();
740
741        assert_eq!(adapter.memory_count(), 2);
742    }
743
744    #[tokio::test]
745    async fn content_secret_detection() {
746        assert!(content_may_contain_secrets("password=hunter2"));
747        assert!(content_may_contain_secrets(
748            "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxx"
749        ));
750        assert!(content_may_contain_secrets(
751            "-----BEGIN RSA PRIVATE KEY-----"
752        ));
753        assert!(content_may_contain_secrets("token=abc123"));
754        assert!(content_may_contain_secrets(
755            "export AWS_SECRET_ACCESS_KEY=xxx"
756        ));
757        assert!(content_may_contain_secrets(
758            "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
759        ));
760        assert!(content_may_contain_secrets("sk-proj-abc123"));
761
762        assert!(!content_may_contain_secrets("normal text about code"));
763        assert!(!content_may_contain_secrets("the password field was empty"));
764    }
765
766    #[tokio::test]
767    async fn relationship_kind_roundtrip() {
768        let json = serde_json::to_string(&RelationshipKind::CauseEffect).unwrap();
769        assert_eq!(json, "\"cause_effect\"");
770        let decoded: RelationshipKind = serde_json::from_str(&json).unwrap();
771        assert_eq!(decoded, RelationshipKind::CauseEffect);
772    }
773
774    #[tokio::test]
775    async fn query_different_scopes_isolated() {
776        let adapter = InMemoryAdapter::new();
777        let scope1 = ScopeId::for_run(uuid::Uuid::from_u128(1));
778        let scope2 = ScopeId::for_run(uuid::Uuid::from_u128(2));
779
780        adapter
781            .store(
782                "proj",
783                InsertMemory::new(scope1.clone(), MemoryClass::Working, "scope1 data"),
784            )
785            .await
786            .unwrap();
787        adapter
788            .store(
789                "proj",
790                InsertMemory::new(scope2.clone(), MemoryClass::Working, "scope2 data"),
791            )
792            .await
793            .unwrap();
794
795        let r1 = adapter
796            .query("proj", MemoryQuery::new(scope1, "data"))
797            .await
798            .unwrap();
799        assert_eq!(r1.len(), 1);
800        assert!(r1[0].content.contains("scope1"));
801
802        let r2 = adapter
803            .query("proj", MemoryQuery::new(scope2, "data"))
804            .await
805            .unwrap();
806        assert_eq!(r2.len(), 1);
807        assert!(r2[0].content.contains("scope2"));
808    }
809}