Skip to main content

smooth_operator_adapter_postgres/
knowledge.rs

1//! pgvector-backed [`KnowledgeBase`] with hybrid dense + sparse retrieval.
2//!
3//! smooth-operator's [`KnowledgeBase`](smooth_operator_core::KnowledgeBase) trait is
4//! **synchronous** (the engine calls `ingest`/`query` directly), but both
5//! embedding and Postgres access are async here. We bridge the two by `spawn`ing
6//! the async work onto the captured runtime [`Handle`] (so its I/O makes
7//! progress on that runtime's reactor) and blocking the calling thread on the
8//! task's `JoinHandle` from a throwaway OS thread — never calling
9//! `Handle::block_on` on a runtime worker thread (which panics "Cannot start a
10//! runtime from within a runtime"). See [`PgKnowledgeBase::run_blocking`].
11//!
12//! ## Retrieval
13//!
14//! 1. **Dense**: embed the query, rank rows by pgvector cosine distance
15//!    (`embedding <=> $query`), take the top-K.
16//! 2. **Sparse**: `content_tsv @@ plainto_tsquery('english', $query)`, ranked by
17//!    `ts_rank`, top-K.
18//! 3. **Fuse**: Reciprocal Rank Fusion (RRF) over the two ranked lists —
19//!    `score = Σ 1/(k + rank)` (k=60) — then return the top-K fused chunks.
20//!
21//! This mirrors smooai's `knowledge_vectors` retrieval (dense HNSW ∪ sparse BM25
22//! → RRF).
23
24use std::collections::HashMap;
25use std::sync::Arc;
26
27use anyhow::{anyhow, Result};
28use deadpool_postgres::Pool;
29use tokio::runtime::Handle;
30
31use smooth_operator_core::{Document, KnowledgeBase, KnowledgeResult};
32
33use smooth_operator::access_control::{AccessContext, DocAcl};
34use smooth_operator::embedding::{Embedder, InputType};
35
36/// RRF constant. 60 is the canonical value from the original RRF paper; it
37/// damps the contribution of low-ranked items without ignoring them.
38const RRF_K: f32 = 60.0;
39
40/// pgvector knowledge base. Cheap to clone (all fields are `Arc`/pool handles).
41#[derive(Clone)]
42pub struct PgKnowledgeBase {
43    pool: Pool,
44    embedder: Arc<dyn Embedder>,
45    handle: Handle,
46    /// Optional org scoping. When set, ingest stamps and query filters on it.
47    organization_id: Option<String>,
48    /// Optional document-level access control (feature gap G3). When set, every
49    /// query filters rows by this requester's entitlements against the stored
50    /// `acl` column **in SQL** (so a restricted document is never even fetched).
51    /// `None` ⇒ no within-org ACL filtering (org isolation only) — the handle
52    /// returned by [`StorageAdapter::knowledge`]. The chat path uses
53    /// [`StorageAdapter::knowledge_for_access`], which sets this.
54    access: Option<AccessContext>,
55}
56
57impl PgKnowledgeBase {
58    pub(crate) fn new(
59        pool: Pool,
60        embedder: Arc<dyn Embedder>,
61        handle: Handle,
62        organization_id: Option<String>,
63    ) -> Self {
64        Self {
65            pool,
66            embedder,
67            handle,
68            organization_id,
69            access: None,
70        }
71    }
72
73    /// A clone of this knowledge base whose queries enforce the given
74    /// [`AccessContext`]'s document-level entitlements (in SQL, against the
75    /// stored `acl` column). Used by
76    /// [`PostgresAdapter::knowledge_for_access`](crate::PostgresAdapter) on the
77    /// chat retrieval path.
78    ///
79    /// When the context carries an [`organization_id`](AccessContext::organization_id)
80    /// (a multi-tenant host threads the turn's org through), it **overrides** the
81    /// adapter's construction-time org for this query — so one adapter instance can
82    /// serve per-turn tenants instead of being pinned to a single static org. The
83    /// org is still a cheap SQL pre-filter (`organization_id = $1`). A context with
84    /// no org leaves the construction-time org unchanged, so the single-tenant
85    /// default is behavior-preserving.
86    #[must_use]
87    pub fn with_access(&self, access: AccessContext) -> Self {
88        let organization_id = access
89            .organization_id
90            .clone()
91            .or_else(|| self.organization_id.clone());
92        Self {
93            organization_id,
94            access: Some(access),
95            ..self.clone()
96        }
97    }
98
99    /// Format a vector as a pgvector literal: `[0.1,0.2,...]`.
100    fn vector_literal(v: &[f32]) -> String {
101        let mut s = String::with_capacity(v.len() * 8 + 2);
102        s.push('[');
103        for (i, x) in v.iter().enumerate() {
104            if i > 0 {
105                s.push(',');
106            }
107            s.push_str(&x.to_string());
108        }
109        s.push(']');
110        s
111    }
112
113    async fn ingest_async(&self, doc: Document) -> Result<()> {
114        let embeddings = self
115            .embedder
116            .embed(std::slice::from_ref(&doc.content), InputType::Document)
117            .await?;
118        let embedding = embeddings
119            .into_iter()
120            .next()
121            .ok_or_else(|| anyhow!("embedder returned no vector"))?;
122        let literal = Self::vector_literal(&embedding);
123        let metadata = serde_json::to_value(&doc.metadata)?;
124        // Persist the document's ACL (feature gap G3) as a discrete column so it
125        // survives the ingest→serve process boundary and can be filtered in SQL
126        // at read. Parsed from the same `acl_v2` metadata key the in-memory
127        // store records. `None` ⇒ NULL ⇒ org-public (backward-compatible).
128        let acl: Option<serde_json::Value> = DocAcl::from_metadata(&doc.metadata)
129            .map(|a| serde_json::to_value(&a))
130            .transpose()?;
131        // Stable per-chunk id: the document is stored as a single chunk keyed by
132        // its document id, so re-ingesting the same doc upserts in place.
133        let row_id = doc.id.clone();
134
135        let client = self.pool.get().await?;
136        client
137            .execute(
138                "INSERT INTO knowledge_vectors
139                    (id, document_id, organization_id, source, content, embedding, metadata, acl)
140                 VALUES ($1, $2, $3, $4, $5, $6::text::vector, $7, $8)
141                 ON CONFLICT (id) DO UPDATE SET
142                    document_id     = EXCLUDED.document_id,
143                    organization_id = EXCLUDED.organization_id,
144                    source          = EXCLUDED.source,
145                    content         = EXCLUDED.content,
146                    embedding       = EXCLUDED.embedding,
147                    metadata        = EXCLUDED.metadata,
148                    acl             = EXCLUDED.acl",
149                &[
150                    &row_id,
151                    &doc.id,
152                    &self.organization_id,
153                    &doc.source,
154                    &doc.content,
155                    &literal,
156                    &metadata,
157                    &acl,
158                ],
159            )
160            .await?;
161        Ok(())
162    }
163
164    async fn query_async(&self, query: &str, limit: usize) -> Result<Vec<KnowledgeResult>> {
165        let embeddings = self
166            .embedder
167            .embed(&[query.to_string()], InputType::Query)
168            .await?;
169        let embedding = embeddings
170            .into_iter()
171            .next()
172            .ok_or_else(|| anyhow!("embedder returned no query vector"))?;
173        let literal = Self::vector_literal(&embedding);
174
175        // Pull a generous candidate pool from each arm so RRF has something to
176        // fuse, then truncate after fusion.
177        let candidate_n: i64 = i64::try_from((limit * 4).max(20)).unwrap_or(20);
178        let client = self.pool.get().await?;
179
180        // --- ACL filter (feature gap G3) ---
181        //
182        // When this handle is access-bound, every row must pass the requester's
183        // document-level entitlement **in SQL** — a restricted document is never
184        // even fetched. A row is visible when ANY holds:
185        //   - `acl IS NULL`              → no ACL recorded ⇒ org-public default
186        //   - `acl->>'public'` is true   → explicitly public
187        //   - requester user id ∈ acl.users   (jsonb `?` key-exists)
188        //   - any requester group ∈ acl.groups (jsonb `?|` any-key-exists)
189        // `$4` is the requester user id (text, NULL ⇒ anonymous), `$5` the
190        // requester groups (text[]). Both are bound below. When the handle is
191        // NOT access-bound (`access` is None) the predicate is `TRUE` — org
192        // isolation only, no within-org filtering.
193        // Build the ACL predicate + the extra bound params ONLY when this handle
194        // is access-bound. Postgres rejects a prepared statement that binds a
195        // parameter the SQL never references, so the raw (org-isolation-only)
196        // path must not add `$4`/`$5`.
197        let acl_user: Option<String> = self.access.as_ref().and_then(|c| c.user_id.clone());
198        let acl_groups: Vec<String> = self
199            .access
200            .as_ref()
201            .map(|c| c.groups.clone())
202            .unwrap_or_default();
203        let acl_predicate = if self.access.is_some() {
204            // A row is visible when it has no recorded ACL (org-public), is
205            // explicitly public, names the requester's user id, or names any of
206            // the requester's groups. `?` / `?|` are jsonb key-exists operators.
207            "(acl IS NULL \
208              OR (acl->>'public')::boolean IS TRUE \
209              OR ($4::text IS NOT NULL AND acl->'users' ? $4) \
210              OR (acl->'groups' ?| $5::text[]))"
211        } else {
212            "TRUE"
213        };
214        // The query text as an owned param so the borrowed trait objects below
215        // don't tie the param vec to the input `&str` lifetime.
216        let query_owned = query.to_string();
217
218        // --- dense arm: cosine distance via pgvector `<=>` ---
219        let dense_sql = format!(
220            "SELECT id, document_id, source, content
221             FROM knowledge_vectors
222             WHERE ($1::text IS NULL OR organization_id = $1)
223               AND {acl_predicate}
224             ORDER BY embedding <=> $2::text::vector
225             LIMIT $3"
226        );
227        let mut dense_params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
228            vec![&self.organization_id, &literal, &candidate_n];
229        if self.access.is_some() {
230            dense_params.push(&acl_user);
231            dense_params.push(&acl_groups);
232        }
233        let dense_rows = client.query(&dense_sql, &dense_params).await?;
234
235        // --- sparse arm: tsvector BM25-style match, ranked by ts_rank ---
236        let sparse_sql = format!(
237            "SELECT id, document_id, source, content
238             FROM knowledge_vectors
239             WHERE ($1::text IS NULL OR organization_id = $1)
240               AND content_tsv @@ plainto_tsquery('english', $2)
241               AND {acl_predicate}
242             ORDER BY ts_rank(content_tsv, plainto_tsquery('english', $2)) DESC
243             LIMIT $3"
244        );
245        let mut sparse_params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
246            vec![&self.organization_id, &query_owned, &candidate_n];
247        if self.access.is_some() {
248            sparse_params.push(&acl_user);
249            sparse_params.push(&acl_groups);
250        }
251        let sparse_rows = client.query(&sparse_sql, &sparse_params).await?;
252
253        // --- Reciprocal Rank Fusion ---
254        struct Hit {
255            document_id: String,
256            source: String,
257            content: String,
258            score: f32,
259        }
260        let mut fused: HashMap<String, Hit> = HashMap::new();
261
262        let mut fuse = |rows: &[tokio_postgres::Row]| {
263            for (rank, row) in rows.iter().enumerate() {
264                let id: String = row.get(0);
265                let document_id: String = row.get(1);
266                let source: String = row.get(2);
267                let content: String = row.get(3);
268                #[allow(clippy::cast_precision_loss)]
269                let contribution = 1.0 / (RRF_K + (rank as f32) + 1.0);
270                fused
271                    .entry(id)
272                    .and_modify(|h| h.score += contribution)
273                    .or_insert(Hit {
274                        document_id,
275                        source,
276                        content,
277                        score: contribution,
278                    });
279            }
280        };
281        fuse(&dense_rows);
282        fuse(&sparse_rows);
283
284        let mut hits: Vec<Hit> = fused.into_values().collect();
285        hits.sort_by(|a, b| {
286            b.score
287                .partial_cmp(&a.score)
288                .unwrap_or(std::cmp::Ordering::Equal)
289        });
290        hits.truncate(limit);
291
292        Ok(hits
293            .into_iter()
294            .map(|h| KnowledgeResult {
295                document_id: h.document_id,
296                chunk: h.content,
297                score: h.score,
298                source: h.source,
299            })
300            .collect())
301    }
302}
303
304impl PgKnowledgeBase {
305    /// Drive an async future to completion from a *synchronous* trait method.
306    ///
307    /// `KnowledgeBase` is sync, but our work (embedding + deadpool) is async.
308    /// `Handle::block_on` can't be called from a runtime worker thread (it panics
309    /// "Cannot start a runtime from within a runtime"), and `block_in_place` only
310    /// relieves the *blocking-budget* concern, not that one. So we `spawn` the
311    /// future onto the runtime (where it can make progress) and block the calling
312    /// thread on a oneshot channel — wrapped in `block_in_place` when we happen to
313    /// be on a multi-thread worker so we don't starve the scheduler.
314    fn run_blocking<F, T>(&self, fut: F) -> Result<T>
315    where
316        F: std::future::Future<Output = Result<T>> + Send + 'static,
317        T: Send + 'static,
318    {
319        // Spawn the real work onto the captured runtime so its async I/O
320        // (deadpool, embedding HTTP) makes progress on that runtime's reactor.
321        let join = self.handle.spawn(fut);
322
323        // Block on the JoinHandle from a throwaway OS thread that owns a tiny
324        // current-thread runtime. This never calls `Handle::block_on` on a worker
325        // thread (which panics "Cannot start a runtime from within a runtime"),
326        // so it's safe whether the caller is on a runtime worker or not.
327        let (tx, rx) = std::sync::mpsc::channel();
328        std::thread::spawn(move || {
329            let result = (|| -> Result<T> {
330                let rt = tokio::runtime::Builder::new_current_thread()
331                    .enable_all()
332                    .build()?;
333                let joined = rt.block_on(join);
334                joined.map_err(|e| anyhow!("knowledge task panicked or was cancelled: {e}"))?
335            })();
336            let _ = tx.send(result);
337        });
338        rx.recv()
339            .map_err(|e| anyhow!("knowledge task channel closed: {e}"))?
340    }
341}
342
343impl KnowledgeBase for PgKnowledgeBase {
344    fn ingest(&self, doc: Document) -> Result<()> {
345        let this = self.clone();
346        self.run_blocking(async move { this.ingest_async(doc).await })
347    }
348
349    fn query(&self, query: &str, limit: usize) -> Result<Vec<KnowledgeResult>> {
350        let this = self.clone();
351        let query = query.to_string();
352        self.run_blocking(async move { this.query_async(&query, limit).await })
353    }
354}