Skip to main content

smooth_operator_adapter_postgres/
memory.rs

1//! pgvector-backed [`Memory`] — persistent + semantic cross-thread agent memory.
2//!
3//! Parity gap Phase 3 / SMOODEV-1470. The core only shipped
4//! [`InMemoryMemory`](smooth_operator_core::InMemoryMemory) (a `Vec` behind a
5//! `Mutex`, keyword recall, gone on restart). The general agent needs
6//! cross-thread *user* memory that survives restarts and recalls by **semantic
7//! similarity** — the TS side does this with a Postgres `store`/`store_vectors`
8//! namespaced by `['memories', orgId, userId]`. [`PgMemory`] is the Rust
9//! equivalent.
10//!
11//! ## Namespace scoping
12//!
13//! The core [`Memory`] trait's `recall(query, limit)` carries **no** org/user
14//! scoping in its signature (same shape as [`KnowledgeBase`]'s `query`). So,
15//! exactly like [`PgKnowledgeBase`](crate::PgKnowledgeBase) binds an
16//! `organization_id`, a [`PgMemory`] instance is **bound to one
17//! `(organization_id, user_id)` namespace at construction**. Every `store`
18//! stamps the namespace onto the row; every `recall` filters on it in SQL
19//! *before* ANN ranking. A `PgMemory` for org A can never recall org B's rows
20//! (or another user's, when user-scoped). `user_id = None` ⇒ org-wide memory.
21//!
22//! ## Semantic recall
23//!
24//! `store` embeds the entry content (through the shared [`Embedder`] seam —
25//! [`DeterministicEmbedder`] offline so tests need no network, [`GatewayEmbedder`]
26//! live) and upserts the row. `recall` embeds the query and ranks the namespace's
27//! rows by pgvector cosine distance (`embedding <=> $query`) under the HNSW index,
28//! returning the top-K with `relevance` set to the cosine similarity (`1 -
29//! distance`). Dimension handling matches `knowledge_vectors`: the `vector(N)`
30//! column width comes from `embedder.dim()`.
31//!
32//! ## Sync trait over async work
33//!
34//! [`Memory`] is **synchronous** (the engine calls it directly) but embedding +
35//! deadpool are async. We bridge identically to
36//! [`PgKnowledgeBase`](crate::PgKnowledgeBase): spawn the async future onto the
37//! captured runtime [`Handle`] and block on it from a throwaway OS thread — never
38//! `Handle::block_on` on a runtime worker (which panics "Cannot start a runtime
39//! from within a runtime"). See [`PgMemory::run_blocking`].
40
41use std::sync::Arc;
42
43use anyhow::{anyhow, Result};
44use deadpool_postgres::Pool;
45use tokio::runtime::Handle;
46
47use smooth_operator_core::{Memory, MemoryEntry, MemoryType};
48
49use smooth_operator::embedding::{Embedder, InputType};
50
51/// pgvector-backed agent memory, bound to one `(organization_id, user_id)`
52/// namespace. Cheap to clone (all fields are `Arc`/pool handles + small strings).
53#[derive(Clone)]
54pub struct PgMemory {
55    pool: Pool,
56    embedder: Arc<dyn Embedder>,
57    handle: Handle,
58    /// Org scope — required. Every `store` stamps it; every `recall` filters on it.
59    organization_id: String,
60    /// User scope — `None` ⇒ org-wide memory. When `Some`, recall is isolated to
61    /// this user's rows (and org-wide rows are NOT mixed in — strict namespace
62    /// match, mirroring the TS `['memories', orgId, userId]` key tuple).
63    user_id: Option<String>,
64}
65
66impl PgMemory {
67    /// Build a memory handle bound to `(organization_id, user_id)`. Pass
68    /// `user_id = None` for org-wide memory.
69    pub(crate) fn new(
70        pool: Pool,
71        embedder: Arc<dyn Embedder>,
72        handle: Handle,
73        organization_id: impl Into<String>,
74        user_id: Option<String>,
75    ) -> Self {
76        Self {
77            pool,
78            embedder,
79            handle,
80            organization_id: organization_id.into(),
81            user_id,
82        }
83    }
84
85    /// Format a vector as a pgvector literal: `[0.1,0.2,...]`. Same wire shape as
86    /// [`PgKnowledgeBase`](crate::PgKnowledgeBase).
87    fn vector_literal(v: &[f32]) -> String {
88        let mut s = String::with_capacity(v.len() * 8 + 2);
89        s.push('[');
90        for (i, x) in v.iter().enumerate() {
91            if i > 0 {
92                s.push(',');
93            }
94            s.push_str(&x.to_string());
95        }
96        s.push(']');
97        s
98    }
99
100    /// Serialize a [`MemoryType`] to its stored text form (the serde tag, e.g.
101    /// `"ShortTerm"`). Round-trips through [`Self::memory_type_from_str`].
102    fn memory_type_to_str(mt: MemoryType) -> Result<String> {
103        // serde serializes a unit enum variant as a JSON string `"Variant"`;
104        // strip the quotes for a clean column value.
105        let json = serde_json::to_string(&mt)?;
106        Ok(json.trim_matches('"').to_string())
107    }
108
109    /// Parse a stored `memory_type` text back into a [`MemoryType`].
110    fn memory_type_from_str(s: &str) -> Result<MemoryType> {
111        Ok(serde_json::from_str(&format!("\"{s}\""))?)
112    }
113
114    async fn store_async(&self, entry: MemoryEntry) -> Result<()> {
115        let embeddings = self
116            .embedder
117            .embed(std::slice::from_ref(&entry.content), InputType::Document)
118            .await?;
119        let embedding = embeddings
120            .into_iter()
121            .next()
122            .ok_or_else(|| anyhow!("embedder returned no vector"))?;
123        let literal = Self::vector_literal(&embedding);
124        let metadata = serde_json::to_value(&entry.metadata)?;
125        let memory_type = Self::memory_type_to_str(entry.memory_type)?;
126
127        let client = self.pool.get().await?;
128        // Upsert by entry id so re-storing the same logical memory replaces it in
129        // place (re-embedding content that may have changed).
130        client
131            .execute(
132                "INSERT INTO memories
133                    (id, organization_id, user_id, content, memory_type, relevance,
134                     metadata, embedding, created_at, last_accessed)
135                 VALUES ($1, $2, $3, $4, $5, $6, $7, $8::text::vector, $9, $10)
136                 ON CONFLICT (id) DO UPDATE SET
137                    organization_id = EXCLUDED.organization_id,
138                    user_id         = EXCLUDED.user_id,
139                    content         = EXCLUDED.content,
140                    memory_type     = EXCLUDED.memory_type,
141                    relevance       = EXCLUDED.relevance,
142                    metadata        = EXCLUDED.metadata,
143                    embedding       = EXCLUDED.embedding,
144                    last_accessed   = EXCLUDED.last_accessed",
145                &[
146                    &entry.id,
147                    &self.organization_id,
148                    &self.user_id,
149                    &entry.content,
150                    &memory_type,
151                    &entry.relevance,
152                    &metadata,
153                    &literal,
154                    &entry.created_at,
155                    &entry.last_accessed,
156                ],
157            )
158            .await?;
159        Ok(())
160    }
161
162    async fn recall_async(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
163        let embeddings = self
164            .embedder
165            .embed(&[query.to_string()], InputType::Query)
166            .await?;
167        let embedding = embeddings
168            .into_iter()
169            .next()
170            .ok_or_else(|| anyhow!("embedder returned no query vector"))?;
171        let literal = Self::vector_literal(&embedding);
172        let limit_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
173
174        let client = self.pool.get().await?;
175        // Strict namespace match: org always; user_id matched with NULL-safe
176        // equality (`IS NOT DISTINCT FROM`) so a `None`-scoped handle recalls
177        // exactly the org-wide rows and a `Some`-scoped handle recalls exactly
178        // that user's rows — neither leaks into the other. Ranking is pgvector
179        // cosine distance under the HNSW index; `1 - distance` is the similarity
180        // surfaced as `relevance`.
181        let rows = client
182            .query(
183                "SELECT id, content, memory_type, metadata, created_at, last_accessed,
184                        1 - (embedding <=> $3::text::vector) AS similarity
185                 FROM memories
186                 WHERE organization_id = $1
187                   AND user_id IS NOT DISTINCT FROM $2
188                 ORDER BY embedding <=> $3::text::vector
189                 LIMIT $4",
190                &[&self.organization_id, &self.user_id, &literal, &limit_i64],
191            )
192            .await?;
193
194        rows.iter()
195            .map(|row| {
196                let memory_type =
197                    Self::memory_type_from_str(row.get::<_, String>("memory_type").as_str())?;
198                let metadata_json: serde_json::Value = row.get("metadata");
199                let metadata = serde_json::from_value(metadata_json)?;
200                #[allow(clippy::cast_possible_truncation)]
201                let similarity = row.get::<_, f64>("similarity") as f32;
202                Ok(MemoryEntry {
203                    id: row.get("id"),
204                    content: row.get("content"),
205                    memory_type,
206                    relevance: similarity,
207                    metadata,
208                    created_at: row.get("created_at"),
209                    last_accessed: row.get("last_accessed"),
210                })
211            })
212            .collect()
213    }
214
215    async fn forget_async(&self, id: &str) -> Result<()> {
216        let client = self.pool.get().await?;
217        // Scope the delete to this handle's namespace so one tenant can't forget
218        // another's memory by guessing an id.
219        client
220            .execute(
221                "DELETE FROM memories
222                 WHERE id = $1
223                   AND organization_id = $2
224                   AND user_id IS NOT DISTINCT FROM $3",
225                &[&id, &self.organization_id, &self.user_id],
226            )
227            .await?;
228        Ok(())
229    }
230
231    /// Drive an async future to completion from a *synchronous* trait method.
232    ///
233    /// Identical bridge to [`PgKnowledgeBase::run_blocking`](crate::PgKnowledgeBase):
234    /// spawn onto the captured runtime so its I/O makes progress on that
235    /// runtime's reactor, then block on the `JoinHandle` from a throwaway OS
236    /// thread running a tiny current-thread runtime — never `Handle::block_on` on
237    /// a worker thread (which panics "Cannot start a runtime from within a
238    /// runtime").
239    fn run_blocking<F, T>(&self, fut: F) -> Result<T>
240    where
241        F: std::future::Future<Output = Result<T>> + Send + 'static,
242        T: Send + 'static,
243    {
244        let join = self.handle.spawn(fut);
245        let (tx, rx) = std::sync::mpsc::channel();
246        std::thread::spawn(move || {
247            let result = (|| -> Result<T> {
248                let rt = tokio::runtime::Builder::new_current_thread()
249                    .enable_all()
250                    .build()?;
251                let joined = rt.block_on(join);
252                joined.map_err(|e| anyhow!("memory task panicked or was cancelled: {e}"))?
253            })();
254            let _ = tx.send(result);
255        });
256        rx.recv()
257            .map_err(|e| anyhow!("memory task channel closed: {e}"))?
258    }
259}
260
261impl Memory for PgMemory {
262    fn store(&self, entry: MemoryEntry) -> Result<()> {
263        let this = self.clone();
264        self.run_blocking(async move { this.store_async(entry).await })
265    }
266
267    fn recall(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
268        let this = self.clone();
269        let query = query.to_string();
270        self.run_blocking(async move { this.recall_async(&query, limit).await })
271    }
272
273    fn forget(&self, id: &str) -> Result<()> {
274        let this = self.clone();
275        let id = id.to_string();
276        self.run_blocking(async move { this.forget_async(&id).await })
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn memory_type_round_trips_through_text() {
286        for mt in [
287            MemoryType::ShortTerm,
288            MemoryType::LongTerm,
289            MemoryType::Entity,
290            MemoryType::User,
291            MemoryType::Feedback,
292            MemoryType::Project,
293            MemoryType::Reference,
294        ] {
295            let s = PgMemory::memory_type_to_str(mt).expect("to_str");
296            // Stored form is the bare serde tag, no surrounding quotes.
297            assert!(
298                !s.contains('"'),
299                "stored memory_type must be unquoted: {s:?}"
300            );
301            let parsed = PgMemory::memory_type_from_str(&s).expect("from_str");
302            assert_eq!(parsed, mt);
303        }
304    }
305
306    #[test]
307    fn vector_literal_shape() {
308        assert_eq!(PgMemory::vector_literal(&[0.5, -1.0, 2.0]), "[0.5,-1,2]");
309        assert_eq!(PgMemory::vector_literal(&[]), "[]");
310    }
311}