Skip to main content

walrus_memory/sqlite/
memory.rs

1//! Memory trait implementation for SqliteMemory.
2
3use crate::{
4    Embedder, Memory, MemoryEntry, RecallOptions,
5    sqlite::{SqliteMemory, sql},
6    utils::now_unix,
7};
8use anyhow::Result;
9use std::future::Future;
10
11impl<E: Embedder> Memory for SqliteMemory<E> {
12    fn get(&self, key: &str) -> Option<String> {
13        let conn = self.conn.lock().unwrap();
14        let now = now_unix();
15        conn.execute(sql::TOUCH_ACCESS, rusqlite::params![now as i64, key])
16            .ok();
17        conn.query_row(sql::SELECT_VALUE, [key], |row| row.get(0))
18            .ok()
19    }
20
21    fn entries(&self) -> Vec<(String, String)> {
22        let conn = self.conn.lock().unwrap();
23        let mut stmt = conn.prepare(sql::SELECT_ENTRIES).unwrap();
24        stmt.query_map([], |row| Ok((row.get(0)?, row.get(1)?)))
25            .unwrap()
26            .filter_map(|r| r.ok())
27            .collect()
28    }
29
30    fn set(&self, key: impl Into<String>, value: impl Into<String>) -> Option<String> {
31        let key = key.into();
32        let value = value.into();
33        let conn = self.conn.lock().unwrap();
34        let now = now_unix() as i64;
35
36        let old: Option<String> = conn
37            .query_row(sql::SELECT_VALUE, [&key], |row| row.get(0))
38            .ok();
39
40        conn.execute(sql::UPSERT, rusqlite::params![key, value, now])
41            .ok();
42
43        old
44    }
45
46    fn remove(&self, key: &str) -> Option<String> {
47        let conn = self.conn.lock().unwrap();
48        let old: Option<String> = conn
49            .query_row(sql::SELECT_VALUE, [key], |row| row.get(0))
50            .ok();
51        if old.is_some() {
52            conn.execute(sql::DELETE, [key]).ok();
53        }
54        old
55    }
56
57    fn store(
58        &self,
59        key: impl Into<String> + Send,
60        value: impl Into<String> + Send,
61    ) -> impl Future<Output = Result<()>> + Send {
62        let key = key.into();
63        let value = value.into();
64
65        async move {
66            // Auto-embed when embedder is present.
67            let embedding = if let Some(embedder) = &self.embedder {
68                let emb = embedder.embed(&value).await;
69                if emb.is_empty() { None } else { Some(emb) }
70            } else {
71                None
72            };
73
74            self.store_with_metadata(&key, &value, None, embedding.as_deref())?;
75            Ok(())
76        }
77    }
78
79    fn recall(
80        &self,
81        query: &str,
82        options: RecallOptions,
83    ) -> impl Future<Output = Result<Vec<MemoryEntry>>> + Send {
84        let query = query.to_owned();
85
86        async move {
87            // Embed query when embedder is present.
88            let query_embedding = if let Some(embedder) = &self.embedder {
89                let emb = embedder.embed(&query).await;
90                if emb.is_empty() { None } else { Some(emb) }
91            } else {
92                None
93            };
94
95            self.recall_sync(&query, &options, query_embedding.as_deref())
96        }
97    }
98
99    fn compile_relevant(&self, query: &str) -> impl Future<Output = String> + Send {
100        let query = query.to_owned();
101
102        async move {
103            let opts = RecallOptions {
104                limit: 5,
105                ..Default::default()
106            };
107
108            // Embed query when embedder is present.
109            let query_embedding = if let Some(embedder) = &self.embedder {
110                let emb = embedder.embed(&query).await;
111                if emb.is_empty() { None } else { Some(emb) }
112            } else {
113                None
114            };
115
116            let entries = self
117                .recall_sync(&query, &opts, query_embedding.as_deref())
118                .unwrap_or_default();
119
120            if entries.is_empty() {
121                return String::new();
122            }
123
124            let mut out = String::from("<memory>\n");
125            for entry in &entries {
126                out.push_str(&format!("<{}>\n", entry.key));
127                out.push_str(&entry.value);
128                if !entry.value.ends_with('\n') {
129                    out.push('\n');
130                }
131                out.push_str(&format!("</{}>\n", entry.key));
132            }
133            out.push_str("</memory>");
134            out
135        }
136    }
137}