vein_database/vector/
lancedb.rs1use anyhow::{bail, Context, Result};
2use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, StringArray};
3use arrow_schema::{DataType, Field, Schema};
4use futures_util::StreamExt;
5use lancedb::{index::Index, query::ExecutableQuery, query::QueryBase, Table};
6use std::sync::Arc;
7use std::path::PathBuf;
8use crate::paths::Paths;
9
10const MIN_ROWS_FOR_INDEX: usize = 300;
11
12#[derive(Debug, Clone)]
13pub struct Entry {
14 pub id: String,
15 pub content: String,
16}
17
18fn get_db_path() -> PathBuf {
19 Paths::get_insights_db()
20}
21
22fn validate_id(id: &str) -> Result<()> {
23 if id.is_empty() || id.len() > 100 {
24 bail!("Invalid id: must be non-empty and <= 100 characters");
25 }
26 if id.contains('\'') || id.contains(';') || id.contains('"') {
27 bail!("Invalid id: contains forbidden characters");
28 }
29 Ok(())
30}
31
32pub struct LanceDb {
33 table: Table,
34 indexed: bool,
35 vector_dim: usize,
36}
37
38impl LanceDb {
39 pub async fn new(table_name: &str, vector_dim: usize) -> Result<Self> {
40 let db_path = get_db_path();
41 let db = lancedb::connect(db_path.to_str().context("Invalid DB path")?)
42 .execute()
43 .await?;
44
45 let schema = Schema::new(vec![
46 Field::new("id", DataType::Utf8, false),
47 Field::new("content", DataType::Utf8, false),
48 Field::new(
49 "vector",
50 DataType::FixedSizeList(
51 Arc::new(Field::new("item", DataType::Float32, false)),
52 vector_dim as i32,
53 ),
54 false,
55 ),
56 ]);
57
58 let schema_ref = Arc::new(schema);
59
60 let names = db.table_names().execute().await?;
61 let (table, indexed) = if names.contains(&table_name.to_string()) {
62 let tbl = db.open_table(table_name).execute().await?;
63 let existing_schema = tbl.schema().await?;
64
65 if let Ok(vector_field) = existing_schema.field_with_name("vector")
67 && let DataType::FixedSizeList(_, existing_dim) = vector_field.data_type()
68 && *existing_dim as usize != vector_dim {
69 bail!(
70 "Table '{}' exists with vector dimension {} but current embedding model produces dimension {}. \
71 Please delete the table and recreate it to switch models.",
72 table_name, existing_dim, vector_dim
73 );
74 }
75
76 let indices = tbl.list_indices().await?;
77 let indexed = !indices.is_empty();
78 (tbl, indexed)
79 } else {
80 let tbl = db
81 .create_table(table_name, vec![RecordBatch::new_empty(schema_ref.clone())])
82 .execute()
83 .await?;
84 (tbl, false)
85 };
86
87 Ok(Self { table, indexed, vector_dim })
88 }
89
90 pub async fn post(&self, id: &str, content: &str, vector: Vec<f32>) -> Result<String> {
91 if vector.len() != self.vector_dim {
92 bail!("vector dimension must be {}, got {}", self.vector_dim, vector.len());
93 }
94 validate_id(id)?;
95
96 if self.exists_by_content(content).await? {
97 return Ok(id.to_string());
98 }
99
100 let schema = self.table.schema().await?;
101
102 let vector_array = FixedSizeListArray::try_new(
103 Arc::new(Field::new("item", DataType::Float32, true)),
104 self.vector_dim as i32,
105 Arc::new(Float32Array::from(vector)),
106 None,
107 )?;
108
109 let batch = RecordBatch::try_new(
110 schema,
111 vec![
112 Arc::new(StringArray::from(vec![id.to_string()])),
113 Arc::new(StringArray::from(vec![content])),
114 Arc::new(vector_array),
115 ],
116 )?;
117
118 self.table.add(vec![batch]).execute().await?;
119
120 if !self.indexed {
121 let count = self.table.count_rows(None).await?;
122 if count >= MIN_ROWS_FOR_INDEX {
123 self.table
124 .create_index(&["vector"], Index::Auto)
125 .execute()
126 .await?;
127 }
128 }
129
130 Ok(id.to_string())
131 }
132
133 pub async fn get(&self, query_vector: &[f32], limit: usize) -> Result<Vec<Entry>> {
134 if query_vector.len() != self.vector_dim {
135 bail!("query vector dimension must be {}, got {}", self.vector_dim, query_vector.len());
136 }
137
138 let stream = self
139 .table
140 .query()
141 .nearest_to(query_vector)?
142 .limit(limit)
143 .execute()
144 .await?;
145
146 let mut entries = Vec::new();
147 let mut stream = stream;
148
149 while let Some(batch_result) = stream.next().await {
150 let batch: RecordBatch = batch_result?;
151 let id_array = batch.column(0);
152 let content_array = batch.column(1);
153
154 for i in 0..batch.num_rows() {
155 let id = id_array
156 .as_any()
157 .downcast_ref::<StringArray>()
158 .map(|arr| arr.value(i).to_string())
159 .unwrap_or_default();
160 let content = content_array
161 .as_any()
162 .downcast_ref::<StringArray>()
163 .map(|arr| arr.value(i).to_string())
164 .unwrap_or_default();
165
166 entries.push(Entry { id, content });
167 }
168 }
169
170 Ok(entries)
171 }
172
173 pub async fn exists_by_content(&self, content: &str) -> Result<bool> {
174 let escaped = content.replace('\'', "''");
175 let count = self.table.count_rows(Some(format!("content = '{}'", escaped))).await?;
176 Ok(count > 0)
177 }
178
179 pub async fn patch(&self, id: &str, new_content: &str, new_vector: Vec<f32>) -> Result<()> {
180 if new_vector.len() != self.vector_dim {
181 bail!("vector dimension must be {}, got {}", self.vector_dim, new_vector.len());
182 }
183
184 if self.exists_by_content(new_content).await? {
185 return Ok(());
186 }
187
188 self.delete(id).await?;
189 self.post(id, new_content, new_vector).await?;
190 Ok(())
191 }
192
193 pub async fn delete(&self, id: &str) -> Result<()> {
194 validate_id(id)?;
195 self.table.delete(&format!("id = '{}'", id)).await?;
196 Ok(())
197 }
198
199 pub async fn rebuild_index(&self) -> Result<()> {
200 self.table
201 .create_index(&["vector"], Index::Auto)
202 .execute()
203 .await?;
204 Ok(())
205 }
206}