Skip to main content

tinyquant_pgvector/
adapter.rs

1//! `PgvectorAdapter` — `SearchBackend` backed by a `PostgreSQL` `pgvector` table.
2//!
3//! # Feature flags
4//!
5//! The adapter is always present in the public API but only performs actual
6//! database operations when the `live-db` feature is enabled.  Without
7//! `live-db`, every method returns a descriptive `BackendError::Adapter`
8//! indicating that the feature is required.
9//!
10//! Integration tests requiring a live server are additionally gated behind
11//! the `test-containers` feature.
12
13use std::sync::Arc;
14
15use tinyquant_core::backend::{SearchBackend, SearchResult};
16use tinyquant_core::types::VectorId;
17
18use crate::errors::{adapter_err, BackendError};
19use crate::sql::validate_table_name;
20use crate::wire::encode_vector;
21
22/// A search backend that stores and queries vectors via pgvector.
23///
24/// Compile-time availability of the postgres connection depends on the
25/// `live-db` feature flag.  Without it, all methods return
26/// `BackendError::Adapter("live-db feature required")`.
27pub struct PgvectorAdapter {
28    /// Validated table name (safe for SQL interpolation).
29    table: String,
30    /// Dimension of the vector column (`None` until schema is ensured).
31    dim: Option<usize>,
32    /// Connection factory, only instantiated under `live-db`.
33    #[cfg(feature = "live-db")]
34    factory: Box<dyn Fn() -> Result<postgres::Client, postgres::Error> + Send + Sync>,
35}
36
37impl PgvectorAdapter {
38    /// Create a new adapter for the given table.
39    ///
40    /// # Errors
41    ///
42    /// Returns `Err` if `table` fails the allowlist regex validation.
43    #[cfg(feature = "live-db")]
44    pub fn new(
45        factory: impl Fn() -> Result<postgres::Client, postgres::Error> + Send + Sync + 'static,
46        table: &str,
47        dim: u32,
48    ) -> Result<Self, BackendError> {
49        let table = table.to_string();
50        validate_table_name(&table)?;
51        let mut adapter = Self {
52            table,
53            dim: None,
54            factory: Box::new(factory),
55        };
56        adapter.dim = Some(dim as usize);
57        Ok(adapter)
58    }
59
60    /// Create a new adapter for the given table (stub without `live-db`).
61    ///
62    /// Only validates the table name; all operational methods will return an
63    /// error until the `live-db` feature is enabled.
64    ///
65    /// # Errors
66    ///
67    /// Returns `Err` if `table` fails the allowlist regex validation.
68    #[cfg(not(feature = "live-db"))]
69    pub fn new(table: impl Into<String>) -> Result<Self, BackendError> {
70        let table = table.into();
71        validate_table_name(&table)?;
72        Ok(Self { table, dim: None })
73    }
74
75    /// Create the `vector` extension and the vectors table if they do not
76    /// exist.
77    ///
78    /// # Errors
79    ///
80    /// Returns `Err` when the `live-db` feature is disabled, or on any
81    /// `postgres::Error`.
82    pub fn ensure_schema(&mut self, dim: usize) -> Result<(), BackendError> {
83        #[cfg(not(feature = "live-db"))]
84        {
85            let _ = dim;
86            return Err(adapter_err(
87                "live-db feature required to connect to PostgreSQL",
88            ));
89        }
90        #[cfg(feature = "live-db")]
91        {
92            use crate::errors::from_pg;
93            let mut client = (self.factory)().map_err(from_pg)?;
94            client
95                .batch_execute("CREATE EXTENSION IF NOT EXISTS vector;")
96                .map_err(from_pg)?;
97            let sql = format!(
98                "CREATE TABLE IF NOT EXISTS {} (id TEXT PRIMARY KEY, embedding vector({}));",
99                self.table, dim
100            );
101            client.batch_execute(&sql).map_err(from_pg)?;
102            self.dim = Some(dim);
103            Ok(())
104        }
105    }
106
107    /// Create an approximate nearest-neighbour index on the embedding column.
108    ///
109    /// `lists` controls the number of `IVFFlat` lists.  If `0`, defaults to 100.
110    ///
111    /// # Errors
112    ///
113    /// Returns `Err` when the `live-db` feature is disabled, or on any
114    /// `postgres::Error`.
115    pub fn ensure_index(&self, lists: u32) -> Result<(), BackendError> {
116        #[cfg(not(feature = "live-db"))]
117        {
118            let _ = lists;
119            return Err(adapter_err(
120                "live-db feature required to connect to PostgreSQL",
121            ));
122        }
123        #[cfg(feature = "live-db")]
124        {
125            use crate::errors::from_pg;
126            let effective_lists = if lists == 0 { 100 } else { lists };
127            let mut client = (self.factory)().map_err(from_pg)?;
128            let sql = format!(
129                "CREATE INDEX IF NOT EXISTS {table}_embedding_idx \
130                 ON {table} USING ivfflat (embedding vector_cosine_ops) \
131                 WITH (lists = {lists});",
132                table = self.table,
133                lists = effective_lists
134            );
135            client.batch_execute(&sql).map_err(from_pg)
136        }
137    }
138
139    /// The table name used by this adapter.
140    pub fn table(&self) -> &str {
141        &self.table
142    }
143}
144
145impl SearchBackend for PgvectorAdapter {
146    fn ingest(&mut self, vectors: &[(VectorId, Vec<f32>)]) -> Result<(), BackendError> {
147        if vectors.is_empty() {
148            return Ok(());
149        }
150        // Dimension-lock check and wire encode validation happen before
151        // any DB connection so we can test them without `live-db`.
152        if let Some(expected) = self.dim {
153            for (_, v) in vectors {
154                if v.len() != expected {
155                    return Err(BackendError::Adapter(Arc::from(format!(
156                        "dimension mismatch: expected {expected}, got {}",
157                        v.len()
158                    ))));
159                }
160            }
161        }
162        // Encode each vector — rejects NaN/Inf before any DB call.
163        let encoded: Vec<(&VectorId, String)> = vectors
164            .iter()
165            .map(|(id, v)| encode_vector(v).map(|enc| (id, enc)))
166            .collect::<Result<_, _>>()?;
167
168        #[cfg(not(feature = "live-db"))]
169        {
170            let _ = encoded;
171            return Err(adapter_err(
172                "live-db feature required to connect to PostgreSQL",
173            ));
174        }
175        #[cfg(feature = "live-db")]
176        {
177            use crate::errors::from_pg;
178            let mut client = (self.factory)().map_err(from_pg)?;
179            let sql = format!(
180                "INSERT INTO {table} (id, embedding) VALUES ($1, $2::vector) \
181                 ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding;",
182                table = self.table
183            );
184            for (id, enc) in &encoded {
185                client
186                    .execute(sql.as_str(), &[&id.as_ref(), enc])
187                    .map_err(from_pg)?;
188            }
189            Ok(())
190        }
191    }
192
193    fn search(&self, query: &[f32], top_k: usize) -> Result<Vec<SearchResult>, BackendError> {
194        if top_k == 0 {
195            return Err(BackendError::InvalidTopK);
196        }
197        let encoded = encode_vector(query)?;
198
199        #[cfg(not(feature = "live-db"))]
200        {
201            let _ = encoded;
202            return Err(adapter_err(
203                "live-db feature required to connect to PostgreSQL",
204            ));
205        }
206        #[cfg(feature = "live-db")]
207        {
208            use crate::errors::from_pg;
209            let sql = format!(
210                "SELECT id, 1 - (embedding <=> $1::vector) AS score \
211                 FROM {table} \
212                 ORDER BY embedding <=> $1::vector \
213                 LIMIT $2;",
214                table = self.table
215            );
216            let mut client = (self.factory)().map_err(from_pg)?;
217            let top_k_i64 = i64::try_from(top_k).unwrap_or(i64::MAX);
218            let rows = client
219                .query(sql.as_str(), &[&encoded, &top_k_i64])
220                .map_err(from_pg)?;
221            let mut results = Vec::with_capacity(rows.len());
222            for row in rows {
223                let id: String = row.get(0);
224                let score: f64 = row.get(1);
225                #[allow(clippy::cast_possible_truncation)]
226                results.push(SearchResult {
227                    vector_id: Arc::from(id.as_str()),
228                    score: score as f32,
229                });
230            }
231            Ok(results)
232        }
233    }
234
235    fn remove(&mut self, vector_ids: &[VectorId]) -> Result<(), BackendError> {
236        if vector_ids.is_empty() {
237            return Ok(());
238        }
239        #[cfg(not(feature = "live-db"))]
240        {
241            return Err(adapter_err(
242                "live-db feature required to connect to PostgreSQL",
243            ));
244        }
245        #[cfg(feature = "live-db")]
246        {
247            use crate::errors::from_pg;
248            let mut client = (self.factory)().map_err(from_pg)?;
249            let sql = format!("DELETE FROM {table} WHERE id = $1;", table = self.table);
250            for id in vector_ids {
251                client
252                    .execute(sql.as_str(), &[&id.as_ref()])
253                    .map_err(from_pg)?;
254            }
255            Ok(())
256        }
257    }
258
259    fn len(&self) -> usize {
260        #[cfg(not(feature = "live-db"))]
261        {
262            0
263        }
264        #[cfg(feature = "live-db")]
265        {
266            let Ok(mut client) = (self.factory)() else {
267                return 0;
268            };
269            let sql = format!("SELECT COUNT(*) FROM {};", self.table);
270            let Ok(row) = client.query_one(sql.as_str(), &[]) else {
271                return 0;
272            };
273            let count: i64 = row.get(0);
274            usize::try_from(count).unwrap_or(0)
275        }
276    }
277
278    fn dim(&self) -> Option<usize> {
279        self.dim
280    }
281}