Skip to main content

vein_database/vector/
lancedb.rs

1use anyhow::{bail, Context, Result};
2use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, StringArray};
3use arrow_schema::{DataType, Field, Schema};
4use futures_util::StreamExt;
5use lancedb::{index::Index, query::ExecutableQuery, query::QueryBase, Table};
6use std::sync::Arc;
7use std::path::PathBuf;
8use crate::paths::Paths;
9
10const MIN_ROWS_FOR_INDEX: usize = 300;
11
12#[derive(Debug, Clone)]
13pub struct Entry {
14    pub id: String,
15    pub content: String,
16}
17
18fn get_db_path() -> PathBuf {
19    Paths::get_insights_db()
20}
21
22fn validate_id(id: &str) -> Result<()> {
23    if id.is_empty() || id.len() > 100 {
24        bail!("Invalid id: must be non-empty and <= 100 characters");
25    }
26    if id.contains('\'') || id.contains(';') || id.contains('"') {
27        bail!("Invalid id: contains forbidden characters");
28    }
29    Ok(())
30}
31
32pub struct LanceDb {
33    table: Table,
34    indexed: bool,
35    vector_dim: usize,
36}
37
38impl LanceDb {
39    pub async fn new(table_name: &str, vector_dim: usize) -> Result<Self> {
40        let db_path = get_db_path();
41        let db = lancedb::connect(db_path.to_str().context("Invalid DB path")?)
42            .execute()
43            .await?;
44
45        let schema = Schema::new(vec![
46            Field::new("id", DataType::Utf8, false),
47            Field::new("content", DataType::Utf8, false),
48            Field::new(
49                "vector",
50                DataType::FixedSizeList(
51                    Arc::new(Field::new("item", DataType::Float32, false)),
52                    vector_dim as i32,
53                ),
54                false,
55            ),
56        ]);
57
58        let schema_ref = Arc::new(schema);
59
60        let names = db.table_names().execute().await?;
61        let (table, indexed) = if names.contains(&table_name.to_string()) {
62            let tbl = db.open_table(table_name).execute().await?;
63            let existing_schema = tbl.schema().await?;
64
65            // Check if existing schema matches requested vector dimension
66            if let Ok(vector_field) = existing_schema.field_with_name("vector")
67                && let DataType::FixedSizeList(_, existing_dim) = vector_field.data_type()
68                && *existing_dim as usize != vector_dim {
69                    bail!(
70                        "Table '{}' exists with vector dimension {} but current embedding model produces dimension {}. \
71                         Please delete the table and recreate it to switch models.",
72                        table_name, existing_dim, vector_dim
73                    );
74            }
75
76            let indices = tbl.list_indices().await?;
77            let indexed = !indices.is_empty();
78            (tbl, indexed)
79        } else {
80            let tbl = db
81                .create_table(table_name, vec![RecordBatch::new_empty(schema_ref.clone())])
82                .execute()
83                .await?;
84            (tbl, false)
85        };
86
87        Ok(Self { table, indexed, vector_dim })
88    }
89
90    pub async fn post(&self, id: &str, content: &str, vector: Vec<f32>) -> Result<String> {
91        if vector.len() != self.vector_dim {
92            bail!("vector dimension must be {}, got {}", self.vector_dim, vector.len());
93        }
94        validate_id(id)?;
95
96        if self.exists_by_content(content).await? {
97            return Ok(id.to_string());
98        }
99
100        let schema = self.table.schema().await?;
101
102        let vector_array = FixedSizeListArray::try_new(
103            Arc::new(Field::new("item", DataType::Float32, true)),
104            self.vector_dim as i32,
105            Arc::new(Float32Array::from(vector)),
106            None,
107        )?;
108
109        let batch = RecordBatch::try_new(
110            schema,
111            vec![
112                Arc::new(StringArray::from(vec![id.to_string()])),
113                Arc::new(StringArray::from(vec![content])),
114                Arc::new(vector_array),
115            ],
116        )?;
117
118        self.table.add(vec![batch]).execute().await?;
119
120        if !self.indexed {
121            let count = self.table.count_rows(None).await?;
122            if count >= MIN_ROWS_FOR_INDEX {
123                self.table
124                    .create_index(&["vector"], Index::Auto)
125                    .execute()
126                    .await?;
127            }
128        }
129
130        Ok(id.to_string())
131    }
132
133    pub async fn get(&self, query_vector: &[f32], limit: usize) -> Result<Vec<Entry>> {
134        if query_vector.len() != self.vector_dim {
135            bail!("query vector dimension must be {}, got {}", self.vector_dim, query_vector.len());
136        }
137
138        let stream = self
139            .table
140            .query()
141            .nearest_to(query_vector)?
142            .limit(limit)
143            .execute()
144            .await?;
145
146        let mut entries = Vec::new();
147        let mut stream = stream;
148
149        while let Some(batch_result) = stream.next().await {
150            let batch: RecordBatch = batch_result?;
151            let id_array = batch.column(0);
152            let content_array = batch.column(1);
153
154            for i in 0..batch.num_rows() {
155                let id = id_array
156                    .as_any()
157                    .downcast_ref::<StringArray>()
158                    .map(|arr| arr.value(i).to_string())
159                    .unwrap_or_default();
160                let content = content_array
161                    .as_any()
162                    .downcast_ref::<StringArray>()
163                    .map(|arr| arr.value(i).to_string())
164                    .unwrap_or_default();
165
166                entries.push(Entry { id, content });
167            }
168        }
169
170        Ok(entries)
171    }
172
173    pub async fn exists_by_content(&self, content: &str) -> Result<bool> {
174        let escaped = content.replace('\'', "''");
175        let count = self.table.count_rows(Some(format!("content = '{}'", escaped))).await?;
176        Ok(count > 0)
177    }
178
179    pub async fn patch(&self, id: &str, new_content: &str, new_vector: Vec<f32>) -> Result<()> {
180        if new_vector.len() != self.vector_dim {
181            bail!("vector dimension must be {}, got {}", self.vector_dim, new_vector.len());
182        }
183
184        if self.exists_by_content(new_content).await? {
185            return Ok(());
186        }
187
188        self.delete(id).await?;
189        self.post(id, new_content, new_vector).await?;
190        Ok(())
191    }
192
193    pub async fn delete(&self, id: &str) -> Result<()> {
194        validate_id(id)?;
195        self.table.delete(&format!("id = '{}'", id)).await?;
196        Ok(())
197    }
198
199    pub async fn rebuild_index(&self) -> Result<()> {
200        self.table
201            .create_index(&["vector"], Index::Auto)
202            .execute()
203            .await?;
204        Ok(())
205    }
206}