smooth_operator_adapter_postgres/knowledge.rs
1//! pgvector-backed [`KnowledgeBase`] with hybrid dense + sparse retrieval.
2//!
3//! smooth-operator's [`KnowledgeBase`](smooth_operator_core::KnowledgeBase) trait is
4//! **synchronous** (the engine calls `ingest`/`query` directly), but both
5//! embedding and Postgres access are async here. We bridge the two by `spawn`ing
6//! the async work onto the captured runtime [`Handle`] (so its I/O makes
7//! progress on that runtime's reactor) and blocking the calling thread on the
8//! task's `JoinHandle` from a throwaway OS thread — never calling
9//! `Handle::block_on` on a runtime worker thread (which panics "Cannot start a
10//! runtime from within a runtime"). See [`PgKnowledgeBase::run_blocking`].
11//!
12//! ## Retrieval
13//!
14//! 1. **Dense**: embed the query, rank rows by pgvector cosine distance
15//! (`embedding <=> $query`), take the top-K.
16//! 2. **Sparse**: `content_tsv @@ plainto_tsquery('english', $query)`, ranked by
17//! `ts_rank`, top-K.
18//! 3. **Fuse**: Reciprocal Rank Fusion (RRF) over the two ranked lists —
19//! `score = Σ 1/(k + rank)` (k=60) — then return the top-K fused chunks.
20//!
21//! This mirrors smooai's `knowledge_vectors` retrieval (dense HNSW ∪ sparse BM25
22//! → RRF).
23
24use std::collections::HashMap;
25use std::sync::Arc;
26
27use anyhow::{anyhow, Result};
28use deadpool_postgres::Pool;
29use tokio::runtime::Handle;
30
31use smooth_operator_core::{Document, KnowledgeBase, KnowledgeResult};
32
33use smooth_operator::access_control::{AccessContext, DocAcl};
34use smooth_operator::embedding::{Embedder, InputType};
35
36/// RRF constant. 60 is the canonical value from the original RRF paper; it
37/// damps the contribution of low-ranked items without ignoring them.
38const RRF_K: f32 = 60.0;
39
40/// pgvector knowledge base. Cheap to clone (all fields are `Arc`/pool handles).
41#[derive(Clone)]
42pub struct PgKnowledgeBase {
43 pool: Pool,
44 embedder: Arc<dyn Embedder>,
45 handle: Handle,
46 /// Optional org scoping. When set, ingest stamps and query filters on it.
47 organization_id: Option<String>,
48 /// Optional document-level access control (feature gap G3). When set, every
49 /// query filters rows by this requester's entitlements against the stored
50 /// `acl` column **in SQL** (so a restricted document is never even fetched).
51 /// `None` ⇒ no within-org ACL filtering (org isolation only) — the handle
52 /// returned by [`StorageAdapter::knowledge`]. The chat path uses
53 /// [`StorageAdapter::knowledge_for_access`], which sets this.
54 access: Option<AccessContext>,
55}
56
57impl PgKnowledgeBase {
58 pub(crate) fn new(
59 pool: Pool,
60 embedder: Arc<dyn Embedder>,
61 handle: Handle,
62 organization_id: Option<String>,
63 ) -> Self {
64 Self {
65 pool,
66 embedder,
67 handle,
68 organization_id,
69 access: None,
70 }
71 }
72
73 /// A clone of this knowledge base whose queries enforce the given
74 /// [`AccessContext`]'s document-level entitlements (in SQL, against the
75 /// stored `acl` column). Used by
76 /// [`PostgresAdapter::knowledge_for_access`](crate::PostgresAdapter) on the
77 /// chat retrieval path.
78 ///
79 /// When the context carries an [`organization_id`](AccessContext::organization_id)
80 /// (a multi-tenant host threads the turn's org through), it **overrides** the
81 /// adapter's construction-time org for this query — so one adapter instance can
82 /// serve per-turn tenants instead of being pinned to a single static org. The
83 /// org is still a cheap SQL pre-filter (`organization_id = $1`). A context with
84 /// no org leaves the construction-time org unchanged, so the single-tenant
85 /// default is behavior-preserving.
86 #[must_use]
87 pub fn with_access(&self, access: AccessContext) -> Self {
88 let organization_id = access
89 .organization_id
90 .clone()
91 .or_else(|| self.organization_id.clone());
92 Self {
93 organization_id,
94 access: Some(access),
95 ..self.clone()
96 }
97 }
98
99 /// Format a vector as a pgvector literal: `[0.1,0.2,...]`.
100 fn vector_literal(v: &[f32]) -> String {
101 let mut s = String::with_capacity(v.len() * 8 + 2);
102 s.push('[');
103 for (i, x) in v.iter().enumerate() {
104 if i > 0 {
105 s.push(',');
106 }
107 s.push_str(&x.to_string());
108 }
109 s.push(']');
110 s
111 }
112
113 async fn ingest_async(&self, doc: Document) -> Result<()> {
114 let embeddings = self
115 .embedder
116 .embed(std::slice::from_ref(&doc.content), InputType::Document)
117 .await?;
118 let embedding = embeddings
119 .into_iter()
120 .next()
121 .ok_or_else(|| anyhow!("embedder returned no vector"))?;
122 let literal = Self::vector_literal(&embedding);
123 let metadata = serde_json::to_value(&doc.metadata)?;
124 // Persist the document's ACL (feature gap G3) as a discrete column so it
125 // survives the ingest→serve process boundary and can be filtered in SQL
126 // at read. Parsed from the same `acl_v2` metadata key the in-memory
127 // store records. `None` ⇒ NULL ⇒ org-public (backward-compatible).
128 let acl: Option<serde_json::Value> = DocAcl::from_metadata(&doc.metadata)
129 .map(|a| serde_json::to_value(&a))
130 .transpose()?;
131 // Stable per-chunk id: the document is stored as a single chunk keyed by
132 // its document id, so re-ingesting the same doc upserts in place.
133 let row_id = doc.id.clone();
134
135 let client = self.pool.get().await?;
136 client
137 .execute(
138 "INSERT INTO knowledge_vectors
139 (id, document_id, organization_id, source, content, embedding, metadata, acl)
140 VALUES ($1, $2, $3, $4, $5, $6::text::vector, $7, $8)
141 ON CONFLICT (id) DO UPDATE SET
142 document_id = EXCLUDED.document_id,
143 organization_id = EXCLUDED.organization_id,
144 source = EXCLUDED.source,
145 content = EXCLUDED.content,
146 embedding = EXCLUDED.embedding,
147 metadata = EXCLUDED.metadata,
148 acl = EXCLUDED.acl",
149 &[
150 &row_id,
151 &doc.id,
152 &self.organization_id,
153 &doc.source,
154 &doc.content,
155 &literal,
156 &metadata,
157 &acl,
158 ],
159 )
160 .await?;
161 Ok(())
162 }
163
164 async fn query_async(&self, query: &str, limit: usize) -> Result<Vec<KnowledgeResult>> {
165 let embeddings = self
166 .embedder
167 .embed(&[query.to_string()], InputType::Query)
168 .await?;
169 let embedding = embeddings
170 .into_iter()
171 .next()
172 .ok_or_else(|| anyhow!("embedder returned no query vector"))?;
173 let literal = Self::vector_literal(&embedding);
174
175 // Pull a generous candidate pool from each arm so RRF has something to
176 // fuse, then truncate after fusion.
177 let candidate_n: i64 = i64::try_from((limit * 4).max(20)).unwrap_or(20);
178 let client = self.pool.get().await?;
179
180 // --- ACL filter (feature gap G3) ---
181 //
182 // When this handle is access-bound, every row must pass the requester's
183 // document-level entitlement **in SQL** — a restricted document is never
184 // even fetched. A row is visible when ANY holds:
185 // - `acl IS NULL` → no ACL recorded ⇒ org-public default
186 // - `acl->>'public'` is true → explicitly public
187 // - requester user id ∈ acl.users (jsonb `?` key-exists)
188 // - any requester group ∈ acl.groups (jsonb `?|` any-key-exists)
189 // `$4` is the requester user id (text, NULL ⇒ anonymous), `$5` the
190 // requester groups (text[]). Both are bound below. When the handle is
191 // NOT access-bound (`access` is None) the predicate is `TRUE` — org
192 // isolation only, no within-org filtering.
193 // Build the ACL predicate + the extra bound params ONLY when this handle
194 // is access-bound. Postgres rejects a prepared statement that binds a
195 // parameter the SQL never references, so the raw (org-isolation-only)
196 // path must not add `$4`/`$5`.
197 let acl_user: Option<String> = self.access.as_ref().and_then(|c| c.user_id.clone());
198 let acl_groups: Vec<String> = self
199 .access
200 .as_ref()
201 .map(|c| c.groups.clone())
202 .unwrap_or_default();
203 let acl_predicate = if self.access.is_some() {
204 // A row is visible when it has no recorded ACL (org-public), is
205 // explicitly public, names the requester's user id, or names any of
206 // the requester's groups. `?` / `?|` are jsonb key-exists operators.
207 "(acl IS NULL \
208 OR (acl->>'public')::boolean IS TRUE \
209 OR ($4::text IS NOT NULL AND acl->'users' ? $4) \
210 OR (acl->'groups' ?| $5::text[]))"
211 } else {
212 "TRUE"
213 };
214 // The query text as an owned param so the borrowed trait objects below
215 // don't tie the param vec to the input `&str` lifetime.
216 let query_owned = query.to_string();
217
218 // --- dense arm: cosine distance via pgvector `<=>` ---
219 let dense_sql = format!(
220 "SELECT id, document_id, source, content
221 FROM knowledge_vectors
222 WHERE ($1::text IS NULL OR organization_id = $1)
223 AND {acl_predicate}
224 ORDER BY embedding <=> $2::text::vector
225 LIMIT $3"
226 );
227 let mut dense_params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
228 vec![&self.organization_id, &literal, &candidate_n];
229 if self.access.is_some() {
230 dense_params.push(&acl_user);
231 dense_params.push(&acl_groups);
232 }
233 let dense_rows = client.query(&dense_sql, &dense_params).await?;
234
235 // --- sparse arm: tsvector BM25-style match, ranked by ts_rank ---
236 let sparse_sql = format!(
237 "SELECT id, document_id, source, content
238 FROM knowledge_vectors
239 WHERE ($1::text IS NULL OR organization_id = $1)
240 AND content_tsv @@ plainto_tsquery('english', $2)
241 AND {acl_predicate}
242 ORDER BY ts_rank(content_tsv, plainto_tsquery('english', $2)) DESC
243 LIMIT $3"
244 );
245 let mut sparse_params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
246 vec![&self.organization_id, &query_owned, &candidate_n];
247 if self.access.is_some() {
248 sparse_params.push(&acl_user);
249 sparse_params.push(&acl_groups);
250 }
251 let sparse_rows = client.query(&sparse_sql, &sparse_params).await?;
252
253 // --- Reciprocal Rank Fusion ---
254 struct Hit {
255 document_id: String,
256 source: String,
257 content: String,
258 score: f32,
259 }
260 let mut fused: HashMap<String, Hit> = HashMap::new();
261
262 let mut fuse = |rows: &[tokio_postgres::Row]| {
263 for (rank, row) in rows.iter().enumerate() {
264 let id: String = row.get(0);
265 let document_id: String = row.get(1);
266 let source: String = row.get(2);
267 let content: String = row.get(3);
268 #[allow(clippy::cast_precision_loss)]
269 let contribution = 1.0 / (RRF_K + (rank as f32) + 1.0);
270 fused
271 .entry(id)
272 .and_modify(|h| h.score += contribution)
273 .or_insert(Hit {
274 document_id,
275 source,
276 content,
277 score: contribution,
278 });
279 }
280 };
281 fuse(&dense_rows);
282 fuse(&sparse_rows);
283
284 let mut hits: Vec<Hit> = fused.into_values().collect();
285 hits.sort_by(|a, b| {
286 b.score
287 .partial_cmp(&a.score)
288 .unwrap_or(std::cmp::Ordering::Equal)
289 });
290 hits.truncate(limit);
291
292 Ok(hits
293 .into_iter()
294 .map(|h| KnowledgeResult {
295 document_id: h.document_id,
296 chunk: h.content,
297 score: h.score,
298 source: h.source,
299 })
300 .collect())
301 }
302}
303
304impl PgKnowledgeBase {
305 /// Drive an async future to completion from a *synchronous* trait method.
306 ///
307 /// `KnowledgeBase` is sync, but our work (embedding + deadpool) is async.
308 /// `Handle::block_on` can't be called from a runtime worker thread (it panics
309 /// "Cannot start a runtime from within a runtime"), and `block_in_place` only
310 /// relieves the *blocking-budget* concern, not that one. So we `spawn` the
311 /// future onto the runtime (where it can make progress) and block the calling
312 /// thread on a oneshot channel — wrapped in `block_in_place` when we happen to
313 /// be on a multi-thread worker so we don't starve the scheduler.
314 fn run_blocking<F, T>(&self, fut: F) -> Result<T>
315 where
316 F: std::future::Future<Output = Result<T>> + Send + 'static,
317 T: Send + 'static,
318 {
319 // Spawn the real work onto the captured runtime so its async I/O
320 // (deadpool, embedding HTTP) makes progress on that runtime's reactor.
321 let join = self.handle.spawn(fut);
322
323 // Block on the JoinHandle from a throwaway OS thread that owns a tiny
324 // current-thread runtime. This never calls `Handle::block_on` on a worker
325 // thread (which panics "Cannot start a runtime from within a runtime"),
326 // so it's safe whether the caller is on a runtime worker or not.
327 let (tx, rx) = std::sync::mpsc::channel();
328 std::thread::spawn(move || {
329 let result = (|| -> Result<T> {
330 let rt = tokio::runtime::Builder::new_current_thread()
331 .enable_all()
332 .build()?;
333 let joined = rt.block_on(join);
334 joined.map_err(|e| anyhow!("knowledge task panicked or was cancelled: {e}"))?
335 })();
336 let _ = tx.send(result);
337 });
338 rx.recv()
339 .map_err(|e| anyhow!("knowledge task channel closed: {e}"))?
340 }
341}
342
343impl KnowledgeBase for PgKnowledgeBase {
344 fn ingest(&self, doc: Document) -> Result<()> {
345 let this = self.clone();
346 self.run_blocking(async move { this.ingest_async(doc).await })
347 }
348
349 fn query(&self, query: &str, limit: usize) -> Result<Vec<KnowledgeResult>> {
350 let this = self.clone();
351 let query = query.to_string();
352 self.run_blocking(async move { this.query_async(&query, limit).await })
353 }
354}