tinyquant_pgvector/
adapter.rs1use std::sync::Arc;
14
15use tinyquant_core::backend::{SearchBackend, SearchResult};
16use tinyquant_core::types::VectorId;
17
18use crate::errors::{adapter_err, BackendError};
19use crate::sql::validate_table_name;
20use crate::wire::encode_vector;
21
22pub struct PgvectorAdapter {
28 table: String,
30 dim: Option<usize>,
32 #[cfg(feature = "live-db")]
34 factory: Box<dyn Fn() -> Result<postgres::Client, postgres::Error> + Send + Sync>,
35}
36
37impl PgvectorAdapter {
38 #[cfg(feature = "live-db")]
44 pub fn new(
45 factory: impl Fn() -> Result<postgres::Client, postgres::Error> + Send + Sync + 'static,
46 table: &str,
47 dim: u32,
48 ) -> Result<Self, BackendError> {
49 let table = table.to_string();
50 validate_table_name(&table)?;
51 let mut adapter = Self {
52 table,
53 dim: None,
54 factory: Box::new(factory),
55 };
56 adapter.dim = Some(dim as usize);
57 Ok(adapter)
58 }
59
60 #[cfg(not(feature = "live-db"))]
69 pub fn new(table: impl Into<String>) -> Result<Self, BackendError> {
70 let table = table.into();
71 validate_table_name(&table)?;
72 Ok(Self { table, dim: None })
73 }
74
75 pub fn ensure_schema(&mut self, dim: usize) -> Result<(), BackendError> {
83 #[cfg(not(feature = "live-db"))]
84 {
85 let _ = dim;
86 return Err(adapter_err(
87 "live-db feature required to connect to PostgreSQL",
88 ));
89 }
90 #[cfg(feature = "live-db")]
91 {
92 use crate::errors::from_pg;
93 let mut client = (self.factory)().map_err(from_pg)?;
94 client
95 .batch_execute("CREATE EXTENSION IF NOT EXISTS vector;")
96 .map_err(from_pg)?;
97 let sql = format!(
98 "CREATE TABLE IF NOT EXISTS {} (id TEXT PRIMARY KEY, embedding vector({}));",
99 self.table, dim
100 );
101 client.batch_execute(&sql).map_err(from_pg)?;
102 self.dim = Some(dim);
103 Ok(())
104 }
105 }
106
107 pub fn ensure_index(&self, lists: u32) -> Result<(), BackendError> {
116 #[cfg(not(feature = "live-db"))]
117 {
118 let _ = lists;
119 return Err(adapter_err(
120 "live-db feature required to connect to PostgreSQL",
121 ));
122 }
123 #[cfg(feature = "live-db")]
124 {
125 use crate::errors::from_pg;
126 let effective_lists = if lists == 0 { 100 } else { lists };
127 let mut client = (self.factory)().map_err(from_pg)?;
128 let sql = format!(
129 "CREATE INDEX IF NOT EXISTS {table}_embedding_idx \
130 ON {table} USING ivfflat (embedding vector_cosine_ops) \
131 WITH (lists = {lists});",
132 table = self.table,
133 lists = effective_lists
134 );
135 client.batch_execute(&sql).map_err(from_pg)
136 }
137 }
138
139 pub fn table(&self) -> &str {
141 &self.table
142 }
143}
144
145impl SearchBackend for PgvectorAdapter {
146 fn ingest(&mut self, vectors: &[(VectorId, Vec<f32>)]) -> Result<(), BackendError> {
147 if vectors.is_empty() {
148 return Ok(());
149 }
150 if let Some(expected) = self.dim {
153 for (_, v) in vectors {
154 if v.len() != expected {
155 return Err(BackendError::Adapter(Arc::from(format!(
156 "dimension mismatch: expected {expected}, got {}",
157 v.len()
158 ))));
159 }
160 }
161 }
162 let encoded: Vec<(&VectorId, String)> = vectors
164 .iter()
165 .map(|(id, v)| encode_vector(v).map(|enc| (id, enc)))
166 .collect::<Result<_, _>>()?;
167
168 #[cfg(not(feature = "live-db"))]
169 {
170 let _ = encoded;
171 return Err(adapter_err(
172 "live-db feature required to connect to PostgreSQL",
173 ));
174 }
175 #[cfg(feature = "live-db")]
176 {
177 use crate::errors::from_pg;
178 let mut client = (self.factory)().map_err(from_pg)?;
179 let sql = format!(
180 "INSERT INTO {table} (id, embedding) VALUES ($1, $2::vector) \
181 ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding;",
182 table = self.table
183 );
184 for (id, enc) in &encoded {
185 client
186 .execute(sql.as_str(), &[&id.as_ref(), enc])
187 .map_err(from_pg)?;
188 }
189 Ok(())
190 }
191 }
192
193 fn search(&self, query: &[f32], top_k: usize) -> Result<Vec<SearchResult>, BackendError> {
194 if top_k == 0 {
195 return Err(BackendError::InvalidTopK);
196 }
197 let encoded = encode_vector(query)?;
198
199 #[cfg(not(feature = "live-db"))]
200 {
201 let _ = encoded;
202 return Err(adapter_err(
203 "live-db feature required to connect to PostgreSQL",
204 ));
205 }
206 #[cfg(feature = "live-db")]
207 {
208 use crate::errors::from_pg;
209 let sql = format!(
210 "SELECT id, 1 - (embedding <=> $1::vector) AS score \
211 FROM {table} \
212 ORDER BY embedding <=> $1::vector \
213 LIMIT $2;",
214 table = self.table
215 );
216 let mut client = (self.factory)().map_err(from_pg)?;
217 let top_k_i64 = i64::try_from(top_k).unwrap_or(i64::MAX);
218 let rows = client
219 .query(sql.as_str(), &[&encoded, &top_k_i64])
220 .map_err(from_pg)?;
221 let mut results = Vec::with_capacity(rows.len());
222 for row in rows {
223 let id: String = row.get(0);
224 let score: f64 = row.get(1);
225 #[allow(clippy::cast_possible_truncation)]
226 results.push(SearchResult {
227 vector_id: Arc::from(id.as_str()),
228 score: score as f32,
229 });
230 }
231 Ok(results)
232 }
233 }
234
235 fn remove(&mut self, vector_ids: &[VectorId]) -> Result<(), BackendError> {
236 if vector_ids.is_empty() {
237 return Ok(());
238 }
239 #[cfg(not(feature = "live-db"))]
240 {
241 return Err(adapter_err(
242 "live-db feature required to connect to PostgreSQL",
243 ));
244 }
245 #[cfg(feature = "live-db")]
246 {
247 use crate::errors::from_pg;
248 let mut client = (self.factory)().map_err(from_pg)?;
249 let sql = format!("DELETE FROM {table} WHERE id = $1;", table = self.table);
250 for id in vector_ids {
251 client
252 .execute(sql.as_str(), &[&id.as_ref()])
253 .map_err(from_pg)?;
254 }
255 Ok(())
256 }
257 }
258
259 fn len(&self) -> usize {
260 #[cfg(not(feature = "live-db"))]
261 {
262 0
263 }
264 #[cfg(feature = "live-db")]
265 {
266 let Ok(mut client) = (self.factory)() else {
267 return 0;
268 };
269 let sql = format!("SELECT COUNT(*) FROM {};", self.table);
270 let Ok(row) = client.query_one(sql.as_str(), &[]) else {
271 return 0;
272 };
273 let count: i64 = row.get(0);
274 usize::try_from(count).unwrap_or(0)
275 }
276 }
277
278 fn dim(&self) -> Option<usize> {
279 self.dim
280 }
281}