Skip to main content

walrus_memory/
memory.rs

1//! Memory trait implementation for SqliteMemory.
2
3use crate::SqliteMemory;
4use crate::sql;
5use crate::utils::now_unix;
6use crate::{Embedder, Memory, MemoryEntry, RecallOptions};
7use anyhow::Result;
8use std::future::Future;
9
10impl<E: Embedder> Memory for SqliteMemory<E> {
11    fn get(&self, key: &str) -> Option<String> {
12        let conn = self.conn.lock().unwrap();
13        let now = now_unix();
14        conn.execute(sql::TOUCH_ACCESS, rusqlite::params![now as i64, key])
15            .ok();
16        conn.query_row(sql::SELECT_VALUE, [key], |row| row.get(0))
17            .ok()
18    }
19
20    fn entries(&self) -> Vec<(String, String)> {
21        let conn = self.conn.lock().unwrap();
22        let mut stmt = conn.prepare(sql::SELECT_ENTRIES).unwrap();
23        stmt.query_map([], |row| Ok((row.get(0)?, row.get(1)?)))
24            .unwrap()
25            .filter_map(|r| r.ok())
26            .collect()
27    }
28
29    fn set(&self, key: impl Into<String>, value: impl Into<String>) -> Option<String> {
30        let key = key.into();
31        let value = value.into();
32        let conn = self.conn.lock().unwrap();
33        let now = now_unix() as i64;
34
35        let old: Option<String> = conn
36            .query_row(sql::SELECT_VALUE, [&key], |row| row.get(0))
37            .ok();
38
39        conn.execute(sql::UPSERT, rusqlite::params![key, value, now])
40            .ok();
41
42        old
43    }
44
45    fn remove(&self, key: &str) -> Option<String> {
46        let conn = self.conn.lock().unwrap();
47        let old: Option<String> = conn
48            .query_row(sql::SELECT_VALUE, [key], |row| row.get(0))
49            .ok();
50        if old.is_some() {
51            conn.execute(sql::DELETE, [key]).ok();
52        }
53        old
54    }
55
56    fn store(
57        &self,
58        key: impl Into<String> + Send,
59        value: impl Into<String> + Send,
60    ) -> impl Future<Output = Result<()>> + Send {
61        let key = key.into();
62        let value = value.into();
63
64        async move {
65            // Auto-embed when embedder is present.
66            let embedding = if let Some(embedder) = &self.embedder {
67                let emb = embedder.embed(&value).await;
68                if emb.is_empty() { None } else { Some(emb) }
69            } else {
70                None
71            };
72
73            self.store_with_metadata(&key, &value, None, embedding.as_deref())?;
74            Ok(())
75        }
76    }
77
78    fn recall(
79        &self,
80        query: &str,
81        options: RecallOptions,
82    ) -> impl Future<Output = Result<Vec<MemoryEntry>>> + Send {
83        let query = query.to_owned();
84
85        async move {
86            // Embed query when embedder is present.
87            let query_embedding = if let Some(embedder) = &self.embedder {
88                let emb = embedder.embed(&query).await;
89                if emb.is_empty() { None } else { Some(emb) }
90            } else {
91                None
92            };
93
94            self.recall_sync(&query, &options, query_embedding.as_deref())
95        }
96    }
97
98    fn compile_relevant(&self, query: &str) -> impl Future<Output = String> + Send {
99        let query = query.to_owned();
100
101        async move {
102            let opts = RecallOptions {
103                limit: 5,
104                ..Default::default()
105            };
106
107            // Embed query when embedder is present.
108            let query_embedding = if let Some(embedder) = &self.embedder {
109                let emb = embedder.embed(&query).await;
110                if emb.is_empty() { None } else { Some(emb) }
111            } else {
112                None
113            };
114
115            let entries = self
116                .recall_sync(&query, &opts, query_embedding.as_deref())
117                .unwrap_or_default();
118
119            if entries.is_empty() {
120                return String::new();
121            }
122
123            let mut out = String::from("<memory>\n");
124            for entry in &entries {
125                out.push_str(&format!("<{}>\n", entry.key));
126                out.push_str(&entry.value);
127                if !entry.value.ends_with('\n') {
128                    out.push('\n');
129                }
130                out.push_str(&format!("</{}>\n", entry.key));
131            }
132            out.push_str("</memory>");
133            out
134        }
135    }
136}