Skip to main content

smooth_operator_adapter_postgres/
knowledge.rs

1//! pgvector-backed [`KnowledgeBase`] with hybrid dense + sparse retrieval.
2//!
3//! smooth-operator's [`KnowledgeBase`](smooth_operator_core::KnowledgeBase) trait is
4//! **synchronous** (the engine calls `ingest`/`query` directly), but both
5//! embedding and Postgres access are async here. We bridge the two by `spawn`ing
6//! the async work onto the captured runtime [`Handle`] (so its I/O makes
7//! progress on that runtime's reactor) and blocking the calling thread on the
8//! task's `JoinHandle` from a throwaway OS thread — never calling
9//! `Handle::block_on` on a runtime worker thread (which panics "Cannot start a
10//! runtime from within a runtime"). See [`PgKnowledgeBase::run_blocking`].
11//!
12//! ## Retrieval
13//!
14//! 1. **Dense**: embed the query, rank rows by pgvector cosine distance
15//!    (`embedding <=> $query`), take the top-K.
16//! 2. **Sparse**: `content_tsv @@ plainto_tsquery('english', $query)`, ranked by
17//!    `ts_rank`, top-K.
18//! 3. **Fuse**: Reciprocal Rank Fusion (RRF) over the two ranked lists —
19//!    `score = Σ 1/(k + rank)` (k=60) — then return the top-K fused chunks.
20//!
21//! This mirrors smooai's `knowledge_vectors` retrieval (dense HNSW ∪ sparse BM25
22//! → RRF).
23
24use std::collections::HashMap;
25use std::sync::Arc;
26
27use anyhow::{anyhow, Result};
28use deadpool_postgres::Pool;
29use tokio::runtime::Handle;
30
31use smooth_operator_core::{Document, KnowledgeBase, KnowledgeResult};
32
33use smooth_operator::access_control::{AccessContext, DocAcl};
34use smooth_operator::embedding::{Embedder, InputType};
35
36/// RRF constant. 60 is the canonical value from the original RRF paper; it
37/// damps the contribution of low-ranked items without ignoring them.
38const RRF_K: f32 = 60.0;
39
40/// pgvector knowledge base. Cheap to clone (all fields are `Arc`/pool handles).
41#[derive(Clone)]
42pub struct PgKnowledgeBase {
43    pool: Pool,
44    embedder: Arc<dyn Embedder>,
45    handle: Handle,
46    /// Optional org scoping. When set, ingest stamps and query filters on it.
47    organization_id: Option<String>,
48    /// Optional document-level access control (feature gap G3). When set, every
49    /// query filters rows by this requester's entitlements against the stored
50    /// `acl` column **in SQL** (so a restricted document is never even fetched).
51    /// `None` ⇒ no within-org ACL filtering (org isolation only) — the handle
52    /// returned by [`StorageAdapter::knowledge`]. The chat path uses
53    /// [`StorageAdapter::knowledge_for_access`], which sets this.
54    access: Option<AccessContext>,
55}
56
57impl PgKnowledgeBase {
58    pub(crate) fn new(
59        pool: Pool,
60        embedder: Arc<dyn Embedder>,
61        handle: Handle,
62        organization_id: Option<String>,
63    ) -> Self {
64        Self {
65            pool,
66            embedder,
67            handle,
68            organization_id,
69            access: None,
70        }
71    }
72
73    /// A clone of this knowledge base whose queries enforce the given
74    /// [`AccessContext`]'s document-level entitlements (in SQL, against the
75    /// stored `acl` column). Used by
76    /// [`PostgresAdapter::knowledge_for_access`](crate::PostgresAdapter) on the
77    /// chat retrieval path.
78    #[must_use]
79    pub fn with_access(&self, access: AccessContext) -> Self {
80        Self {
81            access: Some(access),
82            ..self.clone()
83        }
84    }
85
86    /// Format a vector as a pgvector literal: `[0.1,0.2,...]`.
87    fn vector_literal(v: &[f32]) -> String {
88        let mut s = String::with_capacity(v.len() * 8 + 2);
89        s.push('[');
90        for (i, x) in v.iter().enumerate() {
91            if i > 0 {
92                s.push(',');
93            }
94            s.push_str(&x.to_string());
95        }
96        s.push(']');
97        s
98    }
99
100    async fn ingest_async(&self, doc: Document) -> Result<()> {
101        let embeddings = self
102            .embedder
103            .embed(std::slice::from_ref(&doc.content), InputType::Document)
104            .await?;
105        let embedding = embeddings
106            .into_iter()
107            .next()
108            .ok_or_else(|| anyhow!("embedder returned no vector"))?;
109        let literal = Self::vector_literal(&embedding);
110        let metadata = serde_json::to_value(&doc.metadata)?;
111        // Persist the document's ACL (feature gap G3) as a discrete column so it
112        // survives the ingest→serve process boundary and can be filtered in SQL
113        // at read. Parsed from the same `acl_v2` metadata key the in-memory
114        // store records. `None` ⇒ NULL ⇒ org-public (backward-compatible).
115        let acl: Option<serde_json::Value> = DocAcl::from_metadata(&doc.metadata)
116            .map(|a| serde_json::to_value(&a))
117            .transpose()?;
118        // Stable per-chunk id: the document is stored as a single chunk keyed by
119        // its document id, so re-ingesting the same doc upserts in place.
120        let row_id = doc.id.clone();
121
122        let client = self.pool.get().await?;
123        client
124            .execute(
125                "INSERT INTO knowledge_vectors
126                    (id, document_id, organization_id, source, content, embedding, metadata, acl)
127                 VALUES ($1, $2, $3, $4, $5, $6::text::vector, $7, $8)
128                 ON CONFLICT (id) DO UPDATE SET
129                    document_id     = EXCLUDED.document_id,
130                    organization_id = EXCLUDED.organization_id,
131                    source          = EXCLUDED.source,
132                    content         = EXCLUDED.content,
133                    embedding       = EXCLUDED.embedding,
134                    metadata        = EXCLUDED.metadata,
135                    acl             = EXCLUDED.acl",
136                &[
137                    &row_id,
138                    &doc.id,
139                    &self.organization_id,
140                    &doc.source,
141                    &doc.content,
142                    &literal,
143                    &metadata,
144                    &acl,
145                ],
146            )
147            .await?;
148        Ok(())
149    }
150
151    async fn query_async(&self, query: &str, limit: usize) -> Result<Vec<KnowledgeResult>> {
152        let embeddings = self
153            .embedder
154            .embed(&[query.to_string()], InputType::Query)
155            .await?;
156        let embedding = embeddings
157            .into_iter()
158            .next()
159            .ok_or_else(|| anyhow!("embedder returned no query vector"))?;
160        let literal = Self::vector_literal(&embedding);
161
162        // Pull a generous candidate pool from each arm so RRF has something to
163        // fuse, then truncate after fusion.
164        let candidate_n: i64 = i64::try_from((limit * 4).max(20)).unwrap_or(20);
165        let client = self.pool.get().await?;
166
167        // --- ACL filter (feature gap G3) ---
168        //
169        // When this handle is access-bound, every row must pass the requester's
170        // document-level entitlement **in SQL** — a restricted document is never
171        // even fetched. A row is visible when ANY holds:
172        //   - `acl IS NULL`              → no ACL recorded ⇒ org-public default
173        //   - `acl->>'public'` is true   → explicitly public
174        //   - requester user id ∈ acl.users   (jsonb `?` key-exists)
175        //   - any requester group ∈ acl.groups (jsonb `?|` any-key-exists)
176        // `$4` is the requester user id (text, NULL ⇒ anonymous), `$5` the
177        // requester groups (text[]). Both are bound below. When the handle is
178        // NOT access-bound (`access` is None) the predicate is `TRUE` — org
179        // isolation only, no within-org filtering.
180        // Build the ACL predicate + the extra bound params ONLY when this handle
181        // is access-bound. Postgres rejects a prepared statement that binds a
182        // parameter the SQL never references, so the raw (org-isolation-only)
183        // path must not add `$4`/`$5`.
184        let acl_user: Option<String> = self.access.as_ref().and_then(|c| c.user_id.clone());
185        let acl_groups: Vec<String> = self
186            .access
187            .as_ref()
188            .map(|c| c.groups.clone())
189            .unwrap_or_default();
190        let acl_predicate = if self.access.is_some() {
191            // A row is visible when it has no recorded ACL (org-public), is
192            // explicitly public, names the requester's user id, or names any of
193            // the requester's groups. `?` / `?|` are jsonb key-exists operators.
194            "(acl IS NULL \
195              OR (acl->>'public')::boolean IS TRUE \
196              OR ($4::text IS NOT NULL AND acl->'users' ? $4) \
197              OR (acl->'groups' ?| $5::text[]))"
198        } else {
199            "TRUE"
200        };
201        // The query text as an owned param so the borrowed trait objects below
202        // don't tie the param vec to the input `&str` lifetime.
203        let query_owned = query.to_string();
204
205        // --- dense arm: cosine distance via pgvector `<=>` ---
206        let dense_sql = format!(
207            "SELECT id, document_id, source, content
208             FROM knowledge_vectors
209             WHERE ($1::text IS NULL OR organization_id = $1)
210               AND {acl_predicate}
211             ORDER BY embedding <=> $2::text::vector
212             LIMIT $3"
213        );
214        let mut dense_params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
215            vec![&self.organization_id, &literal, &candidate_n];
216        if self.access.is_some() {
217            dense_params.push(&acl_user);
218            dense_params.push(&acl_groups);
219        }
220        let dense_rows = client.query(&dense_sql, &dense_params).await?;
221
222        // --- sparse arm: tsvector BM25-style match, ranked by ts_rank ---
223        let sparse_sql = format!(
224            "SELECT id, document_id, source, content
225             FROM knowledge_vectors
226             WHERE ($1::text IS NULL OR organization_id = $1)
227               AND content_tsv @@ plainto_tsquery('english', $2)
228               AND {acl_predicate}
229             ORDER BY ts_rank(content_tsv, plainto_tsquery('english', $2)) DESC
230             LIMIT $3"
231        );
232        let mut sparse_params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
233            vec![&self.organization_id, &query_owned, &candidate_n];
234        if self.access.is_some() {
235            sparse_params.push(&acl_user);
236            sparse_params.push(&acl_groups);
237        }
238        let sparse_rows = client.query(&sparse_sql, &sparse_params).await?;
239
240        // --- Reciprocal Rank Fusion ---
241        struct Hit {
242            document_id: String,
243            source: String,
244            content: String,
245            score: f32,
246        }
247        let mut fused: HashMap<String, Hit> = HashMap::new();
248
249        let mut fuse = |rows: &[tokio_postgres::Row]| {
250            for (rank, row) in rows.iter().enumerate() {
251                let id: String = row.get(0);
252                let document_id: String = row.get(1);
253                let source: String = row.get(2);
254                let content: String = row.get(3);
255                #[allow(clippy::cast_precision_loss)]
256                let contribution = 1.0 / (RRF_K + (rank as f32) + 1.0);
257                fused
258                    .entry(id)
259                    .and_modify(|h| h.score += contribution)
260                    .or_insert(Hit {
261                        document_id,
262                        source,
263                        content,
264                        score: contribution,
265                    });
266            }
267        };
268        fuse(&dense_rows);
269        fuse(&sparse_rows);
270
271        let mut hits: Vec<Hit> = fused.into_values().collect();
272        hits.sort_by(|a, b| {
273            b.score
274                .partial_cmp(&a.score)
275                .unwrap_or(std::cmp::Ordering::Equal)
276        });
277        hits.truncate(limit);
278
279        Ok(hits
280            .into_iter()
281            .map(|h| KnowledgeResult {
282                document_id: h.document_id,
283                chunk: h.content,
284                score: h.score,
285                source: h.source,
286            })
287            .collect())
288    }
289}
290
291impl PgKnowledgeBase {
292    /// Drive an async future to completion from a *synchronous* trait method.
293    ///
294    /// `KnowledgeBase` is sync, but our work (embedding + deadpool) is async.
295    /// `Handle::block_on` can't be called from a runtime worker thread (it panics
296    /// "Cannot start a runtime from within a runtime"), and `block_in_place` only
297    /// relieves the *blocking-budget* concern, not that one. So we `spawn` the
298    /// future onto the runtime (where it can make progress) and block the calling
299    /// thread on a oneshot channel — wrapped in `block_in_place` when we happen to
300    /// be on a multi-thread worker so we don't starve the scheduler.
301    fn run_blocking<F, T>(&self, fut: F) -> Result<T>
302    where
303        F: std::future::Future<Output = Result<T>> + Send + 'static,
304        T: Send + 'static,
305    {
306        // Spawn the real work onto the captured runtime so its async I/O
307        // (deadpool, embedding HTTP) makes progress on that runtime's reactor.
308        let join = self.handle.spawn(fut);
309
310        // Block on the JoinHandle from a throwaway OS thread that owns a tiny
311        // current-thread runtime. This never calls `Handle::block_on` on a worker
312        // thread (which panics "Cannot start a runtime from within a runtime"),
313        // so it's safe whether the caller is on a runtime worker or not.
314        let (tx, rx) = std::sync::mpsc::channel();
315        std::thread::spawn(move || {
316            let result = (|| -> Result<T> {
317                let rt = tokio::runtime::Builder::new_current_thread()
318                    .enable_all()
319                    .build()?;
320                let joined = rt.block_on(join);
321                joined.map_err(|e| anyhow!("knowledge task panicked or was cancelled: {e}"))?
322            })();
323            let _ = tx.send(result);
324        });
325        rx.recv()
326            .map_err(|e| anyhow!("knowledge task channel closed: {e}"))?
327    }
328}
329
330impl KnowledgeBase for PgKnowledgeBase {
331    fn ingest(&self, doc: Document) -> Result<()> {
332        let this = self.clone();
333        self.run_blocking(async move { this.ingest_async(doc).await })
334    }
335
336    fn query(&self, query: &str, limit: usize) -> Result<Vec<KnowledgeResult>> {
337        let this = self.clone();
338        let query = query.to_string();
339        self.run_blocking(async move { this.query_async(&query, limit).await })
340    }
341}