Skip to main content

walrus_memory/sqlite/
memory.rs

1//! Memory trait implementation for SqliteMemory.
2
3use crate::sqlite::{SqliteMemory, now_unix, sql};
4use anyhow::Result;
5use std::future::Future;
6use wcore::{Embedder, Memory, MemoryEntry, RecallOptions};
7
8impl<E: Embedder> Memory for SqliteMemory<E> {
9    fn get(&self, key: &str) -> Option<String> {
10        let conn = self.conn.lock().unwrap();
11        let now = now_unix();
12        conn.execute(sql::TOUCH_ACCESS, rusqlite::params![now as i64, key])
13            .ok();
14        conn.query_row(sql::SELECT_VALUE, [key], |row| row.get(0))
15            .ok()
16    }
17
18    fn entries(&self) -> Vec<(String, String)> {
19        let conn = self.conn.lock().unwrap();
20        let mut stmt = conn.prepare(sql::SELECT_ENTRIES).unwrap();
21        stmt.query_map([], |row| Ok((row.get(0)?, row.get(1)?)))
22            .unwrap()
23            .filter_map(|r| r.ok())
24            .collect()
25    }
26
27    fn set(&self, key: impl Into<String>, value: impl Into<String>) -> Option<String> {
28        let key = key.into();
29        let value = value.into();
30        let conn = self.conn.lock().unwrap();
31        let now = now_unix() as i64;
32
33        let old: Option<String> = conn
34            .query_row(sql::SELECT_VALUE, [&key], |row| row.get(0))
35            .ok();
36
37        conn.execute(sql::UPSERT, rusqlite::params![key, value, now])
38            .ok();
39
40        old
41    }
42
43    fn remove(&self, key: &str) -> Option<String> {
44        let conn = self.conn.lock().unwrap();
45        let old: Option<String> = conn
46            .query_row(sql::SELECT_VALUE, [key], |row| row.get(0))
47            .ok();
48        if old.is_some() {
49            conn.execute(sql::DELETE, [key]).ok();
50        }
51        old
52    }
53
54    fn store(
55        &self,
56        key: impl Into<String> + Send,
57        value: impl Into<String> + Send,
58    ) -> impl Future<Output = Result<()>> + Send {
59        let key = key.into();
60        let value = value.into();
61
62        async move {
63            // Auto-embed when embedder is present.
64            let embedding = if let Some(embedder) = &self.embedder {
65                let emb = embedder.embed(&value).await;
66                if emb.is_empty() { None } else { Some(emb) }
67            } else {
68                None
69            };
70
71            self.store_with_metadata(&key, &value, None, embedding.as_deref())?;
72            Ok(())
73        }
74    }
75
76    fn recall(
77        &self,
78        query: &str,
79        options: RecallOptions,
80    ) -> impl Future<Output = Result<Vec<MemoryEntry>>> + Send {
81        let query = query.to_owned();
82
83        async move {
84            // Embed query when embedder is present.
85            let query_embedding = if let Some(embedder) = &self.embedder {
86                let emb = embedder.embed(&query).await;
87                if emb.is_empty() { None } else { Some(emb) }
88            } else {
89                None
90            };
91
92            self.recall_sync(&query, &options, query_embedding.as_deref())
93        }
94    }
95
96    fn compile_relevant(&self, query: &str) -> impl Future<Output = String> + Send {
97        let query = query.to_owned();
98
99        async move {
100            let opts = RecallOptions {
101                limit: 5,
102                ..Default::default()
103            };
104
105            // Embed query when embedder is present.
106            let query_embedding = if let Some(embedder) = &self.embedder {
107                let emb = embedder.embed(&query).await;
108                if emb.is_empty() { None } else { Some(emb) }
109            } else {
110                None
111            };
112
113            let entries = self
114                .recall_sync(&query, &opts, query_embedding.as_deref())
115                .unwrap_or_default();
116
117            if entries.is_empty() {
118                return String::new();
119            }
120
121            let mut out = String::from("<memory>\n");
122            for entry in &entries {
123                out.push_str(&format!("<{}>\n", entry.key));
124                out.push_str(&entry.value);
125                if !entry.value.ends_with('\n') {
126                    out.push('\n');
127                }
128                out.push_str(&format!("</{}>\n", entry.key));
129            }
130            out.push_str("</memory>");
131            out
132        }
133    }
134}