smooth_operator_adapter_postgres/knowledge.rs
1//! pgvector-backed [`KnowledgeBase`] with hybrid dense + sparse retrieval.
2//!
3//! smooth-operator's [`KnowledgeBase`](smooth_operator_core::KnowledgeBase) trait is
4//! **synchronous** (the engine calls `ingest`/`query` directly), but both
5//! embedding and Postgres access are async here. We bridge the two by `spawn`ing
6//! the async work onto the captured runtime [`Handle`] (so its I/O makes
7//! progress on that runtime's reactor) and blocking the calling thread on the
8//! task's `JoinHandle` from a throwaway OS thread — never calling
9//! `Handle::block_on` on a runtime worker thread (which panics "Cannot start a
10//! runtime from within a runtime"). See [`PgKnowledgeBase::run_blocking`].
11//!
12//! ## Retrieval
13//!
14//! 1. **Dense**: embed the query, rank rows by pgvector cosine distance
15//! (`embedding <=> $query`), take the top-K.
16//! 2. **Sparse**: `content_tsv @@ plainto_tsquery('english', $query)`, ranked by
17//! `ts_rank`, top-K.
18//! 3. **Fuse**: Reciprocal Rank Fusion (RRF) over the two ranked lists —
19//! `score = Σ 1/(k + rank)` (k=60) — then return the top-K fused chunks.
20//!
21//! This mirrors smooai's `knowledge_vectors` retrieval (dense HNSW ∪ sparse BM25
22//! → RRF).
23
24use std::collections::HashMap;
25use std::sync::Arc;
26
27use anyhow::{anyhow, Result};
28use deadpool_postgres::Pool;
29use tokio::runtime::Handle;
30
31use smooth_operator_core::{Document, KnowledgeBase, KnowledgeResult};
32
33use smooth_operator::access_control::{AccessContext, DocAcl};
34use smooth_operator::embedding::{Embedder, InputType};
35
36/// RRF constant. 60 is the canonical value from the original RRF paper; it
37/// damps the contribution of low-ranked items without ignoring them.
38const RRF_K: f32 = 60.0;
39
40/// pgvector knowledge base. Cheap to clone (all fields are `Arc`/pool handles).
41#[derive(Clone)]
42pub struct PgKnowledgeBase {
43 pool: Pool,
44 embedder: Arc<dyn Embedder>,
45 handle: Handle,
46 /// Optional org scoping. When set, ingest stamps and query filters on it.
47 organization_id: Option<String>,
48 /// Optional document-level access control (feature gap G3). When set, every
49 /// query filters rows by this requester's entitlements against the stored
50 /// `acl` column **in SQL** (so a restricted document is never even fetched).
51 /// `None` ⇒ no within-org ACL filtering (org isolation only) — the handle
52 /// returned by [`StorageAdapter::knowledge`]. The chat path uses
53 /// [`StorageAdapter::knowledge_for_access`], which sets this.
54 access: Option<AccessContext>,
55}
56
57impl PgKnowledgeBase {
58 pub(crate) fn new(
59 pool: Pool,
60 embedder: Arc<dyn Embedder>,
61 handle: Handle,
62 organization_id: Option<String>,
63 ) -> Self {
64 Self {
65 pool,
66 embedder,
67 handle,
68 organization_id,
69 access: None,
70 }
71 }
72
73 /// A clone of this knowledge base whose queries enforce the given
74 /// [`AccessContext`]'s document-level entitlements (in SQL, against the
75 /// stored `acl` column). Used by
76 /// [`PostgresAdapter::knowledge_for_access`](crate::PostgresAdapter) on the
77 /// chat retrieval path.
78 #[must_use]
79 pub fn with_access(&self, access: AccessContext) -> Self {
80 Self {
81 access: Some(access),
82 ..self.clone()
83 }
84 }
85
86 /// Format a vector as a pgvector literal: `[0.1,0.2,...]`.
87 fn vector_literal(v: &[f32]) -> String {
88 let mut s = String::with_capacity(v.len() * 8 + 2);
89 s.push('[');
90 for (i, x) in v.iter().enumerate() {
91 if i > 0 {
92 s.push(',');
93 }
94 s.push_str(&x.to_string());
95 }
96 s.push(']');
97 s
98 }
99
100 async fn ingest_async(&self, doc: Document) -> Result<()> {
101 let embeddings = self
102 .embedder
103 .embed(std::slice::from_ref(&doc.content), InputType::Document)
104 .await?;
105 let embedding = embeddings
106 .into_iter()
107 .next()
108 .ok_or_else(|| anyhow!("embedder returned no vector"))?;
109 let literal = Self::vector_literal(&embedding);
110 let metadata = serde_json::to_value(&doc.metadata)?;
111 // Persist the document's ACL (feature gap G3) as a discrete column so it
112 // survives the ingest→serve process boundary and can be filtered in SQL
113 // at read. Parsed from the same `acl_v2` metadata key the in-memory
114 // store records. `None` ⇒ NULL ⇒ org-public (backward-compatible).
115 let acl: Option<serde_json::Value> = DocAcl::from_metadata(&doc.metadata)
116 .map(|a| serde_json::to_value(&a))
117 .transpose()?;
118 // Stable per-chunk id: the document is stored as a single chunk keyed by
119 // its document id, so re-ingesting the same doc upserts in place.
120 let row_id = doc.id.clone();
121
122 let client = self.pool.get().await?;
123 client
124 .execute(
125 "INSERT INTO knowledge_vectors
126 (id, document_id, organization_id, source, content, embedding, metadata, acl)
127 VALUES ($1, $2, $3, $4, $5, $6::text::vector, $7, $8)
128 ON CONFLICT (id) DO UPDATE SET
129 document_id = EXCLUDED.document_id,
130 organization_id = EXCLUDED.organization_id,
131 source = EXCLUDED.source,
132 content = EXCLUDED.content,
133 embedding = EXCLUDED.embedding,
134 metadata = EXCLUDED.metadata,
135 acl = EXCLUDED.acl",
136 &[
137 &row_id,
138 &doc.id,
139 &self.organization_id,
140 &doc.source,
141 &doc.content,
142 &literal,
143 &metadata,
144 &acl,
145 ],
146 )
147 .await?;
148 Ok(())
149 }
150
151 async fn query_async(&self, query: &str, limit: usize) -> Result<Vec<KnowledgeResult>> {
152 let embeddings = self
153 .embedder
154 .embed(&[query.to_string()], InputType::Query)
155 .await?;
156 let embedding = embeddings
157 .into_iter()
158 .next()
159 .ok_or_else(|| anyhow!("embedder returned no query vector"))?;
160 let literal = Self::vector_literal(&embedding);
161
162 // Pull a generous candidate pool from each arm so RRF has something to
163 // fuse, then truncate after fusion.
164 let candidate_n: i64 = i64::try_from((limit * 4).max(20)).unwrap_or(20);
165 let client = self.pool.get().await?;
166
167 // --- ACL filter (feature gap G3) ---
168 //
169 // When this handle is access-bound, every row must pass the requester's
170 // document-level entitlement **in SQL** — a restricted document is never
171 // even fetched. A row is visible when ANY holds:
172 // - `acl IS NULL` → no ACL recorded ⇒ org-public default
173 // - `acl->>'public'` is true → explicitly public
174 // - requester user id ∈ acl.users (jsonb `?` key-exists)
175 // - any requester group ∈ acl.groups (jsonb `?|` any-key-exists)
176 // `$4` is the requester user id (text, NULL ⇒ anonymous), `$5` the
177 // requester groups (text[]). Both are bound below. When the handle is
178 // NOT access-bound (`access` is None) the predicate is `TRUE` — org
179 // isolation only, no within-org filtering.
180 // Build the ACL predicate + the extra bound params ONLY when this handle
181 // is access-bound. Postgres rejects a prepared statement that binds a
182 // parameter the SQL never references, so the raw (org-isolation-only)
183 // path must not add `$4`/`$5`.
184 let acl_user: Option<String> = self.access.as_ref().and_then(|c| c.user_id.clone());
185 let acl_groups: Vec<String> = self
186 .access
187 .as_ref()
188 .map(|c| c.groups.clone())
189 .unwrap_or_default();
190 let acl_predicate = if self.access.is_some() {
191 // A row is visible when it has no recorded ACL (org-public), is
192 // explicitly public, names the requester's user id, or names any of
193 // the requester's groups. `?` / `?|` are jsonb key-exists operators.
194 "(acl IS NULL \
195 OR (acl->>'public')::boolean IS TRUE \
196 OR ($4::text IS NOT NULL AND acl->'users' ? $4) \
197 OR (acl->'groups' ?| $5::text[]))"
198 } else {
199 "TRUE"
200 };
201 // The query text as an owned param so the borrowed trait objects below
202 // don't tie the param vec to the input `&str` lifetime.
203 let query_owned = query.to_string();
204
205 // --- dense arm: cosine distance via pgvector `<=>` ---
206 let dense_sql = format!(
207 "SELECT id, document_id, source, content
208 FROM knowledge_vectors
209 WHERE ($1::text IS NULL OR organization_id = $1)
210 AND {acl_predicate}
211 ORDER BY embedding <=> $2::text::vector
212 LIMIT $3"
213 );
214 let mut dense_params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
215 vec![&self.organization_id, &literal, &candidate_n];
216 if self.access.is_some() {
217 dense_params.push(&acl_user);
218 dense_params.push(&acl_groups);
219 }
220 let dense_rows = client.query(&dense_sql, &dense_params).await?;
221
222 // --- sparse arm: tsvector BM25-style match, ranked by ts_rank ---
223 let sparse_sql = format!(
224 "SELECT id, document_id, source, content
225 FROM knowledge_vectors
226 WHERE ($1::text IS NULL OR organization_id = $1)
227 AND content_tsv @@ plainto_tsquery('english', $2)
228 AND {acl_predicate}
229 ORDER BY ts_rank(content_tsv, plainto_tsquery('english', $2)) DESC
230 LIMIT $3"
231 );
232 let mut sparse_params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
233 vec![&self.organization_id, &query_owned, &candidate_n];
234 if self.access.is_some() {
235 sparse_params.push(&acl_user);
236 sparse_params.push(&acl_groups);
237 }
238 let sparse_rows = client.query(&sparse_sql, &sparse_params).await?;
239
240 // --- Reciprocal Rank Fusion ---
241 struct Hit {
242 document_id: String,
243 source: String,
244 content: String,
245 score: f32,
246 }
247 let mut fused: HashMap<String, Hit> = HashMap::new();
248
249 let mut fuse = |rows: &[tokio_postgres::Row]| {
250 for (rank, row) in rows.iter().enumerate() {
251 let id: String = row.get(0);
252 let document_id: String = row.get(1);
253 let source: String = row.get(2);
254 let content: String = row.get(3);
255 #[allow(clippy::cast_precision_loss)]
256 let contribution = 1.0 / (RRF_K + (rank as f32) + 1.0);
257 fused
258 .entry(id)
259 .and_modify(|h| h.score += contribution)
260 .or_insert(Hit {
261 document_id,
262 source,
263 content,
264 score: contribution,
265 });
266 }
267 };
268 fuse(&dense_rows);
269 fuse(&sparse_rows);
270
271 let mut hits: Vec<Hit> = fused.into_values().collect();
272 hits.sort_by(|a, b| {
273 b.score
274 .partial_cmp(&a.score)
275 .unwrap_or(std::cmp::Ordering::Equal)
276 });
277 hits.truncate(limit);
278
279 Ok(hits
280 .into_iter()
281 .map(|h| KnowledgeResult {
282 document_id: h.document_id,
283 chunk: h.content,
284 score: h.score,
285 source: h.source,
286 })
287 .collect())
288 }
289}
290
291impl PgKnowledgeBase {
292 /// Drive an async future to completion from a *synchronous* trait method.
293 ///
294 /// `KnowledgeBase` is sync, but our work (embedding + deadpool) is async.
295 /// `Handle::block_on` can't be called from a runtime worker thread (it panics
296 /// "Cannot start a runtime from within a runtime"), and `block_in_place` only
297 /// relieves the *blocking-budget* concern, not that one. So we `spawn` the
298 /// future onto the runtime (where it can make progress) and block the calling
299 /// thread on a oneshot channel — wrapped in `block_in_place` when we happen to
300 /// be on a multi-thread worker so we don't starve the scheduler.
301 fn run_blocking<F, T>(&self, fut: F) -> Result<T>
302 where
303 F: std::future::Future<Output = Result<T>> + Send + 'static,
304 T: Send + 'static,
305 {
306 // Spawn the real work onto the captured runtime so its async I/O
307 // (deadpool, embedding HTTP) makes progress on that runtime's reactor.
308 let join = self.handle.spawn(fut);
309
310 // Block on the JoinHandle from a throwaway OS thread that owns a tiny
311 // current-thread runtime. This never calls `Handle::block_on` on a worker
312 // thread (which panics "Cannot start a runtime from within a runtime"),
313 // so it's safe whether the caller is on a runtime worker or not.
314 let (tx, rx) = std::sync::mpsc::channel();
315 std::thread::spawn(move || {
316 let result = (|| -> Result<T> {
317 let rt = tokio::runtime::Builder::new_current_thread()
318 .enable_all()
319 .build()?;
320 let joined = rt.block_on(join);
321 joined.map_err(|e| anyhow!("knowledge task panicked or was cancelled: {e}"))?
322 })();
323 let _ = tx.send(result);
324 });
325 rx.recv()
326 .map_err(|e| anyhow!("knowledge task channel closed: {e}"))?
327 }
328}
329
330impl KnowledgeBase for PgKnowledgeBase {
331 fn ingest(&self, doc: Document) -> Result<()> {
332 let this = self.clone();
333 self.run_blocking(async move { this.ingest_async(doc).await })
334 }
335
336 fn query(&self, query: &str, limit: usize) -> Result<Vec<KnowledgeResult>> {
337 let this = self.clone();
338 let query = query.to_string();
339 self.run_blocking(async move { this.query_async(&query, limit).await })
340 }
341}