smooth_operator_adapter_postgres/memory.rs
1//! pgvector-backed [`Memory`] — persistent + semantic cross-thread agent memory.
2//!
3//! Parity gap Phase 3 / SMOODEV-1470. The core only shipped
4//! [`InMemoryMemory`](smooth_operator_core::InMemoryMemory) (a `Vec` behind a
5//! `Mutex`, keyword recall, gone on restart). The general agent needs
6//! cross-thread *user* memory that survives restarts and recalls by **semantic
7//! similarity** — the TS side does this with a Postgres `store`/`store_vectors`
8//! namespaced by `['memories', orgId, userId]`. [`PgMemory`] is the Rust
9//! equivalent.
10//!
11//! ## Namespace scoping
12//!
13//! The core [`Memory`] trait's `recall(query, limit)` carries **no** org/user
14//! scoping in its signature (same shape as [`KnowledgeBase`]'s `query`). So,
15//! exactly like [`PgKnowledgeBase`](crate::PgKnowledgeBase) binds an
16//! `organization_id`, a [`PgMemory`] instance is **bound to one
17//! `(organization_id, user_id)` namespace at construction**. Every `store`
18//! stamps the namespace onto the row; every `recall` filters on it in SQL
19//! *before* ANN ranking. A `PgMemory` for org A can never recall org B's rows
20//! (or another user's, when user-scoped). `user_id = None` ⇒ org-wide memory.
21//!
22//! ## Semantic recall
23//!
24//! `store` embeds the entry content (through the shared [`Embedder`] seam —
25//! [`DeterministicEmbedder`] offline so tests need no network, [`GatewayEmbedder`]
26//! live) and upserts the row. `recall` embeds the query and ranks the namespace's
27//! rows by pgvector cosine distance (`embedding <=> $query`) under the HNSW index,
28//! returning the top-K with `relevance` set to the cosine similarity (`1 -
29//! distance`). Dimension handling matches `knowledge_vectors`: the `vector(N)`
30//! column width comes from `embedder.dim()`.
31//!
32//! ## Sync trait over async work
33//!
34//! [`Memory`] is **synchronous** (the engine calls it directly) but embedding +
35//! deadpool are async. We bridge identically to
36//! [`PgKnowledgeBase`](crate::PgKnowledgeBase): spawn the async future onto the
37//! captured runtime [`Handle`] and block on it from a throwaway OS thread — never
38//! `Handle::block_on` on a runtime worker (which panics "Cannot start a runtime
39//! from within a runtime"). See [`PgMemory::run_blocking`].
40
41use std::sync::Arc;
42
43use anyhow::{anyhow, Result};
44use deadpool_postgres::Pool;
45use tokio::runtime::Handle;
46
47use smooth_operator_core::{Memory, MemoryEntry, MemoryType};
48
49use smooth_operator::embedding::{Embedder, InputType};
50
51/// pgvector-backed agent memory, bound to one `(organization_id, user_id)`
52/// namespace. Cheap to clone (all fields are `Arc`/pool handles + small strings).
53#[derive(Clone)]
54pub struct PgMemory {
55 pool: Pool,
56 embedder: Arc<dyn Embedder>,
57 handle: Handle,
58 /// Org scope — required. Every `store` stamps it; every `recall` filters on it.
59 organization_id: String,
60 /// User scope — `None` ⇒ org-wide memory. When `Some`, recall is isolated to
61 /// this user's rows (and org-wide rows are NOT mixed in — strict namespace
62 /// match, mirroring the TS `['memories', orgId, userId]` key tuple).
63 user_id: Option<String>,
64}
65
66impl PgMemory {
67 /// Build a memory handle bound to `(organization_id, user_id)`. Pass
68 /// `user_id = None` for org-wide memory.
69 pub(crate) fn new(
70 pool: Pool,
71 embedder: Arc<dyn Embedder>,
72 handle: Handle,
73 organization_id: impl Into<String>,
74 user_id: Option<String>,
75 ) -> Self {
76 Self {
77 pool,
78 embedder,
79 handle,
80 organization_id: organization_id.into(),
81 user_id,
82 }
83 }
84
85 /// Format a vector as a pgvector literal: `[0.1,0.2,...]`. Same wire shape as
86 /// [`PgKnowledgeBase`](crate::PgKnowledgeBase).
87 fn vector_literal(v: &[f32]) -> String {
88 let mut s = String::with_capacity(v.len() * 8 + 2);
89 s.push('[');
90 for (i, x) in v.iter().enumerate() {
91 if i > 0 {
92 s.push(',');
93 }
94 s.push_str(&x.to_string());
95 }
96 s.push(']');
97 s
98 }
99
100 /// Serialize a [`MemoryType`] to its stored text form (the serde tag, e.g.
101 /// `"ShortTerm"`). Round-trips through [`Self::memory_type_from_str`].
102 fn memory_type_to_str(mt: MemoryType) -> Result<String> {
103 // serde serializes a unit enum variant as a JSON string `"Variant"`;
104 // strip the quotes for a clean column value.
105 let json = serde_json::to_string(&mt)?;
106 Ok(json.trim_matches('"').to_string())
107 }
108
109 /// Parse a stored `memory_type` text back into a [`MemoryType`].
110 fn memory_type_from_str(s: &str) -> Result<MemoryType> {
111 Ok(serde_json::from_str(&format!("\"{s}\""))?)
112 }
113
114 async fn store_async(&self, entry: MemoryEntry) -> Result<()> {
115 let embeddings = self
116 .embedder
117 .embed(std::slice::from_ref(&entry.content), InputType::Document)
118 .await?;
119 let embedding = embeddings
120 .into_iter()
121 .next()
122 .ok_or_else(|| anyhow!("embedder returned no vector"))?;
123 let literal = Self::vector_literal(&embedding);
124 let metadata = serde_json::to_value(&entry.metadata)?;
125 let memory_type = Self::memory_type_to_str(entry.memory_type)?;
126
127 let client = self.pool.get().await?;
128 // Upsert by entry id so re-storing the same logical memory replaces it in
129 // place (re-embedding content that may have changed).
130 client
131 .execute(
132 "INSERT INTO memories
133 (id, organization_id, user_id, content, memory_type, relevance,
134 metadata, embedding, created_at, last_accessed)
135 VALUES ($1, $2, $3, $4, $5, $6, $7, $8::text::vector, $9, $10)
136 ON CONFLICT (id) DO UPDATE SET
137 organization_id = EXCLUDED.organization_id,
138 user_id = EXCLUDED.user_id,
139 content = EXCLUDED.content,
140 memory_type = EXCLUDED.memory_type,
141 relevance = EXCLUDED.relevance,
142 metadata = EXCLUDED.metadata,
143 embedding = EXCLUDED.embedding,
144 last_accessed = EXCLUDED.last_accessed",
145 &[
146 &entry.id,
147 &self.organization_id,
148 &self.user_id,
149 &entry.content,
150 &memory_type,
151 &entry.relevance,
152 &metadata,
153 &literal,
154 &entry.created_at,
155 &entry.last_accessed,
156 ],
157 )
158 .await?;
159 Ok(())
160 }
161
162 async fn recall_async(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
163 let embeddings = self
164 .embedder
165 .embed(&[query.to_string()], InputType::Query)
166 .await?;
167 let embedding = embeddings
168 .into_iter()
169 .next()
170 .ok_or_else(|| anyhow!("embedder returned no query vector"))?;
171 let literal = Self::vector_literal(&embedding);
172 let limit_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
173
174 let client = self.pool.get().await?;
175 // Strict namespace match: org always; user_id matched with NULL-safe
176 // equality (`IS NOT DISTINCT FROM`) so a `None`-scoped handle recalls
177 // exactly the org-wide rows and a `Some`-scoped handle recalls exactly
178 // that user's rows — neither leaks into the other. Ranking is pgvector
179 // cosine distance under the HNSW index; `1 - distance` is the similarity
180 // surfaced as `relevance`.
181 let rows = client
182 .query(
183 "SELECT id, content, memory_type, metadata, created_at, last_accessed,
184 1 - (embedding <=> $3::text::vector) AS similarity
185 FROM memories
186 WHERE organization_id = $1
187 AND user_id IS NOT DISTINCT FROM $2
188 ORDER BY embedding <=> $3::text::vector
189 LIMIT $4",
190 &[&self.organization_id, &self.user_id, &literal, &limit_i64],
191 )
192 .await?;
193
194 rows.iter()
195 .map(|row| {
196 let memory_type =
197 Self::memory_type_from_str(row.get::<_, String>("memory_type").as_str())?;
198 let metadata_json: serde_json::Value = row.get("metadata");
199 let metadata = serde_json::from_value(metadata_json)?;
200 #[allow(clippy::cast_possible_truncation)]
201 let similarity = row.get::<_, f64>("similarity") as f32;
202 Ok(MemoryEntry {
203 id: row.get("id"),
204 content: row.get("content"),
205 memory_type,
206 relevance: similarity,
207 metadata,
208 created_at: row.get("created_at"),
209 last_accessed: row.get("last_accessed"),
210 })
211 })
212 .collect()
213 }
214
215 async fn forget_async(&self, id: &str) -> Result<()> {
216 let client = self.pool.get().await?;
217 // Scope the delete to this handle's namespace so one tenant can't forget
218 // another's memory by guessing an id.
219 client
220 .execute(
221 "DELETE FROM memories
222 WHERE id = $1
223 AND organization_id = $2
224 AND user_id IS NOT DISTINCT FROM $3",
225 &[&id, &self.organization_id, &self.user_id],
226 )
227 .await?;
228 Ok(())
229 }
230
231 /// Drive an async future to completion from a *synchronous* trait method.
232 ///
233 /// Identical bridge to [`PgKnowledgeBase::run_blocking`](crate::PgKnowledgeBase):
234 /// spawn onto the captured runtime so its I/O makes progress on that
235 /// runtime's reactor, then block on the `JoinHandle` from a throwaway OS
236 /// thread running a tiny current-thread runtime — never `Handle::block_on` on
237 /// a worker thread (which panics "Cannot start a runtime from within a
238 /// runtime").
239 fn run_blocking<F, T>(&self, fut: F) -> Result<T>
240 where
241 F: std::future::Future<Output = Result<T>> + Send + 'static,
242 T: Send + 'static,
243 {
244 let join = self.handle.spawn(fut);
245 let (tx, rx) = std::sync::mpsc::channel();
246 std::thread::spawn(move || {
247 let result = (|| -> Result<T> {
248 let rt = tokio::runtime::Builder::new_current_thread()
249 .enable_all()
250 .build()?;
251 let joined = rt.block_on(join);
252 joined.map_err(|e| anyhow!("memory task panicked or was cancelled: {e}"))?
253 })();
254 let _ = tx.send(result);
255 });
256 rx.recv()
257 .map_err(|e| anyhow!("memory task channel closed: {e}"))?
258 }
259}
260
261impl Memory for PgMemory {
262 fn store(&self, entry: MemoryEntry) -> Result<()> {
263 let this = self.clone();
264 self.run_blocking(async move { this.store_async(entry).await })
265 }
266
267 fn recall(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
268 let this = self.clone();
269 let query = query.to_string();
270 self.run_blocking(async move { this.recall_async(&query, limit).await })
271 }
272
273 fn forget(&self, id: &str) -> Result<()> {
274 let this = self.clone();
275 let id = id.to_string();
276 self.run_blocking(async move { this.forget_async(&id).await })
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn memory_type_round_trips_through_text() {
286 for mt in [
287 MemoryType::ShortTerm,
288 MemoryType::LongTerm,
289 MemoryType::Entity,
290 MemoryType::User,
291 MemoryType::Feedback,
292 MemoryType::Project,
293 MemoryType::Reference,
294 ] {
295 let s = PgMemory::memory_type_to_str(mt).expect("to_str");
296 // Stored form is the bare serde tag, no surrounding quotes.
297 assert!(
298 !s.contains('"'),
299 "stored memory_type must be unquoted: {s:?}"
300 );
301 let parsed = PgMemory::memory_type_from_str(&s).expect("from_str");
302 assert_eq!(parsed, mt);
303 }
304 }
305
306 #[test]
307 fn vector_literal_shape() {
308 assert_eq!(PgMemory::vector_literal(&[0.5, -1.0, 2.0]), "[0.5,-1,2]");
309 assert_eq!(PgMemory::vector_literal(&[]), "[]");
310 }
311}